#include "interaction.h"
#include "math.h"

//#define WARN_NAN_OCCUR
//#define WARN_CORRECTED_R_0

#define raise2(x) (x)*(x)

static float 
interaction_force_function
	( float r
	, float * coefficients)
{
	float result = 0;
	int i;
	result = coefficients[0];
	for(i = 1; i < 7; i++)
	{
		result += coefficients[i] * (powf(r, i - 1) + coefficients[8]*powf(r, i));
	}
	result *= coefficients[7] * expf(coefficients[8] * (r - coefficients[9]));
	result += coefficients[10] * coefficients[11] * expf(coefficients[11] * (r - coefficients[12]));
	result += coefficients[13] 
			* 2 * (r - coefficients[15]) 
			* coefficients[14] * expf(coefficients[14] * raise2(r - coefficients[15]));
	result += coefficients[16] 
			* 2 * (r - coefficients[18]) 
			* coefficients[17] * expf(coefficients[17] * raise2(r - coefficients[18]));
	return result;
}
static float 
interaction_potential_function
	( float r
	, float * coefficients)
{
	float result = 0;
	int i;
	for(i = 0; i < 7; i++)
	{
		result += coefficients[i] * powf(r, i);
	}
	result *= coefficients[7] * expf(coefficients[8] * (r - coefficients[9]));
	result += coefficients[10] * expf(coefficients[11] * (r - coefficients[12]));
	result += coefficients[13] * expf(coefficients[14] * raise2(r - coefficients[15]));
	result += coefficients[16] * expf(coefficients[17] * raise2(r - coefficients[18]));
	return result;
}

static void
interaction_ufunc_force
	( char ** args
	, npy_intp * dimensions
	, npy_intp * steps
	, void * data)
{
	char * in = args[0]
		, * out = args[1];
	npy_intp n = dimensions[0];
	npy_intp in_step = steps[0]
		, out_step = steps[1];

	npy_intp i;

	float * coefficients = (float *) data;

	for(i = 0; i < n; i++)
	{
		*(float *)out = interaction_force_function(*(float *)in, coefficients);
		out += out_step;
		in += in_step;
	}
}
	
static void
interaction_ufunc_potential
	( char ** args
	, npy_intp * dimensions
	, npy_intp * steps
	, void * data)
{
	char * in = args[0]
		, * out = args[1];
	npy_intp n = dimensions[0];
	npy_intp in_step = steps[0]
		, out_step = steps[1];

	npy_intp i;

	float * coefficients = (float *) data;

	for(i = 0; i < n; i++)
	{
		*(float *)out = interaction_potential_function(*(float *)in, coefficients);
		out += out_step;
		in += in_step;
	}
}

static void
interaction_ufunc_float2D
	( char ** args
	, npy_intp * dimensions
	, npy_intp * steps
	, void * data)
{
	//NPY_BEGIN_THREADS_DEF;
	npy_intp i;
	npy_intp j;
	npy_intp n = dimensions[0];

	char * x_old = args[0]
		, * y_old = args[1]
		, * p_x_old = args[2]
		, * p_y_old = args[3]
		, * p_x_new = args[4]
		, * p_y_new = args[5];

	npy_intp x_old_steps = steps[0]
		, y_old_steps = steps[1]
		, p_x_old_steps = steps[2]
		, p_y_old_steps = steps[3]
		, p_x_new_steps = steps[4]
		, p_y_new_steps = steps[5];

	float * coefficients = (float *) data;
	float dt = coefficients[19];

	// Compute the new momenta: 

	// Stuff we will need:
	float r;         // Distance between interacting particles.
	float delta_x_e; // x-component of unit direction vector.
	float delty_y_e; // y-component of unit direction vector.
	float delta_p_x;
	float delta_p_y;

	// The current x_old[i], y_old[i] coordinates.
	float this_x_i;
	float this_y_i;
	// The current x_old[j], y_old[j] coordinates.
	float this_x_j;
	float this_y_j;

	//NPY_BEGIN_THREADS;

	for(i = 0; i < n; i++)
	{
		this_x_i = *(float *)(x_old + i*x_old_steps);
		this_y_i = *(float *)(y_old + i*y_old_steps);
		
		// copy current momenta
		*(float *)(p_x_new + i*p_x_new_steps) = *(float *)(p_x_old + i*p_x_old_steps);
		*(float *)(p_y_new + i*p_y_new_steps) = *(float *)(p_y_old + i*p_y_old_steps);

		// compute and add the momentum offset
		for(j = 0; j < i; j++)
		{
			this_x_j = *(float *)(x_old + j*x_old_steps);
			this_y_j = *(float *)(y_old + j*y_old_steps);

			// Compute distance and direction between particles i,j.
			r = sqrtf(raise2(this_x_i - this_x_j) + raise2(this_y_i - this_y_j));
			// r = 0, we cannot compute the direction.
			// In this case choose the direction randomly.
			if(r > 0)
			{
				delta_x_e = (this_x_i - this_x_j) / r;
				delty_y_e = (this_y_i - this_y_j) / r;
			}
			else
			{
				long int rand = random();
				float random_angle = (((float)rand) / ((float)RAND_MAX)) * 2 * M_PI;
#ifdef WARN_CORRECTED_R_0
				printf("Warning: corrected r = 0 with random angle pi*%f\n", random_angle / M_PI);
#endif
				delta_x_e = cosf(random_angle);
				delty_y_e = sinf(random_angle);
			}


			// Update the momenta.
			delta_p_x = delta_x_e * interaction_force_function(r, coefficients);
			delta_p_y = delty_y_e * interaction_force_function(r, coefficients);
#ifdef WARN_NAN_OCCUR
			if(isnan(delta_p_x))
			{
				printf("Warning: delta_p_x is NaN, after r = %f\n", r);
			}
			if(isnan(delta_p_y))
			{
				printf("Warning: delta_p_y is NaN, after r = %f\n", r);
			}
#endif
			//printf("%d, %d: %f, %f\n", i, j, delta_p_x, delta_p_y);
			*(float *)(p_x_new + i*p_x_new_steps) += dt*delta_p_x;
			*(float *)(p_y_new + i*p_y_new_steps) += dt*delta_p_y;
			*(float *)(p_x_new + j*p_x_new_steps) -= dt*delta_p_x;
			*(float *)(p_y_new + j*p_y_new_steps) -= dt*delta_p_y;
		}
	}
	//NPY_END_THREADS;
}
static char interaction_types[] = 
	{ NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT};
static char force_types[] =
	{ NPY_FLOAT, NPY_FLOAT};
static char potential_types[] =
	{ NPY_FLOAT, NPY_FLOAT};
static PyUFuncGenericFunction potential_funcs[1] = 
	{ interaction_ufunc_potential};
static PyUFuncGenericFunction force_funcs[1] = 
	{ interaction_ufunc_force};
static PyUFuncGenericFunction interaction_funcs[1] = 
	{ interaction_ufunc_float2D};

typedef struct
{
	PyObject_HEAD
	float coefficients[20];
	PyObject * ufunc;
	void *data[1];
} interaction_UFuncWrapper;

static int
interaction_UFuncWrapper_init
	( interaction_UFuncWrapper * self
	, PyObject * args
	, PyObject * kwds)
{
	// 0: ufunc_force
	// 1: interaction2D
	char type;
	PyObject * coefficients;
	int i;
	PyObject * this_coefficient;

	char *kwords[] = { "type_", "coefficients", NULL};

	if(!PyArg_ParseTupleAndKeywords(args, kwds, "bO", kwords, &type, &coefficients))
	{
		return -1;
	}
	if(!PySequence_Check(coefficients))
	{
		return -1;
	}

	if(PySequence_Size(coefficients) != 20)
	{
		PyErr_SetString(PyExc_ValueError, "coefficients must have length 20");
		return -1;
	}

	// copy the coefficients.
	for(i = 0; i < 20; i++)
	{
		this_coefficient = PySequence_GetItem(coefficients, i);
		if(!this_coefficient)
		{
			return -1;
		}
		// XXX: PyFloat_AsDouble might call python code,
		// so make sure that nothing bad can happen.
		Py_INCREF(this_coefficient);
		self->coefficients[i] = PyFloat_AsDouble(this_coefficient);
		Py_DECREF(this_coefficient);
		if(PyErr_Occurred())
		{
			return -1;
		}
	}
	self->data[0] = (void *)self->coefficients;

	switch(type)
	{
		case 0:
		{
			self->ufunc = PyUFunc_FromFuncAndData(
				force_funcs // func
				, self->data // data
				, force_types //types
				, 1 // ntypes
				, 1 // nin
				, 1 // nout
				, PyUFunc_None // identity
				, "force_function" // name
				, "computes the scalar force between two particles with given coefficients" // doc
				, 0); // unused
			break;
		}
		case 1:
		{
			self->ufunc = PyUFunc_FromFuncAndData(
				interaction_funcs
				, self->data
				, interaction_types
				, 1
				, 4
				, 2
				, PyUFunc_None
				, "interaction2D"
				, "Update the momenta according to the given coefficients and positions"
				, 0);
			break;
		}
		case 2:
		{
			self->ufunc = PyUFunc_FromFuncAndData(
				potential_funcs // func
				, self->data // data
				, potential_types //types
				, 1 // ntypes
				, 1 // nin
				, 1 // nout
				, PyUFunc_None // identity
				, "potential_function" // name
				, "computes the scalar potential between two particles with given coefficients" // doc
				, 0); // unused
			break;
		}
		default:
		{
			PyErr_SetString(PyExc_ValueError, "unknown ufunc type, must be 0 or 1");
			return -1;
		}
	}
	Py_INCREF(self->ufunc);
	return 0;
}


static PyObject *
interaction_UFuncWrapper_call
	(interaction_UFuncWrapper * self
	 , PyObject * args
	 , PyObject * kwargs)
{
	return PyObject_Call(self->ufunc, args, kwargs);
}

static PyMemberDef interaction_UFuncWrapper_members[] = 
{
	{"ufunc", T_OBJECT_EX, offsetof(interaction_UFuncWrapper, ufunc), 0, "ufunc"},
	{NULL}
};

static PyTypeObject interaction_UFuncWrapperType = 
{
	PyVarObject_HEAD_INIT(NULL, 0)
	.tp_name = "brown.interaction.UFuncWrapper",
	.tp_doc = "A wrapper that wraps the ufuncs for interaction and force, storing the coefficients",
	.tp_basicsize = sizeof(interaction_UFuncWrapper),
	.tp_itemsize = 0,
	.tp_flags = Py_TPFLAGS_DEFAULT,
	.tp_new = PyType_GenericNew,
	.tp_init = (initproc) interaction_UFuncWrapper_init,
	.tp_call = interaction_UFuncWrapper_call,
	.tp_members = interaction_UFuncWrapper_members
};



static PyMethodDef InteractionMethods[] = {
	{NULL, NULL, 0, NULL}
};


static struct PyModuleDef moduledef = {
	PyModuleDef_HEAD_INIT
	, "brown.interaction"
	, NULL
	, -1
	, InteractionMethods
	, NULL
	, NULL
	, NULL
	, NULL
};

PyMODINIT_FUNC 
PyInit_interaction(void)
{
	PyObject * module;

	if(PyType_Ready(&interaction_UFuncWrapperType) < 0)
	{
		return NULL;
	}

	module = PyModule_Create(&moduledef);
	if(!module)
	{
		return NULL;
	}
	import_array();
	import_ufunc();

	Py_INCREF(&interaction_UFuncWrapperType);
	PyModule_AddObject(module, "UFuncWrapper", (PyObject *) &interaction_UFuncWrapperType);

	return module;
}