/**
 * main.cpp
 *
 * Computer Graphics 2
 * A simple ray tracer
 * 1/7/09
 *
 * Authors:
 * 	leaf corcoran
 * 	adam damiano
 *
 */

#include <stdio.h>
#include <iostream>
#include <vector>
#include <stdlib.h>
#include <math.h>
#include <GL/glut.h>

using namespace std;

#define WIDTH 600 
#define HEIGHT 600 

// max recursion depth for rays
#define RAY_DEPTH 16 

#define LMAX 1000
#define LDMAX 100

// 0 for none
// 1 for ward
// 2 for reindhard
#define TONE_REPRODUCTION 0

/**
 * a point in 3d space
 */
class Point {
public:
	double x, y, z;
	Point() : x(0), y(0), z(0) {}
	Point(double x, double y, double z)
		: x(x), y(y), z(z) {}
};

class World;
class Object;
class Vector;
class Pixel;

Pixel phong(Object & o, World & w, Point p);
Vector reflect(Vector, Vector);
Vector refract(Vector, Vector, double, double, bool &);

/**
 * a color object
 */
class Pixel {
public:
	static float cMAX, cMIN;
	float r, g, b;

	Pixel() : r(0), g(0), b(0) {}
	Pixel(float r, float g, float b)
		: r(r), g(g), b(b)
	{
		// ~
	}

	// clamp the rgb values to valid min and max
	Pixel clamp() const
	{
		return Pixel(
				r > cMAX ? cMAX : r < cMIN ? cMIN : r,
				g > cMAX ? cMAX : g < cMIN ? cMIN : g,
				b > cMAX ? cMAX : b < cMIN ? cMIN : b);
	}

	// multiply two colors
	Pixel operator*(const Pixel & p) const
	{
		return Pixel(r * p.r, g * p.g, b * p.b);
	}

	// add two colors 
	Pixel operator+(const Pixel & p) const
	{
		return Pixel(r + p.r, g + p.g, b + p.b);
	}

	// scale a color by value
	Pixel operator*(const double & i) const
	{
		return Pixel(r*i, g*i, b*i); // fix the phong
	}

	// average two colors with a weight for the left color
	Pixel average(const Pixel & p, const double & balance = 0.5f) const
	{
		
		return Pixel(
				(balance * r + (1-balance) * p.r),
				(balance * g + (1-balance) * p.g),
				(balance * b + (1-balance) * p.b));

	}

	// quick and dirty luminance approximation
	double luminance()
	{
		return 0.27*r + 0.67*g + 0.06*b;
	}

};
float Pixel::cMIN = 0.0;
float Pixel::cMAX = 1.0;

// some colors we will be using
Pixel white(1,1,1);
Pixel black(0,0,0);
Pixel red(1,0,0);
Pixel green(0,1,0);
Pixel blue(0,0,1);
Pixel yellow(1.0,0.9294,0.0431);

Pixel silver(0.6274, 0.7686, 0.8078);

Pixel dbg(0.4705, 0.9490, 1.0); // default background color
Pixel dam = black; // default ambient

/**
 * A vector 
 */
class Vector
{
public:
	double x,y,z;

	Vector(const Point & p1, const Point & p2) {
		x = p2.x - p1.x;
		y = p2.y - p1.y;
		z = p2.z - p1.z;
	}

	Vector(double x, double y, double z) 
		: x(x), y(y), z(z)
	{
		// ~
	}

	/**
	 * return the length of this vector
	 */
	double length()
	{
		return sqrt(x*x + y*y + z*z);
	}

	/**
	 * calculate the dot product
	 */
	double dot(const Vector & v)
	{
		return (x*v.x) + (y*v.y) + (z*v.z);
	}

	Vector scale(const double a)
	{
		return Vector(a*x, a*y, a*z);
	}

	Vector add(const Vector & v)
	{
		return Vector(x + v.x, y + v.y, z + v.z);
	}

	Vector subtract(const Vector & v)
	{
		return Vector(x - v.x, y - v.y, z - v.z);
	}


	Vector normalize()
	{
		double l = length();
		return Vector(x/l, y/l, z/l);
	}

	void print()
	{
		cout << "Vector(" << x << ", " << y << ", " << z << ")" << endl;
	}
};


/**
 * a ray is a point and a vector
 */
class Ray
{
public:
	Point origin;
	Vector direction;	
	double t; // last intersection value
	Object *parent;

	Ray(Point o, Vector dir, Object *p = NULL)
		: origin(o), direction(dir.normalize()), t(-1), parent(p)
	{
		// ~	
	}	

	/**
	 * make a ray from from origin and a point that exists on the ray
	 */
	Ray(Point o, Point t, Object *p = NULL) 
		: origin(o), direction(Vector(o, t).normalize()), t(-1), parent(p)
	{
		// ~
	}

	// get the coordinate at a certain t value
	Point getPoint() { return getPoint(t); }
	Point getPoint(const double t) const
	{
		return Point(
			origin.x + (direction.x * t),
			origin.y + (direction.y * t),
			origin.z + (direction.z * t)
		);
	}
};

/**
 * light source
 */
class Light
{
public:
	Point position;
	Pixel specular, diffuse;
	static Pixel ambient;

	Light() {}
	Light(Point p, Pixel d, Pixel s)
		: position(p), specular(s), diffuse(d)
	{
		// ~
	}
};
Pixel Light::ambient = dam;


/**
 * an object is something in the scene that
 * can collide with a ray via the intersect
 * function
 */
class Object
{
protected:
	// color of the ambient, specular, and diffuse components of light reflection
	Pixel ambient, specular, diffuse;

	Object(Pixel d, double refl, double trans)
		: diffuse(d), reflection(refl), transparency(trans)
	{
		ambient = diffuse;
		specular = white;
	}

public:
	double reflection;
	double transparency;

	// returns the t value for the ray intersects
	// t is negative 1 if there is no intersection or it is behind
	virtual double intersect(Ray r) = 0;
	// return the normal of a Point on this object
	virtual Vector getNormal( Point p ) = 0;

	virtual Pixel getAmbient(const Point & p) {
		return ambient;
	}

	virtual Pixel getSpecular(const Point & p) {
		return specular;
	}

	virtual Pixel getDiffuse(const Point & p) {
		return diffuse;
	}
};


/** 
 * a plane
 */
class Plane : public Object
{
private:
	Point on; // some point on the plane
	Vector n; // plane normal
	double max_x, min_x;
	double max_z, min_z;
	Pixel diffuse_2;
public:

	Plane(Point on, Vector n, Pixel d1, Pixel d2, double refl = 0, double trans = 0)
		: on(on), n(n), diffuse_2(d2), Object(d1, refl, trans)
	{

		// plane boundaries 
		min_x = -1;
		max_x = 2;

		min_z = -1;
		max_z = 5;
	}

	virtual Vector getNormal( Point p ) {
		return n.normalize();
	}

	virtual Pixel getDiffuse(const Point & p)
	{
		// calculate the u,v coordinate
		// magic numbers prevent fuzz on borders
		double u = (p.x + 1) * 1000 / (31 * 9);
		double v = (p.z + 1) * 1000 / (31 * 9);

		if ((int) u % 2 == (int) v % 2)
			return diffuse_2;
		else
			return diffuse;
	}

	virtual double intersect(Ray r)
	{
		// point on plane
		Vector p(on.x, on.y, on.z);
		double d = p.dot(n);

		Vector la = Vector(r.origin.x, r.origin.y, r.origin.z);
		Vector lb = la.add(Vector(r.direction.x, r.direction.y, r.direction.z)); 

		double num = d - la.dot(n);
		double denom = lb.subtract(la).dot(n);
		
		if (denom == 0)
			return -1;

		double t = num/denom;

		// cut off the sides of the planes to make it appear as a quad
		Point i = r.getPoint(t);
		if (i.x < min_x || i.x > max_x) return -1;
		if (i.z < min_z || i.z > max_z) return -1;

		return t;
	}

};

/**
 * a sphere is a type of object with an origin
 * and a radius
 */
class Sphere : public Object
{
private:
	Point center;
	double r; // radius
public:

	Sphere(Point c, double r, Pixel d, double refl = 0, double trans = 0)
		: center(c), r(r), Object(d, refl, trans) 
	{
		// ~
	}

	virtual Vector getNormal(Point p) {
		return Vector(center, p).normalize();
	}
	
	virtual double intersect(Ray ray)
	{
		double xx = ray.origin.x - center.x;
		double yy = ray.origin.y - center.y;
		double zz = ray.origin.z - center.z;

		Vector & rd = ray.direction;

		// calculate the B value
		double b = 2 * ((rd.x*xx) + (rd.y*yy) + (rd.z*zz));	
		double c = (xx*xx) + (yy*yy) + (zz*zz) - (r*r); 

		// see where the vector goes
		double com = (b*b) - (4*c);
		if (com < 0)
			return -1; // no intersection

		double t1 = (-b + sqrt(com))/2;
		double t2 = (-b - sqrt(com))/2;

		// use these vales for calculations to account for rounding errors
		double wt1 = t1 - 0.0005;
		double wt2 = t2 - 0.0005;


		if (wt1 < 0)
			if (wt2 < 0)
				return -1; // behind
			else
				return wt2;
		else if (wt2 < 0) return wt1;


		return t1 < t2 ? t1 : t2;

	}
};

/**
 * the world contains all of our objects and 
 * it also spawns the rays
 *
 */
class World
{
public:
	// the screen raster data
	Pixel screen[WIDTH*HEIGHT];

	Pixel background; // background for missed rays
	Pixel & ambient; // ambient color of the world

	Point camera; // camera Point of convergence

	// all lights and objects in the world
	vector<Object*> objects;
	vector<Light*> lights;

	double lmax, ldmax;


	World() : ambient(Light::ambient), background(dbg)
	{
		camera = Point(0,0,-1.5);

		lmax = LMAX;
		ldmax = LDMAX;


		// clear the screen with the dbg
		for (int i = 0; i < WIDTH * HEIGHT; i++)
			screen[i] = background;

		// set up the light..
		//lights.push_back(new Light(Point(0.3, 1.5, 0), white, white));
		// lights.push_back(new Light(Point(-1.5, 1.5, 1), white, white));
		lights.push_back(new Light(Point(-0.7, 1.8, -1), white, white));


		// transmissive
		objects.push_back(new Sphere(
					Point(-0.21,0, 0.5), 0.5, white, 0.0, 0.9));

		// reflective
		objects.push_back(new Sphere(
					Point(0.5,-0.4, 1.5),0.5, silver, 0.6, 0.0));

		// ground
		objects.push_back(new Plane(
					Point(0, -1, 0), Vector(0,1,0), red, yellow));
	}

	// calculate the log average luminance in a buffer
	double logAverage(Pixel *buffer) const
	{
		int n = WIDTH  * HEIGHT;
		double sum = 0.0;
		for (int i = 0; i < n; i++) {
			sum += log(0.000001 + buffer[i].luminance());
		}

		return exp(sum/n);
	}

	double wardScale(double la, double lmax, double ldmax = 100.0) const
	{
		double num = 1.219 + pow(ldmax/2.0, 0.4);
		double den = 1.219 + pow(la,0.4);

		return pow(num/den, 2.5);
	}

	// apply the ward tone reproduction to the scene
	void ward()
	{
		// apply lmax to scene
		for (int i = 0; i < WIDTH * HEIGHT; i++)
			screen[i] = screen[i] * lmax;


		double la = logAverage(screen);
		double sf = wardScale(la, lmax, ldmax) / ldmax;

		cout << "Log Average: " << la << endl;
		cout << "Ward Scale Factor: " << sf << endl;

		// now scale all the colors in the buffer
		for (int i = 0; i < WIDTH * HEIGHT; i++)
			screen[i] = screen[i] * sf;
	}

	void reindhard(double a = 0.18)
	{
		// apply lmax to scene
		for (int i = 0; i < WIDTH * HEIGHT; i++)
			screen[i] = screen[i] * lmax;


		double la = logAverage(screen);

		double zone = a / la;

		cout << "zone " << zone << endl;
		// scale all the colors in the buffer
		for (int i = 0; i < WIDTH * HEIGHT; i++) {
			Pixel p = screen[i] * zone;

			p.r = p.r / (1 + p.r);
			p.g = p.g / (1 + p.g);
			p.b = p.b / (1 + p.b);

			screen[i] = p.clamp();
		}

	}
	
	/** 
	 * render the entire screen by spawning rays for each pixel
	 */
	void render() 
	{
		for (int y = 0; y < HEIGHT; y++) 
		{
			for (int x = 0; x < WIDTH; x++)
			{
				Pixel & me = screen[y*HEIGHT + x];

				// find the offsets in x and y directions
				// our viewport is 2x2 centered on origin
				double ox = 2*(1.0*x/WIDTH) -1;
				double oy = 2*(1.0*y/HEIGHT) -1;

				Ray r(camera, Point(ox, oy, 0));
				me = trace(r);

				me = me.clamp();
			}
		}

		switch (TONE_REPRODUCTION) {
			case 1:
				ward();	
				break;
			case 2:
				reindhard();
				break;
		}

		// draw the screen
		glRasterPos2i(0,0);
		glDrawPixels(WIDTH, HEIGHT, GL_RGB, GL_FLOAT, screen);
	}

	/**
	 * recursively trace a ray
	 * returns the final color of the pixel
	 */
	Pixel trace(Ray r, int depth = 0, bool inside = false)
	{
		// intersect the ray and don't ignore the ray's source
		Object *o = intersect(r, false);
		if (o == NULL) return dbg;

		Pixel i = phong(*o, *this, r.getPoint());

		if (depth > RAY_DEPTH) return i;

		// are we calculating reflection
		if (o->reflection > 0) {
			Vector n = o->getNormal(r.getPoint());
			Pixel j = trace(
					Ray(r.getPoint(), reflect(r.direction, n), o), 
					depth+1);	

			i = i.average(j, 1 - o->reflection);
		}

		// are we refracting inside a transparent object
		if (o->transparency > 0) {
			// find the refraction vector
			double a = 1.0;
			double b = 0.98;
			Vector n = o->getNormal(r.getPoint());

			// if we are inside we need to swap some values
			if (inside) {
				n = n.scale(-1);

				// swap
				double tmp = a;
				a = b;
				b = tmp;
			}

			bool tir = false;
			Vector refracted = refract(r.direction, n, a, b, tir);
			if (tir) inside != inside;

			Pixel j = trace(
					Ray(r.getPoint(), refracted, o),
					depth+ 1, !inside);


			i = i.average(j, 1 - o->transparency);
		}

		return i;
	}


	/** 
	 * attempt to intersect a ray with all objects in the scene 
	 * excluding the excluded object
	 *
	 * returns pointer to object hit, or null if no object
	 * the t value of intersection is stored in the ray
	 */
	Object *intersect(Ray & r, bool ignore = true)
	{
		double t = -1;
		Object *winner = NULL;
		for (vector<Object*>::iterator i = objects.begin();
			i != objects.end(); i++)
		{
			if (ignore && *i == r.parent) continue; // do not consider this object
			double tt = (*i)->intersect(r);
			if (tt > 0 && (tt < t || t == -1)) {
				t = tt;
				winner = *i;
			}
			
		}

		r.t = t;
		return winner;
	}
};

World w;

/**
 * calculate vector of perfect reflection
 *
 * !! the vector of interest must be Pointing away from the 
 * surface !!
 *
 * 	x	the vector of interest
 * 	n	the reflection normal
 */
Vector reflect(Vector x, Vector n)
{
	double coef = -2.0 * x.dot(n);
	return x.add(n.scale(coef)).normalize();
}


/**
 * calculate the vector of refraction
 *
 * i	incident vector
 * n	surface normal
 * n1	index of refraction for outside space
 * n2	index of refraction for inner space
 */
Vector refract(Vector incident, Vector normal, double n1, double n2, bool & tir) 
{
	double n = n1/n2;
	
	double ndn = normal.dot(incident.scale(-1));

	double root = 1 + (n * n * ((ndn * ndn) - 1));
	if (root < 0) {
		tir = true;
		return reflect(incident, normal);
	}

	double coef = (ndn * n) - sqrt(root);
	return incident.scale(n).add(normal.scale(coef)).normalize();
}

/**
 * phong illumination model
 *
 * 	o	object being illuminated
 * 	w	the world the object is it
 * 	p	the 3d point under investigation
 * 
 */
Pixel phong(Object & o, World & w, Point p)
{
	Vector n = o.getNormal(p);
	Vector v(p, w.camera); // direction toward the viewer
	v = v.normalize();

	Pixel darkness = (w.ambient * o.getAmbient(p));
	Pixel i = darkness;

	
	// for each light find out the contributed intensity
	for (vector<Light*>::iterator li = w.lights.begin();
		li != w.lights.end(); li++)
	{
		Light & light = (**li);

		Vector l(p, light.position); // Point toward light source
		l = l.normalize();
		Vector r = reflect(l, n);

		// add in the specular and diffuse components for this light
		i = i + (light.diffuse * o.getDiffuse(p) * l.dot(n)) +
			(light.specular * o.getSpecular(p) * (double)pow(r.dot(v.scale(-1.0f)), 51.0f) ).clamp();

	}

	// for each light launch a shadow ray and reduce intensity accordingly
	double num = 0; // amount of total lights hit
	for (vector<Light*>::iterator li = w.lights.begin();
		li != w.lights.end(); li++)
	{
		Light & light = (**li);

		// shadow ray, spawn from the Point on the object with a vector Pointing toward the light
		Ray sr(p, light.position, &o);
		if (Object *winner = w.intersect(sr)) {
			Point pt = sr.getPoint();
			Vector v1(p, pt);
			Vector v2(p, light.position);

			// see if the intersected object is closer than the light
			if (v1.length() < v2.length()) {
				num += (1 - winner->transparency);
			}

		}
	}

	// make the total a little bigger so we don't get full
	// darkness in complete shadow to simulate some ambient light
	double total = w.lights.size()*2;

	if (num > 0)
		i = darkness.average(i, sqrt(num / total));

	return i;
}

void init()
{
	glClearColor (0.3, 0.1, 0.1, 0.0);
}

void render()
{
	cout << "Started drawing ..." << endl;
	glClear (GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

	// render the world
	w.render();

	glutSwapBuffers();
}


void reshape(int w, int h)
{
	glViewport (0, 0, (GLsizei) w, (GLsizei) h);
	glMatrixMode (GL_PROJECTION);
	glLoadIdentity();
	gluOrtho2D(0,WIDTH,0,HEIGHT);

	glMatrixMode(GL_MODELVIEW);
	glLoadIdentity();
}

static void Key(unsigned char key, int x, int y)
{
	switch (key) {
		case ' ':
			render();
			break;
		case 27:
			exit(0);
	}
}

int main(int argc, char** argv)
{
	glutInit(&argc, argv);
	glutInitDisplayMode (GLUT_SINGLE | GLUT_RGB | GLUT_DEPTH);
	glutInitWindowSize (WIDTH, HEIGHT); 
	glutCreateWindow ("Ray Tracer");

	init();

	glutReshapeFunc(reshape);
	glutDisplayFunc(render); 
	// glutIdleFunc(render);

	glutKeyboardFunc(Key);
	glutMainLoop();
	return 0;
}

