//--------------------------------------------------------
// COMP 238 Programming Assignemt 1 - Whitted ray tracer
// Adrian Ilie
//--------------------------------------------------------

#include "raytracer.h"
#include <math.h>
#include <fstream>
#include <iostream>
#include <stdlib.h>

#define MAX_TREE_DEPTH 5
#define EPSILON 1e-3
#define AIR_REFRACTION_INDEX 0.999

//--------------------------------------------------------
// Point class
//--------------------------------------------------------
cPoint::cPoint()
{
	X = Y = Z = 0.0;
}

cPoint::cPoint(double vX, double vY, double vZ)
{
	X = vX;
	Y = vY;
	Z = vZ;
}

double cPoint::Distance(const cPoint& p) const
{
	return sqrt( (p.X - X)*(p.X - X) + (p.Y - Y)*(p.Y - Y) + (p.Z - Z)*(p.Z - Z) );
}

bool cPoint::operator==(cPoint& p) const
{
	return ( X == p.X )&&( Y == p.Y )&&( Z == p.Z );
}

//--------------------------------------------------------
// Vector class
//--------------------------------------------------------
cVector::cVector()
{
	I = J = K = 0.0;
}

cVector::cVector(const cPoint& p1, const cPoint& p2)
{
	I = p2.X - p1.X;
	J = p2.Y - p1.Y;
	K = p2.Z - p1.Z;
}

cVector::cVector(double vI, double vJ, double vK)
{
	I = vI;
	J = vJ;
	K = vK;
}

cVector cVector::Normalize() const
{
	double m = Magnitude();
	if (m != 0)return cVector(I/m, J/m, K/m);
	else return cVector(0.0, 0.0, 0.0);
}

double cVector::DotProd(const cVector& v) const
{
	return I*v.I + J*v.J + K*v.K;
}

cVector cVector::CrossProd(const cVector& v) const
{
	return cVector(J*v.K - v.J*K, -1.0*(I*v.K - v.I*K), I*v.J - v.I*J);
}

cVector cVector::operator+ (cVector v) const
{
	return cVector(I + v.I, J + v.J, K + v.K);
}

cVector cVector::operator- (cVector v) const
{
	return cVector(I - v.I, J - v.J, K - v.K);
}

cPoint cVector::operator+ (cPoint p) const
{
	return cPoint(I + p.X, J + p.Y, K + p.Z);
}

cVector cVector::operator* (double s) const
{
	return cVector(I*s, J*s, K*s);
}

cPoint cVector::VectorToPoint() const
{
	return cPoint(I,J,K);
}

double cVector::Magnitude() const
{
	return sqrt(I*I + J*J + K*K);
}

//--------------------------------------------------------
// Color class
//--------------------------------------------------------
double clamp01(double d)
{
	if (d<0)d=0;
	if (d>1)d=1;
	return d;
}

cColor::cColor(double vR, double vG, double vB)
{
	R = clamp01(vR);
	G = clamp01(vG);
	B = clamp01(vB);
}

cColor cColor::operator* (double s) const//multiply by scalar
{
	return cColor(R*s, G*s, B*s);
}

cColor cColor::operator+ (const cColor& c) const
{
	return cColor(R+c.R,G+c.G,B+c.B);
}

cColor cColor::operator* (const cColor& c) const
{
	return cColor(R*c.R,G*c.G,B*c.B);
}

unsigned char cColor::bR() const
{
	return DoubleToByte(R);
}
unsigned char cColor::bG() const
{
	return DoubleToByte(G);
}
unsigned char cColor::bB() const
{
	return DoubleToByte(B);
}

unsigned char cColor::DoubleToByte(double d) const
{
	return (unsigned char)(clamp01(d)*255.0);
}

//--------------------------------------------------------
// Light class
//--------------------------------------------------------
cLight::cLight() : Point(0,0,0), Color(0.0, 0.0, 0.0)
{}

cLight::cLight(const cPoint& p) : Point(p), Color(0.0, 0.0, 0.0)
{}

cLight::cLight(const cPoint& p, const cColor& c) : Point(p), Color(c)
{}

cLight::cLight(const double x, const double y, const double z, const double r, const double g, const double b) : Point(x, y, z), Color(r, g, b)
{}

cColor cLight::Contribution(const cRay& r, const cShape* s, const cPoint& p)
{
	cColor diffuse = cColor(0.0, 0.0, 0.0);
	cColor specular = cColor(0.0, 0.0, 0.0);
	cVector N, L, V, R;
	double NdotL, RdotV;
		
	N = s->Normal(p, Point);//normal on surface at intersection point
	L = cVector(p, Point);//ray form intersection point to light
	L = L.Normalize();//ray to light normalized
	NdotL = N.DotProd(L);//normal dot ray to light

	if (NdotL < 0) return diffuse;//light behind surface, return zero
	diffuse = s->Diffuse * NdotL;

	R = s->Reflect(p, cRay(Point, p)).Direction.Normalize();//ray to light reflected
	V = r.Direction.Normalize() * -1.0;//incident ray reversed
	RdotV = R.DotProd(V);
	if (RdotV >= 0.0)//get specular highlight	
		specular = s->Specular * pow(RdotV, s->SpecularPower) * s->Reflectivity;

	return Color * (diffuse + specular);
}

//--------------------------------------------------------
// Ray class
//--------------------------------------------------------
cRay::cRay(const cPoint& p, const cVector& v): Direction(v), Origin(p){};

cRay::cRay(const cPoint& p1, const cPoint& p2): Direction(cVector(p1,p2)), Origin(p1){};

//--------------------------------------------------------
// Shape class (abstract)
//--------------------------------------------------------
cShape::cShape()
{
}

cRay cShape::Reflect(const cPoint& p, const cRay& r) const
{
	cRay reflected;
	cVector V = r.Direction.Normalize();
	cVector N = Normal(p, r.Origin );
	double NdotV = N.DotProd(V);
	if ( NdotV < 0.0 )
		N = N * -1.0;//flip the normal

	reflected.Origin = p;
	reflected.Direction = r.Direction.Normalize() + N * (2 * NdotV);

	return reflected;
}

cRay cShape::Refract(const cPoint& p, const cRay& r) const
{
	cRay refracted;
	cVector normal = Normal(p, r.Origin );
	double RdotN = normal.DotProd(r.Direction * -1.0);
	double n1,n2;
	if ( normal.DotProd(r.Direction.Normalize()) > 0.0 )
	{
		normal = normal * -1.0;//flip the normal
		RdotN = -RdotN;
		n1 = RefractionIndex;
		n2 = AIR_REFRACTION_INDEX;
	}
	else
	{
		n2 = RefractionIndex;
		n1 = AIR_REFRACTION_INDEX;
	}

	refracted.Origin = p;
	refracted.Direction = (r.Direction.Normalize() + (normal * RdotN)) * (n1 / n2) -
		(normal * sqrt(1.0 - n1*n1*(1-RdotN*RdotN)/(n2*n2)));

	return refracted;
}

//--------------------------------------------------------
// Sphere class
//--------------------------------------------------------
cSphere::cSphere(cPoint p, double r) : Center(p), Radius(r)
{
	Type = string("Sphere");
}

cSphere::cSphere(double r, double x, double y, double z) : Center(cPoint(x,y,z)), Radius(r)
{
	Type = string("Sphere");
}

bool cSphere::Intersect(const cRay& r, cPoint& p) const
{
	cVector EO(r.Origin, Center);	
	cVector E(cPoint(0,0,0), r.Origin); // E = vector from origin to ray origin
	cVector V = r.Direction.Normalize();
	double v = V.DotProd(EO);
	if (v < 0) return false;
	{
		double t = Radius*Radius + v*v - EO.DotProd(EO);
		if (t < 0) return false;
		else
		{
			cVector P = E + V*(v-sqrt(t));
			p = P.VectorToPoint();
			return true;
		}
	}
}

cVector cSphere::Normal(const cPoint& p1, const cPoint& p2) const
{
	cVector radius(Center, p1);//p1 on surface
	if (Center.Distance(p2) - Center.Distance(p1) <= EPSILON)//p2 second intersection with sphere or inside
		radius = radius * -1.0;//reverse it
	return (radius.Normalize());
}

//--------------------------------------------------------
// Plane class
//--------------------------------------------------------
cPlane::cPlane()
{}

cPlane::cPlane(const cPoint& p, const cVector& n): Point(p)
{
	Normal2 = n.Normalize();
	Type = string("Plane");
}

cPlane::cPlane(double x, double y, double z, double nx, double ny, double nz)
{
	Point = cPoint(x, y, z);
	Normal2 = cVector(nx, ny, nz).Normalize();
	Type = string("Plane");
}

cPlane::~cPlane()
{}

bool cPlane::Intersect(const cRay& r, cPoint& p) const
{
	
	cVector nray = r.Direction.Normalize();
	if ( Normal2.DotProd(nray) > 0.0 )
		return false;//going in opposite direction

	cVector P_minus_U (r.Origin, Point);

	double t = Normal2.DotProd(P_minus_U) / Normal2.DotProd(r.Direction);
	p =  (r.Direction * t) + r.Origin;//intersect ray with plane

	return true; 
}

cVector cPlane::Normal(const cPoint& p1, const cPoint& p2) const
{
	return Normal2;
}

//--------------------------------------------------------
// Surface class
//--------------------------------------------------------
cSurface::cSurface() : Ambient(0,0,0), Diffuse(0,0,0), Specular(0,0,0)
{
	SpecularPower = 1;
	Transparency = 0;
	Reflectivity = 1;
	RefractionIndex = AIR_REFRACTION_INDEX;
}


//--------------------------------------------------------
// Tracer class
//--------------------------------------------------------
cTracer::cTracer() : Background(0.5,0.5,0.5), Eye(1,1,1), lookAt(0,0,0),upVector(0,0,1)
{
	fov = 45;
	FileName = "t.ppm";
	Shapes = new cShapeList;
	Lights = new cLightList;
	Surfaces = new cSurfaceList;
}

cTracer::~cTracer()
{
	delete Shapes;
	delete Lights;
}

void cTracer::Trace()
{
	cRay rEye(Eye,cVector(0,0,0));
	cColor color;
	float x, y, dx, dy;
	int c = 0;
	ofstream file(FileName.c_str(),ios::binary);
	file << "P6 " << Width << " " << Height << " 255" << '\n';
	for (y = -Height/2 + 0.5; y < Height/2; y++)
	{
		c++;//line counter
		printf("%d\n",c);
		dy = y / Height;
		for (x = -Width/2 + 0.5; x < Width/2; x++)
		{
			dx = x / Width;
			rEye.Direction.K = u.K * dx + v.K * dy + o.K;
			rEye.Direction.J = u.J * dx + v.J * dy + o.J;
			rEye.Direction.I = u.I * dx + v.I * dy + o.I;
			color = Cast(rEye, 0);
			file << (unsigned char)color.bR() << (unsigned char)color.bG() << (unsigned char)color.bB();
		}
	}
	file.close();
}

cColor cTracer::Cast(const cRay& r, int treeDepth)
{
	cShape* sClosest;
	cPoint pIntersect;
	if (treeDepth > MAX_TREE_DEPTH) return Background;
	sClosest = Query(r, pIntersect);
	if (sClosest == NULL) return Background;
	else return Shade(r, sClosest, pIntersect)
		        + Reflect(r, sClosest, pIntersect, treeDepth)
				+ Refract(r, sClosest, pIntersect, treeDepth);
}

cColor cTracer::Shade(const cRay& r, const cShape* s, const cPoint& p)
{
	cColor intensity = s->Ambient;
	cColor contribution;

	cRay rReverse; //the ray that goes from intersection point to light sources
	cShape* sClosest; //closest shape on that ray
	cPoint pIntersect; //intersection point on that shape
		
	cLightListIterator iLights = Lights->begin();
	while ( iLights != Lights->end() ) // for each light
	{
		rReverse.Origin = p;
		rReverse.Direction = cVector(rReverse.Origin, (*iLights)->Point);
		sClosest = Query(rReverse, pIntersect);//any objects on path to light source?
		contribution = (*iLights)->Contribution(r, s, p);
		if (sClosest!=NULL) intensity = intensity + contribution * sClosest->Transparency;//light comes to p through sClosest
			else intensity = intensity + contribution;//light comes to p directly
		iLights++;
	};
	return intensity;
}

cColor cTracer::Reflect(const cRay& r, const cShape* sClosest, const cPoint& pIntersect, int treedepth)
{
	if (sClosest->Reflectivity == 0.0) return cColor(0.0, 0.0, 0.0);
	else
	{
		cRay rReflected = sClosest->Reflect(pIntersect, r);
		return Cast(rReflected, treedepth + 1) * sClosest->Specular * sClosest->Reflectivity;
	}
}

cColor cTracer::Refract(const cRay& r, const cShape* sClosest, const cPoint& pIntersect, int treedepth)
{
	if (sClosest->Transparency == 0.0) return cColor(0.0, 0.0, 0.0);
	else
	{
		cRay rRefracted = sClosest->Refract(pIntersect, r);
		return Cast(rRefracted, treedepth + 1) * sClosest->Transparency;
	}
}

cShape* cTracer::Query(const cRay& r, cPoint& p)
{
	cShape* sClosest = NULL;
	cPoint pIntersect;
	double closest = 999999;
	double intersect;
	cShapeListIterator iShapes = Shapes->begin();
	while ( iShapes != Shapes->end() )
	{
		if ( (*iShapes)->Intersect( r, pIntersect ) )
		{
			intersect = pIntersect.Distance(r.Origin);
			if ((intersect < closest)&&(intersect >= EPSILON))
			{
				p = pIntersect;
				sClosest = *iShapes;
				closest = intersect;
			}
		}
		iShapes++;
	}
	return sClosest;
}

void cTracer::AddShape(cShape* s)
{
	Shapes->push_back(s);
}

void cTracer::AddLight(cLight* l)
{
	Lights->push_back(l);
}

void cTracer::AddSurface(cSurface* s)
{
	Surfaces->push_back(s);
}

cSurface* cTracer::GetLastSurface()
{
	cSurface* su;
	cSurfaceListIterator iSurfaces = Surfaces->begin();
	while ( iSurfaces != Surfaces->end() )
	{
		su = * iSurfaces;
		iSurfaces++;
	}
	return su;
}

cSurface* cTracer::GetSurfaceByName(char *sn)
{
	cSurfaceListIterator iSurfaces = Surfaces->begin();
	cSurface* su = NULL;
	while ( iSurfaces != Surfaces->end() )
	{
		if ( strcmp((*iSurfaces)->name.c_str(), sn) == 0 ) su = *iSurfaces;
		iSurfaces++;
	}
	if (su)	return su;
	else 
	{
		fprintf(stderr,"Cannot find surface %s\n",sn);
		exit(1);
	}
}

cShape* cTracer::GetLastShape()
{
	cShapeListIterator iShapes = Shapes->begin();
	cShape* su = * iShapes;
	return su;
}

void cTracer::ApplySurface(cShape* sh, char* sn)
{
	cSurface* su = GetSurfaceByName(sn);

	sh->Ambient = su->Ambient;
	sh->Diffuse = su->Diffuse;
	sh->Specular = su->Specular;
	sh->Reflectivity = su->Reflectivity;
	sh->RefractionIndex = su->RefractionIndex;
	sh->Transparency = su->Transparency;
	sh->SpecularPower = su->SpecularPower;
}

void cTracer::SetupCamera()
{
	// setup u, v and o from lookAt and upVector
	cVector look = lookAt - cVector(Eye.X, Eye.Y, Eye.Z);
	u = look.CrossProd(upVector).Normalize();
	v = look.CrossProd(u).Normalize();
	o = look.Normalize() * (0.5 / tan((fov*3.1415926535897932384626433832795/180.0) / 2));
}