55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
|
import scipy.optimize
|
||
|
import numpy as np
|
||
|
from defusedxml import ElementTree
|
||
|
from collections import deque
|
||
|
import matplotlib.pyplot as plt
|
||
|
from scipy.optimize import curve_fit
|
||
|
|
||
|
fit_f1 = lambda x, K, alpha: K * np.exp(alpha * x)
|
||
|
fit_f2 = lambda x, a, b: a*x + b
|
||
|
|
||
|
|
||
|
|
||
|
data = {"Germany": deque()
|
||
|
, "France": deque()
|
||
|
, "Italy": deque()
|
||
|
# , "United States": deque()
|
||
|
# , "Angola": deque()
|
||
|
# , "China": deque()
|
||
|
}
|
||
|
|
||
|
with open("data/API_NY.GDP.MKTP.KN_DS2_en_xml_v2_10230884.xml") as fin:
|
||
|
tree = ElementTree.parse(fin)
|
||
|
for record in tree.getroot().find("data").findall("record"):
|
||
|
this_data = {field.get("name"): field.text for field in record.findall("field")}
|
||
|
if(this_data["Country or Area"] in data):
|
||
|
if(this_data["Value"] != None):
|
||
|
data[this_data["Country or Area"]].append((this_data["Year"], this_data["Value"]))
|
||
|
|
||
|
|
||
|
|
||
|
class Data(object):
|
||
|
def __init__(self, raw_data):
|
||
|
self.x = np.array([int(k) for k, v in raw_data])
|
||
|
self.y = np.array([float(v) for k, v in raw_data])
|
||
|
|
||
|
plots = deque()
|
||
|
for country, values in data.items():
|
||
|
values = Data(values)
|
||
|
|
||
|
popt1, pcov = curve_fit(fit_f1, values.x, values.y
|
||
|
, p0=[values.y[0], 1])
|
||
|
popt2, pcov = curve_fit(fit_f2, values.x, values.y
|
||
|
, p0=[values.y[0], (values.y[-1] - values.y[0])/(values.x[-1] - values.x[0])])
|
||
|
|
||
|
f1 = lambda x: fit_f1(x, popt1[0], popt1[1])
|
||
|
f2 = lambda x: fit_f2(x, popt2[0], popt2[1])
|
||
|
|
||
|
p1, = plt.plot(values.x, values.y, label="{}: real".format(country))
|
||
|
p2, = plt.plot(values.x, f1(values.x), label="%s: exponential fit, K=%.3e, $\\alpha$=%.3e" % (country, *popt1))
|
||
|
p3, = plt.plot(values.x, f2(values.x), label="%s: linear fit, a=%.3e, b=%.3e" % (country, *popt2))
|
||
|
plots.extend([p1, p2, p3])
|
||
|
|
||
|
plt.legend(handles=list(plots))
|
||
|
plt.show()
|