#include "Quadetrend.hh"
#include "DVecType.hh"
#include "TSeries.hh"
#include <stdexcept>

using namespace std;

//======================================  Default Constructor
Quadetrend::Quadetrend(void)
  : mLength(0)
{}

//======================================  Data Constructor
Quadetrend::Quadetrend::Quadetrend(int N) {
    setLength(N);
}

//======================================  Destructor
Quadetrend::~Quadetrend(void) {
}

//======================================  Destructor
Quadetrend*
Quadetrend::clone(void) const {
    return new Quadetrend(*this);
}

//======================================  Set the filter length
void 
Quadetrend::setLength(int N) {
    mLength = N;
    if (!N) return;
    double sum1(N);
    double sumx(0);
    double sumxx(0);
    double sum3x(0);
    double sum4x(0);
    for (int i=0; i<N; ++i) {
        double x=i-N/2;
	sumx  += x;
	double xx=x*x;
	sumxx += xx;
	sum3x += x*xx;
	sum4x += xx*xx;
    }
    double a00 = sum1;
    double a01 = sumx;
    double a11 = sumxx;
    double a02 = sumxx;
    double a12 = sum3x;
    double a22 = sum4x;
    double det = a00*(a11*a22 - a12*a12) 
               - a01*(a01*a22 - a02*a12) 
               + a02*(a01*a12 - a02*a11);
    if (det == 0) throw logic_error("Quadetrend: Matrix is singular");

    mCovMtx[0] =  (a11*a22 - a12*a12)/det;
    mCovMtx[1] = -(a01*a22 - a02*a12)/det;
    mCovMtx[2] =  (a01*a12 - a11*a02)/det;
    mCovMtx[3] =  (a00*a22 - a02*a02)/det;
    mCovMtx[4] = -(a00*a12 - a02*a01)/det;
    mCovMtx[5] =  (a00*a11 - a01*a01)/det;

#ifdef DEBUG
    cout << "Covariance matrix: " << endl;
    cout << mCovMtx[0] << " " << mCovMtx[1] << " " << mCovMtx[2] << endl;
    cout << mCovMtx[1] << " " << mCovMtx[3] << " " << mCovMtx[4] << endl;
    cout << mCovMtx[2] << " " << mCovMtx[4] << " " << mCovMtx[5] << endl;
#endif

    double i00 = a00 * mCovMtx[0] + a01*mCovMtx[1] + a02 * mCovMtx[2];
    double i01 = a00 * mCovMtx[1] + a01*mCovMtx[3] + a02 * mCovMtx[4];
    double i02 = a00 * mCovMtx[2] + a01*mCovMtx[4] + a02 * mCovMtx[5];
    double i11 = a01 * mCovMtx[1] + a11*mCovMtx[3] + a12 * mCovMtx[4];
    double i12 = a01 * mCovMtx[2] + a11*mCovMtx[4] + a12 * mCovMtx[5];
    double i22 = a02 * mCovMtx[2] + a12*mCovMtx[4] + a22 * mCovMtx[5];
    const double tol(1e-15);
    if (fabs(i00-1.0) > tol || fabs(i01) > tol || fabs(i02) > tol ||
	fabs(i11-1.0) > tol || fabs(i12) > tol || fabs(i22 -1.0) > tol)
        throw logic_error("Quadetrend: Matrix inversion failed");

#ifdef DEBUG
    cout << "Covariance * Hessian  (Ident?) matrix: " << endl;
    cout << i00 << " " << i01 << " " << i02 << endl;
    cout << i01 << " " << i11 << " " << i12 << endl;
    cout << i02 << " " << i12 << " " << i22 << endl;
#endif
}

template<class T>
inline void
QDtrend(DVecType<T>& dv, const double CovMtx[]) {
    long N = dv.getLength();
    double sumY   = 0.0;
    double sumYx  = 0.0;
    double sumYxx = 0.0;
    T* pY = dv.refTData();
    for (long i=0; i<N; i++) {
        double y = pY[i];
	double x = i-N/2;
	sumY +=  y;
	y *= x;
	sumYx  += y;
	sumYxx += y*x;
    }

    //----------------------------------  Calculate quadratic coefficients
    double a = CovMtx[0]*sumY + CovMtx[1]*sumYx + CovMtx[2]*sumYxx;
    double b = CovMtx[1]*sumY + CovMtx[3]*sumYx + CovMtx[4]*sumYxx;
    double c = CovMtx[2]*sumY + CovMtx[4]*sumYx + CovMtx[5]*sumYxx;

#ifdef DEBUG
    cout << "Quadetrend coefficients: " << a << " " << b << " " << c << endl;
#endif

    //----------------------------------  Detrend
    for (long i=0; i<N; i++) {
        double x = i - N/2;
	*pY++ -= (c * x + b) * x + a;
    }
}

//======================================  Detrend a time series
TSeries 
Quadetrend::apply(const TSeries& ts) {
    if (!ts.getNSample()) return ts;
    dataCheck(ts);
    if (!mStartTime) mStartTime = ts.getStartTime();

    //---------------------------------  Sum over samples
    TSeries r(ts);
    switch (r.refDVect()->getType()) {
    case DVector::t_double:
	QDtrend(dynamic_cast<DVecType<double>&>(*r.refDVect()), mCovMtx);
	break;
    default:
	r.Convert(DVector::t_float);
	QDtrend(dynamic_cast<DVecType<float>&>(*r.refDVect()), mCovMtx);
	break;
    }
    return r;
}

//======================================  Chec the data length
void 
Quadetrend::dataCheck(const TSeries& ts) const {
    if (ts.getNSample() != mLength) throw runtime_error("Bad series length");
}

//======================================  Reset/invalidate
void 
Quadetrend::reset(void) {
    mLength = 0;
}
