/**
 * main.cpp
 *
 * Computer Graphics 2
 * A simple ray tracer
 * 1/7/09
 *
 * Authors:
 * 	leaf corcoran
 * 	adam damiano
 *
 */

#include <stdio.h>
#include <iostream>
#include <vector>
#include <stdlib.h>
#include <math.h>
#include <GL/glut.h>

using namespace std;

#define WIDTH 600 
#define HEIGHT 600 

// max recursion depth for rays
#define RAY_DEPTH 4 

int view = 0;

class Point {
public:
	double x, y, z;
	Point() : x(0), y(0), z(0) {}
	Point(double x, double y, double z)
		: x(x), y(y), z(z) {}

	bool operator==(Point &o) {
		return x == o.x && y == o.y && z == o.z;
	}

	void print()
	{
		cout << "Point: (" << x << ", " << y << ", " << z << ")" << endl;
	}
};

class World;
class Object;
class Vector;
class Pixel;
Pixel phong(Object & o, World & w, Point p);
Vector reflect(Vector, Vector);

Vector refract(Vector, Vector, double, double, bool &);
Vector bad_refract(Vector i, Vector n, double n1, double n2);
Vector python_refract(Vector i, Vector n, double n1, double n2);

/**
 * A class representing a color/pixel
 * only put additional data after the three components
 */
class Pixel {
public:
	unsigned char r, g, b;

	Pixel() : r(0), g(0), b(0) {}
	Pixel(unsigned char r, unsigned char g, unsigned char b) 
		: r(r), g(g), b(b) {}

	// not sure what this operation is called
	Pixel operator*(const Pixel & p)
	{
		return Pixel(
			((int)r * p.r)/255,
			((int)g * p.g)/255,
			((int)b * p.b)/255
			);
	}

	Pixel operator+(const Pixel & p) {
		int rr = r + p.r;
		int gg = g + p.g;
		int bb = b + p.b;
		if (rr > 255) rr = 255;
		if (gg > 255) gg = 255;
		if (bb > 255) bb = 255;

		return Pixel(rr,gg,bb);
	}

	Pixel operator*(const double & i) {
		if (i <=0) return Pixel();

		int rr = r * i;
		int gg = g * i;
		int bb = b * i;
		if (rr > 255) rr = 255;
		if (gg > 255) gg = 255;
		if (bb > 255) bb = 255;

		return Pixel(rr,gg,bb);
	}

	Pixel average(const Pixel & p, double balance = 0.5f)
	{
		
		return Pixel(
				(balance * r + (1-balance) * p.r),
				(balance * g + (1-balance) * p.g),
				(balance * b + (1-balance) * p.b));

	}

	bool operator==(Pixel &o) {
		return r == o.r && g == o.g && g == o.g;
	}
};

Pixel white(255,255,255);
Pixel blue(0,0,255);
Pixel green(0,255,0);
Pixel red(255,0,0);
Pixel yellow(255,237,11);
Pixel brown(139, 60, 1);
Pixel black(0,0,0);

Pixel silver(160, 196, 206);

Pixel dbg(11, 203, 255); // default background
Pixel dam(0,0,0);  // default ambient

/**
 * A vector 
 */
class Vector
{
public:
	double x,y,z;

	Vector(const Point & p1, const Point & p2) {
		x = p2.x - p1.x;
		y = p2.y - p1.y;
		z = p2.z - p1.z;
	}

	Vector(double x, double y, double z) 
		: x(x), y(y), z(z)
	{
		// ~
	}

	/**
	 * return the length of this vector
	 */
	double length()
	{
		return sqrt(x*x + y*y + z*z);
	}

	/**
	 * calculate the dot product
	 */
	double dot(const Vector & v)
	{
		return (x*v.x) + (y*v.y) + (z*v.z);
	}

	Vector scale(const double a)
	{
		return Vector(a*x, a*y, a*z);
	}

	Vector add(const Vector & v)
	{
		return Vector(x + v.x, y + v.y, z + v.z);
	}

	Vector subtract(const Vector & v)
	{
		return Vector(x - v.x, y - v.y, z - v.z);
	}


	/**
	 * return this vector normalized
	 */
	Vector normalize()
	{
		double l = length();
		return Vector(x/l, y/l, z/l);
	}

	void print()
	{
		cout << "Vector(" << x << ", " << y << ", " << z << ")" << endl;
	}
};



/**
 * A ray consists of an origin Point
 * and a direction vector
 */
class Ray
{
public:
	Point origin;
	Vector direction;	
	double t; // last intersection value
	Object *parent;

	Ray(Point o, Vector dir, Object *p = NULL)
		: origin(o), direction(dir.normalize()), t(-1), parent(p)
	{
		// ~	
	}	

	/**
	 * make a ray from from origin and a point that exists on the ray
	 */
	Ray(Point o, Point t, Object *p = NULL) 
		: origin(o), direction(Vector(o, t).normalize()), t(-1), parent(p)
	{
		// ~
	}

	// get the coordinate at a certain t value
	Point getPoint() { return getPoint(t); }
	Point getPoint(const double t) const
	{
		return Point(
			origin.x + (direction.x * t),
			origin.y + (direction.y * t),
			origin.z + (direction.z * t)
		);
	}
	
	void print()
	{
		cout << "Ray: o(" << origin.x << ", " << origin.y << ", " << origin.z << ") d(" <<
			direction.x << ", " << direction.y << ", " << 
			direction.z << ")" << endl;
	}
};

/**
 * Light source
 */
class Light
{
public:
	Point position;
	Pixel specular, diffuse;
	static Pixel ambient;

	Light() {}
	Light(Point p, Pixel d, Pixel s)
		: position(p), specular(s), diffuse(d)
	{
		// ~
	}
};
Pixel Light::ambient = dam;


/**
 * an object is something in the scene that
 * can collide with a ray via the intersect
 * function
 */
class Object
{
protected:
	// color of the ambient, specular, and diffuse components of light reflection
	Pixel ambient, specular, diffuse;
public:
	double reflection;
	double transparency;

	// returns the t value for the ray intersects
	// t is negative 1 if there is no intersection or it is behind
	virtual double intersect(Ray r) = 0;
	// return the normal of a Point on this object
	virtual Vector getNormal( Point p ) = 0;

	virtual Pixel getAmbient(const Point & p) {
		return ambient;
	}

	virtual Pixel getSpecular(const Point & p) {
		return specular;
	}

	virtual Pixel getDiffuse(const Point & p) {
		return diffuse;
	}
};


/** 
 * a plane is a type of object with an origin
 * Point and a normal vector
 * all planes are truncated to certain x values 
 * for the time being
 */
class Plane : public Object
{
private:
	double x,y,z;	// point on plane
	Vector n; // plane normal
	double max_x, min_x;
	double max_z, min_z;
	Pixel even, odd;
public:

	Plane(double x, double y, double z, Vector n, Pixel d)
		: x(x), y(y), z(z), n(n)
	{
		reflection = 0;
		transparency = 0;
		diffuse = d;
		ambient = white;
		specular = white;
		min_x = -1;
		max_x = 2;

		min_z = -1;
		max_z = 5;



		// colors...
		even = diffuse;
		odd = yellow;
	}

	virtual Vector getNormal( Point p ) {
		return n.normalize();
	}

	virtual Pixel getDiffuse(const Point & p)
	{
		// calculate the u,v coordinate
		// magic numbers prevent fuzz on borders
		double u = (p.x + 1) * 1000 / (31 * 9);
		double v = (p.z + 1) * 1000 / (31 * 9);

		if ((int) u % 2 == (int) v % 2)
			return odd;
		else
			return even;

		return odd;
	}

	virtual double intersect(Ray r)
	{
		// point on plane
		Vector p(x,y,z);
		double d = p.dot(n);

		Vector la = Vector(r.origin.x, r.origin.y, r.origin.z);
		Vector lb = la.add(Vector(r.direction.x, r.direction.y, r.direction.z)); 

		double num = d - la.dot(n);
		double denom = lb.subtract(la).dot(n);
		
		if (denom == 0)
			return -1;

		double t = num/denom;

		// cut off the sides of the planes to make it appear as a quad
		Point i = r.getPoint(t);
		if (i.x < min_x || i.x > max_x) return -1;
		if (i.z < min_z || i.z > max_z) return -1;

		return t;
	}

};

/**
 * a sphere is a type of object with an origin
 * and a radius
 */
class Sphere : public Object
{
private:
	double x,y,z; // center
	double r; // radius
public:

	Sphere(double x, double y, double z, double r, Pixel d)
		: x(x), y(y), z(z), r(r)
	{
		reflection = 0;
		transparency = 0;
		diffuse = d;
		ambient = white;
		specular = white;
	}

	virtual Vector getNormal(Point p) {
		return Vector(Point(x,y,z), p).normalize();
	}
	
	virtual double intersect(Ray ray)
	{
		double xx = ray.origin.x - x;
		double yy = ray.origin.y - y;
		double zz = ray.origin.z - z;

		Vector & rd = ray.direction;

		// calculate the B value
		double b = 2 * ((rd.x*xx) + (rd.y*yy) + (rd.z*zz));	
		double c = (xx*xx) + (yy*yy) + (zz*zz) - (r*r); 

		// see where the vector goes
		double com = (b*b) - (4*c);
		if (com < 0)
			return -1; // no intersection

		double t1 = (-b + sqrt(com))/2 - 0.0005;
		double t2 = (-b - sqrt(com))/2 - 0.0005;



		/*
		cout << "\tSphere:" 
			<< " with: (" << t1
			<< ") and (" << t2 << ") "
			<< endl;

		cout << "\t\t";
		ray.getPoint(t1).print();
		cout << "\t\t";
		ray.getPoint(t2).print();
		*/

		if (t1 < 0 && t2 < 0)
			return -1; // it is behind us

		if (t1 < 0)
			return t2;

		if (t2 < 0)
			return t1;

		// if (ray.parent == this) cout << "\tinside?" << endl;

		// both are positive, return lesser of two
		double t;
		if (ray.parent == this)
			t = t1 < t2 ? t2 : t1;
		else
			t = t1 < t2 ? t1 : t2;

		/*
		if (ray.getPoint(t) == ray.origin) {
			cout << "got the same freaking point" << endl;
			return -1;
		}
		*/

		return t;
	}
};

/**
 * the world contains all of our objects and 
 * it also spawns the rays
 *
 */
class World
{
public:
	// the screen raster data
	Pixel screen[WIDTH*HEIGHT];

	Pixel background; // background for missed rays
	Pixel & ambient; // ambient color of the world

	// camera Point of convergence
	Point camera;


	// all the objects in the world
	vector<Object*> objects;

	// all the lights in the world
	vector<Light*> lights;
	// Light light;

	World() : ambient(Light::ambient)
	{
		// default background and camera
		background = dbg;
		camera = Point(0,0,-1);

		// clear the screen with the dbg
		for (int i = 0; i < WIDTH * HEIGHT; i++)
			screen[i] = background;

		// set up the light..
		lights.push_back(new Light(Point(0.3f, 1.5f, 0), white, white));
		// lights.push_back(new Light(Point(1, 0.5f, 0), white, white));
		// lights.push_back(new Light(Point(-1.5, 0.5f,2), white, white));


		objects.push_back(new Sphere(-0.21,0, 0.5,  0.5, blue));
		(*(objects.end() - 1))->transparency = 0.9f;

		

		objects.push_back(new Sphere(0.5,-0.4, 1.5,  0.5, silver));
		(*(objects.end() - 1))->reflection = 0.6f;


		objects.push_back(new Plane(0, -1, 0, Vector(0,1,0), red ));
	}
	
	/** 
	 * render the entire screen by spawning rays for each pixel
	 */
	void render() 
	{
		
		/*
		// trace the path of a single ray
		Ray r(camera, Point(0,.09, 1));

		Pixel c = trace(r);

		return;
		*/


		if (view == 0) {
			camera = Point(0,0,-1.5);
		} else {
			camera = Point(-3,0,1);
		}



		for (int y = 0; y < HEIGHT; y++) 
		{
			for (int x = 0; x < WIDTH; x++)
			{
				Pixel & me = screen[y*HEIGHT + x];

				// find the offsets in x and y directions
				// our viewport is 2x2 centered on origin
				double ox = 2*(1.0*x/WIDTH) -1;
				double oy = 2*(1.0*y/HEIGHT) -1;

				Ray r(camera, Point(ox, oy, 0));
				if (view == 0) { 
					r = Ray(camera, Point(ox, oy, 0));
				} else {
					r = Ray(camera, Point(-2, oy, ox+1));
				}
				
				me = trace(r);
			}
		}

		// draw the screen
		glRasterPos2i(0,0);
		glDrawPixels(WIDTH, HEIGHT, GL_RGB, GL_UNSIGNED_BYTE, screen);
	}

	/**
	 * recursively trace a ray
	 * returns the final color of the pixel
	 */
	Pixel trace(Ray r, int depth = 0, bool inside = false)
	{
		/*
		if (inside) cout << "inside ";
		else cout << "outside ";
		cout << "Trace.." << endl;
		cout << "\t";
		r.print();
		*/

		// intersect the ray and don't ignore the ray's source
		Object *o = intersect(r, false);
		if (o == NULL) return dbg;

		Pixel i = phong(*o, *this, r.getPoint());

		if (depth > RAY_DEPTH) return i;

		// are we calculating reflection
		if (o->reflection > 0) {
			Vector n = o->getNormal(r.getPoint());
			Pixel j = trace(
					Ray(r.getPoint(), reflect(r.direction, n), o), 
					depth+1);	

			i = i.average(j, 1 - o->reflection);
		}

		// are we refracting inside a transparent object
		if (o->transparency > 0) {
			// find the refraction vector


			double a = 1.0;
			double b = 0.97;
			Vector n = o->getNormal(r.getPoint());

			// if we are inside we need to swap some values
			if (inside) {
				n = n.scale(-1);

				// swap
				double tmp = a;
				a = b;
				b = tmp;
			}

			bool tir = false;
			Vector refracted = refract(r.direction, n, a, b, tir);
			if (tir) inside != inside;

			Pixel j = trace(
					Ray(r.getPoint(), refracted, o),
					depth+ 1, !inside);


			if (inside)
				i = j;
			else
				i = i.average(j, 1 - o->transparency);
		}

		return i;
	}


	/** 
	 * attempt to intersect a ray with all objects in the scene 
	 * excluding the excluded object
	 *
	 * returns pointer to object hit, or null if no object
	 * the t value of intersection is stored in the ray
	 */
	Object *intersect(Ray & r, bool ignore = true)
	{
		double t = -1;
		Object *winner = NULL;
		for (vector<Object*>::iterator i = objects.begin();
			i != objects.end(); i++)
		{
			if (ignore && *i == r.parent) continue; // do not consider this object
			double tt = (*i)->intersect(r);
			if (tt > 0 && (tt < t || t == -1)) {
				t = tt;
				winner = *i;
			}
			
		}

		r.t = t;
		return winner;
	}
};

World w;

/**
 * calculate vector of perfect reflection
 *
 * !! the vector of interest must be Pointing away from the 
 * surface !!
 *
 * 	x	the vector of interest
 * 	n	the reflection normal
 */
Vector reflect(Vector x, Vector n)
{
	double coef = -2.0 * x.dot(n);
	return x.add(n.scale(coef)).normalize();
}


/**
 * calculate the vector of refraction
 *
 * i	incident vector
 * n	surface normal
 * ni	index of refraction for outside space
 * nt	index of refraction for inner space
 */
Vector bad_refract(Vector i, Vector n, double n1, double n2)
{
	double r = n1 / n2;
	double ndn = n.dot(i);
	double d = r * r * (1.0 - (ndn * ndn));

	if (d > 1) {
		// cout << "TIR" << endl;
		return reflect(i, n);
	}


	Vector a = i.scale(r);
	Vector b = n.scale(r + sqrt(1.0 - d));
	// Vector g = i.scale(ndn).add(n.scale(ndn + sqrt(1.0 - d)).scale(-1));

	return a.subtract(b).normalize();
}


Vector python_refract(Vector incident, Vector normal, double n1, double n2) 
{
	incident = incident.scale(-1.0);

	double n = n1 / n2;
	double dot = incident.dot(normal);

	double d = 1.0 - ((n*n)* (1.0- (dot * dot)));
	if (d < 0) return reflect(incident, normal);

	d = sqrt(d);
	d = (n*dot) - d;
	
	return normal.scale(d).subtract(incident.scale(n)).normalize();
}


Vector refract(Vector incident, Vector normal, double n1, double n2, bool & tir) 
{
	double n = n1/n2;
	
	double ndn = normal.dot(incident.scale(-1));

	double root = 1 + (n * n * ((ndn * ndn) - 1));
	if (root < 0) {
		// cout << "TIR" << endl;
		tir = true;
		return reflect(incident, normal);
	}

	double coef = (ndn * n) - sqrt(root);
	return incident.scale(n).add(normal.scale(coef)).normalize();
}

/**
 * phong illumination model
 *
 * 	o	object being illuminated
 * 	w	the world the object is it
 * 	p	the 3d point under investigation
 * 
 */
Pixel phong(Object & o, World & w, Point p)
{
	Vector n = o.getNormal(p);
	Vector v(p, w.camera); // direction toward the viewer
	v = v.normalize();

	Pixel darkness = (w.ambient * o.getAmbient(p));
	Pixel i = darkness;

	
	// for each light find out the contributed intensity
	for (vector<Light*>::iterator li = w.lights.begin();
		li != w.lights.end(); li++)
	{
		Light & light = (**li);

		Vector l(p, light.position); // Point toward light source
		l = l.normalize();
		Vector r = reflect(l, n);

		// add in the specular and diffuse components for this light
		i = i + (light.diffuse * o.getDiffuse(p) * l.dot(n)) +
			(light.specular * o.getSpecular(p) * (double)pow(r.dot(v.scale(-1.0f)), 51.0f) );
	}

	// return i;

	// for each light launch a shadow ray and reduce intensity accordingly
	double num = 0;
	for (vector<Light*>::iterator li = w.lights.begin();
		li != w.lights.end(); li++)
	{
		Light & light = (**li);

		// shadow ray, spawn from the Point on the object with a vector Pointing toward the light
		Ray sr(p, light.position, &o);
		if (Object *winner = w.intersect(sr)) {
			Point pt = sr.getPoint();
			Vector v1(p, pt);
			Vector v2(p, light.position);

			// see if the intersected object is closer than the light
			if (v1.length() < v2.length()) {

				// if the object is transparent refract it
				// this will only work for spheres
				if (winner->transparency > 0) {
					/*
					Vector normal = winner->getNormal(sr.getPoint());
					bool tir = false;
					Vector refr = refract(sr.direction, normal, 1.0, .96, tir);

					Ray (sr(2
					*/	
				} else {
					num += 1;
				}
				
				
				
				// num += (1 - winner->transparency/1.5);
			}
		}
	}

	i = darkness.average(i, num/(w.lights.size()+.5));

	return i;
}

/**
 * return true if the ray is in shadow
 */
bool shadowRay(Ray ray, World & w)
{
	if (Object *winner = w.intersect(ray)) 
	{
		// if the object is transparent then we need to go deeper
		if (winner->transparency > 0) {

		}

		Point pt = sr.getPoint();
		Vector v1(p, pt);
		Vector v2(p, light.position);
		
		// see if the object is closer than light
		if (v1.length() < v2.length()) {



		Point pt = ray.getPoint();
	}

	return false;
}

void init()
{
	glClearColor (0.3, 0.1, 0.1, 0.0);
}

void render()
{
	cout << "Start Drawing with view " << view << endl;
	glClear (GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

	// render the world
	w.render();

	glutSwapBuffers();
}


void reshape(int w, int h)
{
	glViewport (0, 0, (GLsizei) w, (GLsizei) h);
	glMatrixMode (GL_PROJECTION);
	glLoadIdentity();
	gluOrtho2D(0,WIDTH,0,HEIGHT);

	glMatrixMode(GL_MODELVIEW);
	glLoadIdentity();
}

static void Key(unsigned char key, int x, int y)
{
	switch (key) {
		case ' ':
			view = (view + 1) % 2;
			render();
			break;
		case 27:
			exit(0);
	}
}

int main(int argc, char** argv)
{


	glutInit(&argc, argv);
	glutInitDisplayMode (GLUT_SINGLE | GLUT_RGB | GLUT_DEPTH);
	glutInitWindowSize (WIDTH, HEIGHT); 
	glutCreateWindow ("Ray Tracer");

	init();

	glutDisplayFunc(render); 
	// glutIdleFunc(render);
	glutReshapeFunc(reshape);

	glutKeyboardFunc(Key);
	glutMainLoop();
	return 0;
}

