from collections import deque
import matplotlib.pyplot as plt
import numpy as np
import json

from pyqcs import State, H, X, S, CZ
from pyqcs.graph.state import GraphState
from pyqcs.util.random_circuits import random_circuit

from measure_circuit import execution_statistics

def S_with_extra_arg(act, i):
    return S(act)

def test_scaling_circuits(state_factory
                        , nstart
                        , nstop
                        , step
                        , nqbits
                        , ncircuits
                        , **kwargs):
    results = deque()

    for ngates in range(nstart, nstop, step):
        circuits = [random_circuit(nqbits, ngates, X, H, S_with_extra_arg, CZ)
                        for _ in range(ncircuits)]
        state = state_factory(nqbits)

        print("running test with", ngates, "gates on", nqbits, "qbits")

        N, avg, std_dev = execution_statistics(circuits, state, scale=1, **kwargs)
        results.append([ngates, N, avg, std_dev])

    return np.array(results, dtype=np.double)


if __name__ == "__main__":
    nstart = 400
    nstop = 2800
    step = 50
    ncircuits = 100
    nqbits0 = 100
    nqbits1 = 50
    seed = 0xdeadbeef

    np.random.seed(seed)
    results_graph0 = test_scaling_circuits(GraphState.new_zero_state
                                    , nstart
                                    , nstop
                                    , step
                                    , nqbits0
                                    , ncircuits
                                    , repeat=10)
    np.random.seed(seed)
    results_graph1 = test_scaling_circuits(GraphState.new_zero_state
                                    , nstart
                                    , nstop
                                    , step
                                    , nqbits1
                                    , ncircuits
                                    , repeat=10)

    np.savetxt("circuit_scaling_graph0.csv", results_graph0)
    print("saved results0 to circuit_scaling_graph0.csv")
    np.savetxt("circuit_scaling_graph1.csv", results_graph1)
    print("saved results1 to circuit_scaling_graph1.csv")

    meta = {
            "nstart": 400
            , "nstop": 1800
            , "step": 50
            , "ncircuits": 50
            , "nqbits0": 100
            , "nqbits1": 50
            , "seed": 0xdeadbeef}

    with open("circuit_scaling_meta.json", "w") as fout:
        json.dump(meta, fout)
    print("saved meta to circuit_scaling_meta.json")