#include <cmath>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
using namespace std;

// data structures ------------------------------------------------------------

struct DataPoint {
    double t;                   // time
    double n;                   // N(t)
};

struct DataSet {
    double dt;                  // time step Delta t
    vector<DataPoint> points;   // vector of data points
};

// global variables -----------------------------------------------------------

double n0 = 1000;               // number of nuclei at time t = 0
double tau = 1;                 // mean lifetime
double tMax = 5;                // time to stop simulation
double dt = 0.2;                // time step Delta t for Euler algorithm

vector<DataSet> results;        // all results of a computation

// function declarations ------------------------------------------------------

void getConstants();            // user inputs N(0), tau, t_max
void doComputation();           // solve equation for several values of dt
void plotResults();             // use Gnuplot to plot all results

double getDouble(string prompt);// utility function to enter real number
bool getYesNo(string prompt);   // utility function to get Yes/No response
string convert(int i);          // converts an integer to a string

// function definitions -------------------------------------------------------

int main() {
    // loop to perform several computations
    do {
        getConstants();
        doComputation();
        plotResults();
    } while (getYesNo("Another computation?") == true);
}

void getConstants() {

    // print old values of physical constants which are global variables
    cout << endl
         << "Simulation of Radioactive Decay\n"
         << "-------------------------------\n"
         << "Current values of physical constants: \n"
         << "N(0) = " << n0 << ", tau = " << tau
         << ", t_max = " << tMax << endl;

    if (getYesNo("Change these values?") == false)
        return;

    n0 = getDouble("Enter new value of N(0)");
    tau = getDouble("Enter new value of tau");
    tMax = getDouble("Enter new value of t_max");
}

void doComputation() {

    results.clear();      // empty the results vector

    // loop to generate data for several values of time step dt
    do {
        // get new time step
        cout << "Current value of dt = " << dt << endl;
        dt = getDouble("Enter new value of dt");

        // initialize computational variables and data set
        double t = 0, n = n0;
        DataSet set;                    // new empty data set
        set.dt = dt;
        DataPoint p;                    // new data point
        p.t = t;
        p.n = n0;
        set.points.push_back(p);        // add the point to the set

        int iterations = int(tMax/dt);
        for (int i = 0; i < iterations; i++) {

            // one time step using Euler's algorithm
            p.n = n -= n / tau * dt;
            p.t = t += dt;

            set.points.push_back(p);    // add the new point to the set
        }

        results.push_back(set);         // add the new set to results vector

    } while (getYesNo("Another time step?") == true);
}

void plotResults() {

    // use Gnuplot to plot data sets stored in global variable results
    // write a Gnuplot script file, setting title, etc.
    ofstream scriptFile("script");
    scriptFile << "set title \"N(0) = " << n0 << ", tau = " << tau << "\"\n"
               << "set xlabel \"Time t (sec)\"\n"
               << "set ylabel \"Number of nuclei N(t)\"\n"
               << "f(x)=" << n0 << "*exp(-x/" << tau << ")\n"
               << "plot f(x) title \"Exact\"";

    // write the data sets in individual files
    for (int i = 0; i < results.size(); i++) {

        // create an individual file name
        string dataFileName("data_");
        dataFileName += convert(i);

        // open the file and write the data
        ofstream dataFile(dataFileName.c_str());
        for (int j = 0; j < results[i].points.size(); j++) {
            dataFile << results[i].points[j].t << '\t'
                     << results[i].points[j].n << '\n';
        }
        dataFile.close();

        // add plotting information for this set to the script file
        scriptFile << ", \"" << dataFileName << "\" "
                   <<" title \"dt = " << results[i].dt << "\"";
    }

    scriptFile << "\npause mouse\n";
    scriptFile.close();

    // call Gnuplot to execute the script file
    string gnuplot("\"C:\\Program Files\\gnuplot\\bin\\pgnuplot\"");
    string command = gnuplot + " script";
    system(command.c_str());

}

double getDouble(string prompt) {
    bool done = false;
    double value;
    while (!done) {
        cout << prompt << ": " << flush;
        cin >> value;
        if (cin.fail()) {           // conversion to double has failed
            cin.clear();            // reset cin
            string badResponse;
            cin >> badResponse;     // eat the illegal entry string
            cout << "Bad response ... ";
        } else done = true;
    }
    return value;
}

bool getYesNo(string prompt) {
    bool done = false, yes;
    while (!done) {
        cout << prompt << " Enter Y[es] or N[o]: " << flush;
        string response;
        cin >> response;
        switch (response[0]) {
            case 'y':
            case 'Y':
                yes = done = true;
                break;
            case 'n':
            case 'N':
                yes = false;
                done = true;
                break;
            default:
                cout << "Bad response ... ";
                break;
        }
    }
    return yes;
}

string convert(int i) {
    ostringstream os;
    os << i;
    return os.str();
}

