import timeit
import itertools
import sys
import random
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

from scipy.sparse import dok_matrix


plt.gca().yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2e'))
plt.gcf().set_figheight(10)
plt.gcf().set_figwidth(20)

def construction(node_number):
    return dok_matrix((node_number, node_number))

number = 100 
repeat = 100
node_numbers = [100, 500, 1000, 1500, 2000, 3000, 4000, 6000, 8000, 10000, 20000, 50000, 100000, 500000, 1000000]

result = list()

for node_number in  node_numbers:
    print("running trial with", node_number, "nodes... ", end="", flush=True)
    timer = timeit.Timer(lambda: construction(node_number))
    result.append(min(timer.repeat(repeat=repeat, number=number)) / number)
    print("done")

h0, = plt.plot(node_numbers, result, "go-", label="Graph Construction Time in $s$")
plt.title("Graph Construction Time in Seconds ({} loops, best out of {})".format(number, repeat))
plt.xlabel("Number of Nodes")
plt.ylabel("Time in $s$")
plt.legend(handles=[h0])

plt.savefig("runtime_dok_matrix_graph_construction.png", dpi=400)
print("figure saved to runtime_dok_matrix_graph_construction.png")

plt.clf()
plt.gca().yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2e'))
plt.gcf().set_figheight(10)
plt.gcf().set_figwidth(20)

edge_counts = [10, 20, 50, 100, 500, 1000]
edge_counts = [e * 10 for e in edge_counts]

def add_edges(edges, node_number):
    g = construction(node_number)
    for edge in edges:
        g[edge] = 1

#edges = [[(list(itertools.product(list(range(e//2, e//2 + e//4)), 2))[:c], e) 
#                for e in node_numbers] 
#                    for c in edge_counts]


edges = []
for c in edge_counts:
    m = []
    for e in node_numbers:
        edge_from = list(range(e//2))
        edge_to = list(range(e//2, e//2 + e//4))

        random.shuffle(edge_from)
        random.shuffle(edge_to)

        m.append(([i for i,_ in zip(itertools.product(edge_from, edge_to), range(c))], e))
    edges.append(m)

#results = [[min(timeit.Timer(lambda: add_edges(*e)).repeat(repeat=repeat, number=number) for e in i] 
#            for i in edges]

results = []


for i,edge_count in zip(edges, edge_counts):
    m = []
    for e,node_number in zip(i, node_numbers):
        print("running trial with", node_number, "nodes;", edge_count, "edges ... ", end="", flush=True)
        m.append(min(timeit.Timer(lambda: add_edges(*e)).repeat(repeat=repeat, number=number)) / number)
        print("done")
    results.append(m)
        
# remove construction time
results = [[(t - ct) / edge_count for t, ct in zip(times, result)] for times, edge_count in zip(results, edge_counts)]

handles = [plt.plot(node_numbers, result, label="When adding {} edges".format(ec))[0] for result, ec in zip(results, edge_counts)]

plt.legend(handles=handles)

plt.title("Time to add one Edge in Seconds ({} loops, best out of {})".format(number, repeat))
plt.xlabel("Number of Nodes")
plt.ylabel("Time in $s$")

plt.savefig("runtime_all_array_graph_add_edges.png", dpi=400)
print("figure saved to runtime_dok_matrix_graph_add_edges.png")