#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <vector>
using namespace std;

#include "findroot.hpp"

namespace cpl {

//      Declare some utility functions

static void printHead ( const char *algorithm, double accuracy, ostream& os );
static void printStep ( int step, double x, double dx,
                                    double f_of_x, ostream& os );
static void printWarning ( int maxSteps );

//      class RootFinder functions

void RootFinder::guessRoot ( double xGuess )
{
        x0 = xGuess;
}

void RootFinder::guessStep ( double stepSize )
{
        dx = stepSize;
        x1 = x0 + dx;
}

void RootFinder::secondGuessRoot ( double xGuess )
{
        x1 = xGuess;
        dx = x1 - x0;
}

int RootFinder::bracketRoot ( double f ( double x ) )
{
     // This function will attempt to bracket the root of a function
     // starting with the RootFinder data values x0 and x1, by
     // geometrically expanding the interval [x0, x1] until the root
     // is bracketed, or the number of steps exceeds maxSteps

        //      a convenient expansion factor
        const double expansionFactor = 1.6;

        if ( x0 == x1 ) {
                cerr << "\n RootFinder::bracketRoot: sorry, x0 = x1!\n"
                        << " x0 = " << x0 << endl
                        << " x1 = " << x1 << endl;
                return 0;
        }

        int step = 0;
        double f0 = f(x0);
        double f1 = f(x1);

        while ( ++step <= maxSteps ) {
                if ( f0 * f1 <= 0 )
                        return 1;         // success .. root is bracketed
                if ( abs(f0) < abs(f1) ) { // x0 probably closer to root
                        x0 -= expansionFactor * dx;
                        f0 = f(x0);
                } else {                          // x1 probably closer to the root
                        x1 += expansionFactor * dx;
                        f1 = f(x1);
                }
                dx = x1 - x0;             // also adjust dx
        }

        printWarning(maxSteps);   // maxSteps has been exceeded
        return 0;
}

void RootFinder::setAccuracy ( double epsilon )
{
        if ( epsilon > 1e-38 )
                accuracy = epsilon;
        else {
                cerr << " RootFinder: you have requested a ridiculous"
                        << " accuracy: " << accuracy << " !!!" << endl;
        }
}

void RootFinder::setMaxSteps ( int steps )
{
        if ( steps > 0 )
                maxSteps = steps;
        else {
                cerr << " RootFinder: you have requested a ridiculous"
                        << " number of steps: " << steps << " !!!" << endl;
        }
}

//      class SimpleSearch functions

double SimpleSearch::findRoot ( double f ( double x ) )
{
        double f0 = f(x0);
        double f_of_x = f0;

        int step = 0;
        if ( verbose ) {
                printHead("Simple Search with Step-Halving", accuracy, *os);
                printStep(step, x0, dx, f_of_x, *os);
        }

        while ( abs(dx) > accuracy && f_of_x != 0 && ++step <= maxSteps ) {
                x0 += dx;                         //    take a step
                f_of_x = f(x0);
                if ( f0 * f_of_x < 0 ) {        //      jumped past root
                        x0 -= dx;                       //      backup
                        dx /= 2;                        //      halve the step size
                }
                if ( verbose )
                        printStep(step, x0, dx, f_of_x, *os);
        }

        if ( step > maxSteps )
                printWarning(maxSteps);
        steps = step;
        return x0;
}

//      class BisectionSearch functions

double BisectionSearch::findRoot ( double f ( double x ) )
{
        double f0 = f(x0);
        double f1 = f(x1);

        if ( f0 * f1 > 0 ) {
                cerr << " BisectionSearch: sorry, root not bracketed!\n"
                        << " f(" << x0 << ") = " << f0 << endl
                        << " f(" << x1 << ") = " << f1 << endl
                        << " Trying to bracket the root using bracketRoot ..."
                        << flush;
                double save_x0 = x0;
                double save_x1 = x1;
                if ( bracketRoot(f) ) {
                        cerr << " Bracketing succeeded !\n"
                                << " x0 = " << x0 << " x1 = " << x1
                                << " continuing ..." << endl;
                        f0 = f(x0);
                        f1 = f(x1);
                } else {
                        cerr << " Sorry, bracketing attempt failed" << endl;
                        x0 = save_x0;
                        x1 = save_x1;
                        return abs(f0) < abs(f1) ? x0 : x1;
                }
        }
        if ( f0 == 0 )
                return x0;
        if ( f1 == 0 )
                return x1;

        double xHalf, fHalf = 0.5 * (f0 + f1);
        int step = 0;
        if ( verbose ) {
                printHead("Bisection Search", accuracy, *os);
                printStep(step, x0, x1 - x0, fHalf, *os);
        }
        do {                                         //         iteration loop
                if ( ++step > maxSteps )
                        break;
                xHalf = 0.5 * (x0 + x1);        //      bisection point
                fHalf = f(xHalf);
                if ( f0 * fHalf > 0 ) {  //     x0 and xHalf on same side of root
                        x0 = xHalf;             //      replace x0 by xHalf
                        f0 = fHalf;
                } else {                 //     x1 and xHalf on same side of root
                        x1 = xHalf;             //      replace x1 by xHalf
                        f1 = fHalf;
                }
                if ( verbose )
                        printStep(step, x0, x1 - x0, fHalf, *os);
        } while ( abs(x1 - x0) > accuracy && fHalf != 0);

        if ( step > maxSteps )
                printWarning(maxSteps);
        steps = step;
        return xHalf;
}

//      class SecantSearch functions

double SecantSearch::findRoot ( double f ( double x ) )
{
        double f0 = f(x0);
        double f1 = f(x1);

        if ( f0 == 0 )
                return x0;
        if ( f1 == 0 )
                return x1;

        int step = 0;
        if ( verbose ) {
                printHead("Secant Search", accuracy, *os);
                printStep(step, x0, x1 - x0, f1, *os);
        }
        do {
                if ( ++step > maxSteps )
                        break;
                if ( f0 == f1 ) {
                        cerr << " Secant Search: f(x0) = f(x1), algorithm fails!\n"
                                << " f(" << x0 << ") = " << f0 << endl
                                << " f(" << x1 << ") = " << f1 << endl;
                        break;
                }
                dx *= - f1 / ( f1 - f0 );
                x0 = x1;
                f0 = f1;
                x1 += dx;
                f1 = f(x1);
                if ( verbose )
                        printStep(step, x0, dx, f1, *os);
        } while ( abs(dx) > accuracy && f1 != 0);

        if ( step > maxSteps )
                printWarning(maxSteps);
        steps = step;
        return x1;
}

//      class TangentSearch functions

double TangentSearch::findRoot ( double f ( double x ),
                                                   double fPrime ( double x ) )
{
        double f0 = f(x0);
        double fPrime0 = fPrime(x0);

        if ( f0 == 0 )
                return x0;

        if ( fPrime0 != 0 )
                dx = - f0 / fPrime0;

        int step = 0;
        if ( verbose ) {
                printHead("Tangent Search", accuracy, *os);
                printStep(step, x0, dx, f0, *os);
        }
        do {
                if ( ++step > maxSteps )
                        break;
                if ( fPrime0 == 0 ) {
                        cerr << " Tangent Search: f'(x0) = 0, algorithm fails!\n"
                                << " f(" << x0 << ") = " << f0 << endl
                                << " f'(" << x0 << ") = " << fPrime0 << endl;
                        break;
                }
                dx = - f0 / fPrime0;
                x0 += dx;
                f0 = f(x0);
                fPrime0 = fPrime(x0);
                if ( verbose )
                        printStep(step, x0, dx, f0, *os);
        } while ( abs(dx) > accuracy && f0 != 0);

        if ( step > maxSteps )
                printWarning(maxSteps);
        steps = step;
        return x0;
}

//      Utility functions

static void printHead ( const char *algorithm, double accuracy, ostream& os )
{
        os << "\n ROOT FINDING using " << algorithm
           << "\n Requested accuracy = " << accuracy
           << "\n Step     Guess For Root          Step Size           Function Value"
           << "\n ----  --------------------  --------------------  --------------------"
           << endl;
}

static void printStep ( int step, double x, double dx,
                                    double f_of_x, ostream& os )
{
        int w = os.width();
        int p = os.precision();
        ios::fmtflags f = os.flags();
        os.setf(ios::right, ios::adjustfield);
        os << " " << setw(4) << step << "  ";
        os.setf(ios::left, ios::adjustfield);
        os << setprecision(14)
           << setw(20) << x << "  "
           << setw(20) << dx << "  "
           << setw(20) << f_of_x
           << endl;
        os.width(w);
        os.precision(p);
        os.setf(f);
}

static void printWarning ( int maxSteps )
{
        cerr << " Warning: maximum number of steps "
                << maxSteps << " exceeded!" << endl;
}

}  /* end namespace cpl */
