#!/usr/bin/env python

#
# Copyright (C) 2017, 2019, 2022, 2023
# Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"Compute color color diagrams"

import sys

import numpy as np

import stk

from ciao_contrib.param_soaker import get_params
from ciao_contrib._tools.fileio import outfile_clobber_checks
import ciao_contrib.logger_wrapper as lw

# Ensure the XSPEC models are loaded
from sherpa.astro.ui import *

from color_color import ColorColor, EnergyBand, ModelParameter


toolname = "color_color"
__revision__ = "13 January 2023"

lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2
verb3 = lgr.verbose3


# Setup the energy bands
def make_energy(band, token):
    """Convert band string into Energy Band Object

    Uses lookup for CSC energy bands based on token id
    """
    _csc = {'S': '0.5:1.2', 'M': '1.2:2.0', 'H': '2.0:7.0',
            'B': '0.5:7.0', 'U': '0.2:0.5', 'W': '0.1:10.0'}

    if band.lower() == 'csc':
        if token in _csc:
            band = _csc[token]
        else:
            raise ValueError("Unknown CSC band")

    if 0 == len(band) or 'none' == band.lower():
        et = None
    else:
        bb = [float(x) for x in band.split(":")]
        et = EnergyBand(bb[0], bb[1], token)
    return(et)


def make_model(param, grid, sample):
    'Setup model from parameters and grid'

    p = eval(param)
    p_grid = [float(x) for x in stk.build(grid)]
    mp = ModelParameter(p, p_grid, fine_grid_resolution=int(sample))
    return(mp)


@lw.handle_ciao_errors(toolname, __revision__)
def tool():
    'Main routine'

    pars = get_params(toolname, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb2})

    clobber = (pars["clobber"] == "yes")
    outfile_clobber_checks(clobber, pars["outfile"])
    outfile_clobber_checks(clobber, pars["outplot"])

    # Set numpy random seed (used by fake)
    if int(pars["random_seed"]) < 0:
        from random import getrandbits
        pars["random_seed"] = str(getrandbits(29))
    np.random.seed(int(pars["random_seed"]))

    # Setup the ColorColor object first.  Need this
    # so that model parameters are created.
    rmffile = pars["rmffile"]
    if 0 == len(rmffile) or 'none' == rmffile.lower():
        rmffile = None

    verb3(f"Using model expression {pars['model']}")
    model = eval(pars["model"])
    cc = ColorColor(model, pars["infile"], rmffile=rmffile)

    # Create the two model parameters to be varied
    mp1 = make_model(pars["param1"], pars["grid1"], pars["plot_oversample"])
    mp2 = make_model(pars["param2"], pars["grid2"], pars["plot_oversample"])

    # Create the energy bands
    eL = make_energy(pars["soft"], 'S')
    eM = make_energy(pars["medium"], 'M')
    eH = make_energy(pars["hard"], 'H')

    # Go to work
    from ciao_contrib.runtool import add_tool_history
    matrix = cc(mp1, mp2, eL, eM, eH, None)
    matrix.write(pars["outfile"], toolname=toolname)
    add_tool_history(pars["outfile"], toolname, pars, toolversion=__revision__)

    # Plot stuff?
    out_plot = pars["outplot"]
    if len(out_plot) > 0 and "none" != out_plot.lower():
        mp1.set_curve_style(color="forestgreen", linestyle="--", marker="")
        mp2.set_curve_style(color="black", linestyle="-", marker="")
        mp1.set_label_style(color="forestgreen")
        mp2.set_label_style(color="black")
        matrix.plot(out_plot)

        if 'yes' == pars["showplot"]:
            import matplotlib.pylab as plt
            plt.show()


if __name__ == "__main__":
        tool()
