/**
 * main.cpp
 *
 * Computer Graphics 2
 * A simple ray tracer
 * 1/7/09
 *
 * Authors:
 * 	leaf corcoran
 * 	adam damiano
 *
 */

#include <stdio.h>
#include <iostream>
#include <vector>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <GL/glut.h>

using namespace std;

#define WIDTH 400 
#define HEIGHT 400 

class Point {
public:
	float x, y, z;
	Point() : x(0), y(0), z(0) {}
	Point(float x, float y, float z)
		: x(x), y(y), z(z) {}
};

class World;
class Object;
class Vector;
class Pixel;
Pixel phong(Object & o, World & w, Point p);

/**
 * A class representing a color/pixel
 * only put additional data after the three components
 */
class Pixel {
public:
	unsigned char r, g, b;

	Pixel() : r(0), g(0), b(0) {}
	Pixel(unsigned char r, unsigned char g, unsigned char b) 
		: r(r), g(g), b(b) {}

	// not sure what this operation is called
	Pixel operator*(const Pixel & p)
	{
		return Pixel(
			((int)r * p.r)/255,
			((int)g * p.g)/255,
			((int)b * p.b)/255
			);
	}

	Pixel operator+(const Pixel & p) {
		int rr = r + p.r;
		int gg = g + p.g;
		int bb = b + p.b;
		if (rr > 255) rr = 255;
		if (gg > 255) gg = 255;
		if (bb > 255) bb = 255;

		return Pixel(rr,gg,bb);
	}

	Pixel operator*(const float & i) {
		if (i <=0) return Pixel();

		int rr = r * i;
		int gg = g * i;
		int bb = b * i;
		if (rr > 255) rr = 255;
		if (gg > 255) gg = 255;
		if (bb > 255) bb = 255;

		return Pixel(rr,gg,bb);
	}

	Pixel average(const Pixel & p, float balance = 0.5f)
	{
		
		return Pixel(
				(balance * r + (1-balance) * p.r),
				(balance * g + (1-balance) * p.g),
				(balance * b + (1-balance) * p.b));

	}
};

Pixel white(255,255,255);
Pixel blue(0,0,255);
Pixel green(0,255,0);
Pixel red(255,0,0);
Pixel brown(139, 60, 1);
Pixel black(0,0,0);
Pixel dbg(255,0,100); // default background
Pixel dam(52,31,76);  // default ambient

/**
 * A vector 
 */
class Vector
{
public:
	float x,y,z;

	Vector(const Point & p1, const Point & p2) {
		x = p2.x - p1.x;
		y = p2.y - p1.y;
		z = p2.z - p1.z;
	}

	Vector(float x, float y, float z) 
		: x(x), y(y), z(z)
	{
		// ~
	}

	/**
	 * return the length of this vector
	 */
	float length()
	{
		return sqrt(x*x + y*y + z*z);
	}

	/**
	 * calculate the dot product
	 */
	float dot(const Vector & v)
	{
		return (x*v.x) + (y*v.y) + (z*v.z);
	}

	Vector const scale(const float a)
	{
		return Vector(a*x, a*y, a*z);
	}

	Vector add(const Vector & v)
	{
		return Vector(x + v.x, y + v.y, z + v.z);
	}

	/**
	 * return this vector normalized
	 */
	Vector normalize()
	{
		float l = length();
		return Vector(x/l, y/l, z/l);
	}

	void print()
	{
		cout << "Vector(" << x << ", " << y << ", " << z << ")" << endl;
	}
};



/**
 * A ray consists of an origin Point
 * and a direction vector
 */
class Ray
{
public:
	Point origin;
	Vector direction;	
	float t; // last intersection value

	Ray(Point o, Vector dir)
		: origin(o), direction(dir.normalize()), t(-1)
	{
		// ~	
	}	

	/**
	 * make a ray from from origin and a point that exists on the ray
	 */
	Ray(Point o, Point t) 
		: origin(o), direction(Vector(o, t).normalize())
	{
		// ~
	}

	// get the coordinate at a certain t value
	Point getPoint() { return getPoint(t); }
	Point getPoint(const float t) const
	{
		return Point(
			origin.x + (direction.x * t),
			origin.y + (direction.y * t),
			origin.z + (direction.z * t)
		);
	}
	
	void print()
	{
		cout << "Ray: o(" << origin.x << ", " << origin.y << ", " << origin.z << ") d(" <<
			direction.x << ", " << direction.y << ", " << 
			direction.z << ")" << endl;
	}
};

/**
 * Light source
 */
class Light
{
public:
	Point position;
	Pixel specular, diffuse;
	static Pixel ambient;

	Light() {}
	Light(Point p, Pixel d, Pixel s)
		: position(p), specular(s), diffuse(d)
	{
		// ~
	}
};
Pixel Light::ambient = dam;


/**
 * an object is something in the scene that
 * can collide with a ray via the intersect
 * function
 */
class Object
{
public:
	// color of the ambient, specular, and diffuse components of light reflection
	Pixel ambient, specular, diffuse;

	// returns the t value for the ray intersects
	// t is negative 1 if there is no intersection or it is behind
	virtual float intersect(Ray r) = 0;
	// return the normal of a Point on this object
	virtual Vector getNormal( Point p ) = 0;
};


/** 
 * a plane is a type of object with an origin
 * Point and a normal vector
 * all planes are truncated to certain x values 
 * for the time being
 */
class Plane : public Object
{
private:
	float x,y,z;	// point on plane
	Vector n; // plane normal
	float max_x, min_x;
public:

	Plane(float x, float y, float z, Vector n, Pixel d)
		: x(x), y(y), z(z), n(n)
	{
		diffuse = d;
		ambient = white;
		specular = white;
		min_x = -1;
		max_x = 2;
	}

	virtual Vector getNormal( Point p ) {
		return n.normalize();
	}

	virtual float intersect(Ray r)
	{
		// vector of the Point on the plane
		Vector a(x,y,z);

		float d = a.dot(n);

		float denom = n.dot(r.direction);
		if (denom == 0) {
			// parallel to plane
			return -1;
		}

		float t = (n.dot(Vector(r.origin.x, r.origin.y, r.origin.z)) + d) / denom;

		// lets trim the edge of the plane to make it more turner whitted
		Point i = r.getPoint(t);
		if (i.x < min_x || i.x > max_x) return -1;



		return t;
	}

};

/**
 * a sphere is a type of object with an origin
 * and a radius
 */
class Sphere : public Object
{
private:
	float x,y,z; // center
	float r; // radius
public:

	Sphere(float x, float y, float z, float r, Pixel d)
		: x(x), y(y), z(z), r(r)
	{
		diffuse = d;
		ambient = white;
		specular = white;
	}

	virtual Vector getNormal(Point p) {
		return Vector(Point(x,y,z), p).normalize();
	}
	
	virtual float intersect(Ray ray)
	{
		float xx = ray.origin.x - x;
		float yy = ray.origin.y - y;
		float zz = ray.origin.z - z;

		Vector & rd = ray.direction;

		// calculate the B value
		float b = 2 * ((rd.x*xx) + (rd.y*yy) + (rd.z*zz));	
		float c = (xx*xx) + (yy*yy) + (zz*zz) - (r*r); 

		// see where the vector goes
		float com = (b*b) - (4*c);
		if (com < 0)
			return -1; // no intersection

		float t1 = (-b + sqrt(com))/2;
		float t2 = (-b - sqrt(com))/2;

		if (t1 < 0 && t2 < 0)
			return -1; // it is behind us

		if (t1 < 0)
			return t2;

		if (t2 < 0)
			return t1;

		// both are positive, return lesser of two
		return t1 < t2 ? t1 : t2;
	}
};

/**
 * the world contains all of our objects and 
 * it also spawns the rays
 *
 */
class World
{
public:
	// the screen raster data
	Pixel screen[WIDTH*HEIGHT];

	Pixel background; // background for missed rays
	Pixel & ambient; // ambient color of the world

	// camera Point of convergence
	Point camera;


	// all the objects in the world
	vector<Object*> objects;

	// all the lights in the world
	vector<Light*> lights;
	// Light light;

	World() : ambient(Light::ambient)
	{
		// default background and camera
		background = dbg;
		camera = Point(0,0,-1);

		// clear the screen with the dbg
		for (int i = 0; i < WIDTH * HEIGHT; i++)
			screen[i] = background;

		// set up the light..
		lights.push_back(new Light(Point(0.3, 1.5f, 0), white, white));
		// lights.push_back(new Light(Point(1, 0.5f, 0), white, white));
		// lights.push_back(new Light(Point(-1.5, 0.5f,2), white, white));

		objects.push_back(new Sphere(0,0,1,0.5, blue));
		objects.push_back(new Sphere(0.5,-0.5,2,0.5, green));
		objects.push_back(new Plane(0, -1, 0, Vector(0,1,0), brown ));
		
	}
	
	/** 
	 * spawn rays for each Pixel in the screen and find out
	 * if they collide with anything
	 */
	void render() 
	{
		for (int y = 0; y < HEIGHT; y++) 
		{
			for (int x = 0; x < WIDTH; x++)
			{
				Pixel & me = screen[y*HEIGHT + x];

				// find the offsets in x and y directions
				// our viewport is 2x2 centered on origin
				float ox = 2*(1.0*x/WIDTH) -1;
				float oy = 2*(1.0*y/HEIGHT) -1;

				Ray r(camera, Point(ox, oy, 0));
				
				if (Object* winner = intersect(r)) {
					me = phong(*winner, *this, r.getPoint() );
					//me = winner->diffuse;
				}
			}
		}

		// draw the screen
		glRasterPos2i(0,0);
		glDrawPixels(WIDTH, HEIGHT, GL_RGB, GL_UNSIGNED_BYTE, screen);
	}

	/** 
	 * attempt to intersect a ray with all objects in the scene 
	 * excluding the excluded object
	 *
	 * returns pointer to object hit, or null if no object
	 * the t value of intersection is stored in the ray
	 */
	Object *intersect(Ray & r, Object *exclude = NULL)
	{
		float t = -1;
		Object *winner = NULL;
		for (vector<Object*>::iterator i = objects.begin();
			i != objects.end(); i++)
		{
			if (*i == exclude) continue; // do not consider this object
			float tt = (*i)->intersect(r);
			if (tt > 0 && (tt < t || t == -1)) {
				t = tt;
				winner = *i;
			}
			
		}

		r.t = t;
		return winner;
	}
};

World w;

/**
 * calculate vector of perfect reflection
 *
 * !! the vector of interest must be Pointing away from the 
 * surface !!
 *
 * 	x	the vector of interest
 * 	n	the reflection normal
 */
Vector reflect(Vector x, Vector n)
{
	float coef = -2.0 * x.dot(n);
	return x.add(n.scale(coef)).normalize();
}

/**
 * phong illumination model
 *
 * 	o	object being illuminated
 * 	w	the world the object is it
 * 	p	the 3d point under investigation
 * 
 */
Pixel phong(Object & o, World & w, Point p)
{
	Vector n = o.getNormal(p);
	Vector v(p, w.camera); // direction toward the viewer
	v = v.normalize();

	Pixel darkness = (w.ambient * o.ambient);
	Pixel i = darkness;

	
	// for each light find out the contributed intensity
	for (vector<Light*>::iterator li = w.lights.begin();
		li != w.lights.end(); li++)
	{
		Light & light = (**li);

		Vector l(p, light.position); // Point toward light source
		l = l.normalize();
		Vector r = reflect(l, n);

		// add in the specular and diffuse components for this light
		i = i + (light.diffuse * o.diffuse * l.dot(n)) +
			(light.specular * o.specular * (float)pow(r.dot(v.scale(-1.0f)), 51.0f) );
	}

	// for each light launch a shadow ray and reduce intensity accordingly
	float num = 0;
	for (vector<Light*>::iterator li = w.lights.begin();
		li != w.lights.end(); li++)
	{
		Light & light = (**li);

		// shadow ray, spawn from the Point on the object with a vector Pointing toward the light
		Ray sr(p, light.position);
		if (Object *winner = w.intersect(sr, &o)) {
			Point pt = sr.getPoint();
			Vector v1(p, pt);
			Vector v2(p, light.position );

			// see if the intersected object is closer than the light
			if (v1.length() < v2.length()) {
				// return (o.ambient * w.ambient);
				// i = i.average(o.ambient * w.ambient, .3);
				num++;
			}
		}
	}

	i = darkness.average(i, num/w.lights.size());

	return i;
}

void init()
{
	glClearColor (0.3, 0.1, 0.1, 0.0);
}

void render()
{
	glClear (GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

	// render the world
	w.render();

	glutSwapBuffers();
}


void reshape(int w, int h)
{
	glViewport (0, 0, (GLsizei) w, (GLsizei) h);
	glMatrixMode (GL_PROJECTION);
	glLoadIdentity();
	gluOrtho2D(0,WIDTH,0,HEIGHT);

	glMatrixMode(GL_MODELVIEW);
	glLoadIdentity();
}

static void Key(unsigned char key, int x, int y)
{
	switch (key) {
		case 27:
			exit(0);
	}
}

int main(int argc, char** argv)
{
	glutInit(&argc, argv);
	glutInitDisplayMode (GLUT_SINGLE | GLUT_RGB | GLUT_DEPTH);
	glutInitWindowSize (WIDTH, HEIGHT); 
	glutCreateWindow ("Ray Tracer");

	init();

	glutDisplayFunc(render); 
	// glutIdleFunc(render);
	glutReshapeFunc(reshape);

	glutKeyboardFunc(Key);
	glutMainLoop();
	return 0;
}

