/*Numerical solver for AdS mode equation

phi''-(1+u^2)/u/f phi'+w^2/u/f^2 phi=0,  f=1-u^2

by 

Paul Romatschke

Used in "Spectral sum rules for the quark-gluon plasma",
by PR and D.T. Son, http://arxiv.org/abs/0903.3946

Copyright of this code is granted provided you keep this disclaimer.
*/

#include <iostream>
#include <iomanip>
#include <fstream>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <complex>

using namespace std;

typedef complex<double> myc;
typedef complex<long double> mylc;


const long int NUM=10000000;
long int UPDATE=NUM/100;

double u0=0.000001;
double u1=0.999;
double a=(u1-u0)/NUM;

double phi1,phi2;
double pi1,pi2;

double sol1[2],sol2[2];

fstream out;

void initialize(double w)
{

  printf("===> Initializing\n");
  //setting initial condition

  

  u0+=a/2.;

  phi1=pow(u0,2)*(1-pow(w,2)*u0/3.);
  //phi1+=pow(u0,4)*(0.5+pow(w,4)/24);
  phi2=-pow(w,4)/2.*log(u0)*phi1+1+pow(w,2)*u0-2./9.*pow(w,6)*pow(u0,3);
  //phi2-=pow(u0,4)/8.*(1.5*pow(w,4)-25./72.*pow(w,8));

  pi1=2*u0-pow(u0,2)*pow(w,2);
  //pi1+=pow(u0,3)/2.*(1.5*pow(w,4)-25./72.*pow(w,8));
  pi2=-pow(w,4)/2.*(phi1/u0+log(u0)*pi1)+pow(w,2)-2./3.*pow(w,6)*pow(u0,2);
  //pi2-=pow(u0,3)/2.*(1.5*pow(w,4)-25./72.*pow(w,8));


  u0-=a/2.;

  phi1=pow(u0,2)*(1-pow(w,2)*u0/3.);
  //phi1+=pow(u0,4)*(0.5+pow(w,4)/24);
  phi2=-pow(w,4)/2.*log(u0)*phi1+1+pow(w,2)*u0-2./9.*pow(w,6)*pow(u0,3);
  //phi2-=pow(u0,4)/8.*(1.5*pow(w,4)-25./72.*pow(w,8));
}




void forwardsolve(double w)
{
  printf("====>Starting solver for w=%f\n",w);

  double temp=0;
  double fu=0;
   for (long int i=0;i<NUM-1;i++)
    {
      
      phi1+=a*pi1;
      phi2+=a*pi2;
      
      double u=u0+(i+1)*a;
      fu=1-u*u;

      double tt1,tt2;

      tt1=(1+u*u)/(u*fu);
      tt2=w*w/(u*fu*fu);
      temp=tt1*pi1-phi1*tt2;
      pi1+=a*temp/(1-a/2.*(1+u*u)/(u*fu));

      temp=tt1*pi2-phi2*tt2;
      pi2+=a*temp/(1-a/2.*tt1);

      if (i%UPDATE==0)
	{
	  out << u0+(i+1)*a << "\t";
	  out << phi1 <<"\t";
	  out << phi2 <<"\n";
	}

    }
   sol1[0]=phi1;
   sol2[0]=phi2;
   phi1+=a*pi1;
   phi2+=a*pi2;
   sol1[1]=phi1;
   sol2[1]=phi2;

   out << u1 << "\t";
   out << phi1 <<"\t";
   out << phi2 <<"\n";

   //printf("====>Solver finished\n");
}

myc lphi1(double u,double w)
{
  myc I(0,1);
  myc temp=0;
   
  temp=1.0-(1-u)*(3*w*w/4.-I*w/4.+I*pow(w,3)/2.)/(1.0+w*w);
  temp-=(1-u)*(1-u)*w*(4.*I-2*w+7.*I*w*w+4.*pow(w,3))/32./(-2.+3.*I*w+w*w);

  temp*=exp(-I*0.5*w*log(1-u));
  return temp;
}


double getsolution(double w)
{
  //printf("====>Starting Evaluation\n");
  mylc aa,bb,cc,dd;
  
  myc temp1=lphi1(u1-a,w);
  myc temp2=lphi1(u1,w);

  aa=(sol1[0]/conj(temp1)-sol1[1]/conj(temp2))/(temp1/conj(temp1)-temp2/conj(temp2));
  //aa=(sol1[0]*conj(temp2)-sol1[1]*conj(temp1))/(temp1*conj(temp2)-temp2*conj(temp1));
  bb=(sol1[0]/temp1-sol1[1]/temp2)/(conj(temp1)/temp1-conj(temp2)/temp2);
  
  cc=(sol2[0]/conj(temp1)-sol2[1]/conj(temp2))/(temp1/conj(temp1)-temp2/conj(temp2));
  dd=(sol2[0]/temp1-sol2[1]/temp2)/(conj(temp1)/temp1-conj(temp2)/temp2);


  //  cout << "a " << aa << "\n";
  //cout << "b " << bb << "\n";
  //cout << "c " << cc << "\n";
  //cout << "d " << dd << "\n";

  mylc sol;
  sol=dd/(aa*dd-bb*cc);
  sol/=bb/(bb*cc-aa*dd);

  //cout << "sol " << sol << "\n";
  return imag(sol);
}

void singlesol(double w)
{

  UPDATE=NUM/100;

  printf("This is the numerical solver for AdS correlators\n");

  initialize(w);

  out.open("sol.dat",ios::out);
  
  forwardsolve(w);

  out.close();

  
  printf("Solution :w=%f, %.12g\n",w,getsolution(w));

  
}

void multisol()
{

  UPDATE=NUM;
  fstream sfout;

  printf("This is the numerical solver for AdS correlators\n");
  out.open("sol.dat",ios::out);
  sfout.open("sf.dat",ios::out);

  for (double w=0.01;w<5.1;w+=0.1)
    {
      initialize(w);
      forwardsolve(w);
      double temp=getsolution(w);
      printf("Solution :w=%f, %.12g\n",w,(temp-pow(w,4)*M_PI/2.)/w);
      sfout << w << "\t" << 2*(temp-pow(w,4)*M_PI/2.)/w;
      sfout << "\t" << fabs(2*(temp-pow(w,4)*M_PI/2.)/w) << endl;
      
    }

  out.close();
  sfout.close();
}


void allsumrules(double wmin, double wmax, double stepsize)
{

  UPDATE=NUM;
  fstream sfout;


  printf("This is the numerical solver for AdS correlators\n");
  printf("Checking sumrule 1 and 2:\n");
  out.open("sol.dat",ios::out);
  sfout.open("sf.dat",ios::out);
  double temp0=0.5;
  double lastw=wmax;
  double temp=0;
  
  int asize=0;

  //dry-run:
  for (double w=wmin;w<wmax;w+=stepsize)
    {
      asize++;
    }

  double *sum1,*sum2;
  sum1=new double[asize];
  sum2=new double[asize];

  asize=0;

  for (double w=wmin;w<wmax;w+=stepsize)
    {
      initialize(w);
      forwardsolve(w);
      temp=getsolution(w);

      if (w<wmin+0.5*stepsize)
	temp0=temp/w;
      
      sum1[asize]=(temp-pow(w,4)*M_PI/2.)/w*stepsize;
      sum2[asize]=(temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3)*stepsize;

      printf("Sumrules :w=%f, %.12g, %.12g\n\n",w,(temp-pow(w,4)*M_PI/2.)/w,(temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3)); 
     sfout << w << "\t" << (temp-pow(w,4)*M_PI/2.)/w;
     sfout << "\t" << (temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3) << endl;
     lastw=w;

     asize++;
    }

  double ssum1,ssum2;

  ssum1=0;
  ssum2=0;

  
  //trapezoid:
  for (int i=1;i<asize-1;i++)
    {
      ssum1+=sum1[i];
      ssum2+=sum2[i];
    }
  ssum1+=0.5*(sum1[0]+sum1[asize-1]);
  ssum2+=0.5*(sum2[0]+sum2[asize-1]);
  
  printf("Trapezoid:\n");
  printf("Sum rule one: %.12g\n",ssum1/M_PI);
  printf("sum rule two: %.12g\n",4*ssum2/M_PI-2/M_PI/lastw);

  ssum1=0;
  ssum2=0;

  //higher order:
  for (int i=3;i<asize-3;i++)
    {
      ssum1+=sum1[i];
      ssum2+=sum2[i];
    }
  ssum1+=3./8.*(sum1[0]+sum1[asize-1]);
  ssum2+=3./8.*(sum2[0]+sum2[asize-1]);
  ssum1+=7./6.*(sum1[1]+sum1[asize-2]);
  ssum2+=7./6.*(sum2[1]+sum2[asize-2]);
  ssum1+=23./24.*(sum1[2]+sum1[asize-3]);
  ssum2+=23./24.*(sum2[2]+sum2[asize-3]);

  printf("O(N^4):\n");
  printf("Sum rule one: %.12g\n",ssum1/M_PI);
  printf("sum rule two: %.12g\n",4*ssum2/M_PI-2/M_PI/lastw);

  out.close();
  sfout.close();

  delete [] sum1;
  delete [] sum2;
}


void allsumrulestry(double wmin, double wmax, double stepsize)
{

  UPDATE=NUM;
  fstream sfout;

  int reducer=80;
  double wswitch=0.1;

  printf("This is the numerical solver for AdS correlators\n");
  printf("Checking sumrule 1 and 2:\n");
  out.open("sol.dat",ios::out);
  sfout.open("sf.dat",ios::out);
  double temp0=0.5;
  double lastw=wmax;
  double temp=0;
  double es1=0,es2=0,es3=0,es4=0;

  int asize=0;

  stepsize/=reducer;

  //dry-run:
  for (double w=wmin;w<wswitch+2*wmin;w+=stepsize)
    {
      asize++;
    }

  double *sum1,*sum2;
  sum1=new double[asize];
  sum2=new double[asize];

  asize=0;

  for (double w=wmin;w<wswitch+2*wmin;w+=stepsize)
    {
      initialize(w);
      forwardsolve(w);
      temp=getsolution(w);

      if (w<wmin+0.5*stepsize)
	temp0=temp/w;
      
      sum1[asize]=(temp-pow(w,4)*M_PI/2.)/w*stepsize;
      sum2[asize]=(temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3)*stepsize;

      printf("Sumrules :w=%f, %.12g, %.12g\n\n",w,(temp-pow(w,4)*M_PI/2.)/w,(temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3)); 
     sfout << w << "\t" << (temp-pow(w,4)*M_PI/2.)/w;
     sfout << "\t" << (temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3) << endl;
     lastw=w;

     asize++;
    }

  double ssum1,ssum2;

  ssum1=0;
  ssum2=0;

  
  //trapezoid:
  for (int i=1;i<asize-1;i++)
    {
      ssum1+=sum1[i];
      ssum2+=sum2[i];
    }
  ssum1+=0.5*(sum1[0]+sum1[asize-1]);
  ssum2+=0.5*(sum2[0]+sum2[asize-1]);
  
  

  double ssum12=0;
  double ssum22=0;

  //higher order:
  for (int i=3;i<asize-3;i++)
    {
      ssum12+=sum1[i];
      ssum22+=sum2[i];
    }
  ssum12+=3./8.*(sum1[0]+sum1[asize-1]);
  ssum22+=3./8.*(sum2[0]+sum2[asize-1]);
  ssum12+=7./6.*(sum1[1]+sum1[asize-2]);
  ssum22+=7./6.*(sum2[1]+sum2[asize-2]);
  ssum12+=23./24.*(sum1[2]+sum1[asize-3]);
  ssum22+=23./24.*(sum2[2]+sum2[asize-3]);

  
  delete [] sum1;
  delete [] sum2;
  

  es1=ssum1;
  es2=ssum2;
  es3=ssum12;
  es4=ssum22;

  stepsize*=reducer;
  lastw=wmax;
  temp=0;
  asize=0;

  //dry-run:
  for (double w=wswitch+wmin;w<wmax+2*wmin;w+=stepsize)
    {
      asize++;
    }

  sum1=new double[asize];
  sum2=new double[asize];

  asize=0;

  for (double w=wswitch+wmin;w<wmax+2*wmin;w+=stepsize)
    {
      initialize(w);
      forwardsolve(w);
      temp=getsolution(w);

      if (w<wmin+0.5*stepsize)
	temp0=temp/w;
      
      sum1[asize]=(temp-pow(w,4)*M_PI/2.)/w*stepsize;
      sum2[asize]=(temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3)*stepsize;

      printf("Sumrules :w=%f, %.12g, %.12g\n\n",w,(temp-pow(w,4)*M_PI/2.)/w,(temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3)); 
     sfout << w << "\t" << (temp-pow(w,4)*M_PI/2.)/w;
     sfout << "\t" << (temp-pow(w,4)*M_PI/2.-temp0*w)/pow(w,3) << endl;
     lastw=w;

     asize++;
    }


  //trapezoid:
  for (int i=1;i<asize-1;i++)
    {
      ssum1+=sum1[i];
      ssum2+=sum2[i];
    }
  ssum1+=0.5*(sum1[0]+sum1[asize-1]);
  ssum2+=0.5*(sum2[0]+sum2[asize-1]);
  
  //higher order:
  for (int i=3;i<asize-3;i++)
    {
      ssum12+=sum1[i];
      ssum22+=sum2[i];
    }
  ssum12+=3./8.*(sum1[0]+sum1[asize-1]);
  ssum22+=3./8.*(sum2[0]+sum2[asize-1]);
  ssum12+=7./6.*(sum1[1]+sum1[asize-2]);
  ssum22+=7./6.*(sum2[1]+sum2[asize-2]);
  ssum12+=23./24.*(sum1[2]+sum1[asize-3]);
  ssum22+=23./24.*(sum2[2]+sum2[asize-3]);

  printf("\n\n Results for %i iterations\n",NUM);
  printf("Trapezoid:\n");
  printf("Sum rule one for wmax=%f, deltaw=%f: %.12g\n",lastw,stepsize,ssum1/M_PI);
  printf("sum rule two for wmax=%f, deltaw=%f: %.12g\n",lastw,stepsize,4*ssum2/M_PI-2/M_PI/lastw);

  printf("O(N^4):\n");
  printf("Sum rule one: %.12g\n",ssum12/M_PI);
  printf("sum rule two: %.12g\n",4*ssum22/M_PI-2/M_PI/lastw);

  //printf("results 1 %.12g and %.12g where %.12g\n",(ssum1-es1)/M_PI,4*(ssum2-es2)/M_PI-2/M_PI/lastw,2/M_PI/lastw);
  //printf("N^4: results 1 %.12g and %.12g\n ",(ssum12-es3)/M_PI,4*(ssum22-es4)/M_PI-2/M_PI/lastw);

  printf("Error estimate: reduce wmax by 0.1:\n");
  asize-=(int)(0.1/stepsize);
  double redlastw=0.1+wmin+(asize-1)*stepsize;
  //printf("reduced lastw=%.12g\n",redlastw);

  
  //trapezoid:
  for (int i=1;i<asize-1;i++)
    {
      ssum1-=sum1[i];
      ssum2-=sum2[i];
    }
  ssum1-=0.5*(sum1[0]+sum1[asize-1]);
  ssum2-=0.5*(sum2[0]+sum2[asize-1]);
  
  //higher order:
  for (int i=3;i<asize-3;i++)
    {
      ssum12-=sum1[i];
      ssum22-=sum2[i];
    }
  ssum12-=3./8.*(sum1[0]+sum1[asize-1]);
  ssum22-=3./8.*(sum2[0]+sum2[asize-1]);
  ssum12-=7./6.*(sum1[1]+sum1[asize-2]);
  ssum22-=7./6.*(sum2[1]+sum2[asize-2]);
  ssum12-=23./24.*(sum1[2]+sum1[asize-3]);
  ssum22-=23./24.*(sum2[2]+sum2[asize-3]);

  
  ssum1-=es1;
  ssum2-=es2;
  ssum12-=es3;
  ssum22-=es4;

  printf("Trapezoid:\n");
  printf("Error est sum rule one for wmax=%f, deltaw=%f: %.12g\n",redlastw,stepsize,ssum1/M_PI);
  printf("Error est sum rule two for wmax=%f, deltaw=%f: %.12g\n",redlastw,stepsize,4*ssum2/M_PI-2/M_PI/lastw+2/M_PI/redlastw);

  printf("O(N^4):\n");
  printf("Error est sum rule one: %.12g\n",ssum12/M_PI);
  printf("Error est sum rule two: %.12g\n",4*ssum22/M_PI-2/M_PI/lastw+2/M_PI/redlastw);

  out.close();
  sfout.close();

  delete [] sum1;
  delete [] sum2;
}


int main() 
{

  //singlesol(1.0);
  
  //multisol();
  //const long int NUM=40000000;  //gives good first sumrule
  //sumrule1(0.000001,2.05,0.05); //because extends far out. No great w resolution needed
  allsumrulestry(0.0000001,6.00,0.01);
  //allsumrules(0.000001,1.00,0.1);

  return 0;

}

