#!/usr/bin/env python

#
# Copyright (C) 2021, 2023
# Smithsonian Astrophysical Observatory
#
#
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA.
#

"""Compute monochromatic energy based on spectral model
(incl responses) and band"""

import sys

import numpy as np

from paramio import pset, paramopen, paramclose

from sherpa.astro import ui

import ciao_contrib.logger_wrapper as lw
from ciao_contrib.param_soaker import get_params


TOOLNAME = "find_mono_energy"
__REVISION__ = "25 September 2023"

DATASET_ID = "myid"

# pylint: disable=undefined-variable

# Setup verbose
lgr = lw.initialize_logger(TOOLNAME)
verb1 = lgr.verbose1
verb2 = lgr.verbose2
verb3 = lgr.verbose3
verb4 = lgr.verbose4


def setup_dataspace(arffile, rmffile):
    """Setup the sherpa dataspace based on RMF grid and associate
    ARF and RMF with that dataset.

    Parameters
    ----------
    arf : str
        The name of the ARF to use.
    emf : str
        The name of the RMF to use.

    """

    if rmffile.strip() != '':
        verb3(f"Loading RMF from {rmffile}")
        rmf = ui.unpack_rmf(rmffile)
        verb3(f"RMF contains {rmf.detchans} channels, starting at {rmf.offset}")
    else:
        rmf = None

    if arffile.strip() != '':
        verb3(f"Loading ARF from {arffile}")
        arf = ui.unpack_arf(arffile)
        verb3(f"ARF contains {len(arf.specresp)} entries")
    else:
        arf = None

    if arf is None and rmf is None:
        raise ValueError("At least one of arffile or rmffile must be set")

    if rmf is not None:
        chan_lo = rmf.offset
        chan_hi = rmf.detchans
    else:
        chan_lo = 1
        chan_hi = len(arf.specresp)

    ui.dataspace1d(chan_lo, chan_hi, id=DATASET_ID, dstype=ui.DataPHA)

    if arf is not None:
        ui.set_arf(DATASET_ID, arf)

    if rmf is not None:
        ui.set_rmf(DATASET_ID, rmf)


def setup_model(model, model_params):
    """Set the source model

    Parameters
    ----------
    model : str
        The model expression containing modelname.instancename terms
        following Sherpa.
    model_params : str
        The parameter expression. This can actually be any valid Python
        string but is expected to be instancename.paramname=value terms
        separated by semi-colons. It can also contain commands like
        'set_xsabund("lodd")'. This string can be empty.

    """

    try:
        ui.set_model(DATASET_ID, model)
    except Exception as mybad:
        raise ValueError(f"ERROR with model={model}\n{mybad}") from mybad

    verb3("Starting source model:")
    verb3(ui.get_source(DATASET_ID))

    if model_params.strip() == '':
        return

    try:
        # Evaluate the parameter expressions within the ui namespace.
        #
        # pylint: disable=exec-used
        exec(model_params, vars(ui))
    except Exception as mybad:
        raise ValueError(f'Error with paramvals={model_params}\n{mybad}') from mybad


def setup_band(band):
    "Parse the energy band"

    csc = {'soft': [0.5, 1.2], 'medium': [1.2, 2.0], 'hard': [2.0, 7.0],
           'broad': [0.5, 7.0], 'ultrasoft': [0.2, 0.5],
           'wide': [0.1, 10.0]}

    if band in csc:
        elo, ehi = csc[band]
    elif band.count(":") != 1:
        raise ValueError(f"Unknown band value {band}")
    else:
        try:
            elo, ehi = [float(x) for x in band.split(":")]
        except ValueError:
            raise ValueError(f"Band ranges must be numbers: sent '{band}'") from None

    verb2(f"Selecting energy range {elo} - {ehi} keV")

    # This will error out if elo > ehi but the error message is a
    # bit unclear here.
    #
    if elo > ehi:
        raise ValueError(f"Invalid band '{band}': low must be <= high")

    if elo < 0:
        raise ValueError(f"Invalid band '{band}': low must be >= 0")

    ui.notice_id(DATASET_ID, elo, ehi)


def compute_metric(metric):
    """Compute the metric based on the fluxed spectrum

    Parameters
    ----------
    metric : {"mean", "max"}
        How is the monochromatic energy to be calculated?

    """

    verb3("Full model expression:")
    verb3(ui.get_model(DATASET_ID))

    # Ensure the analysis is in energy units and the model is plotted
    # as counts, not counts/kev (we don't care about whether it's
    # normalised by the exposure time). This is true for CIAO 4.14;
    # it may change in the future - e.g. see the exploratory work
    # at https://github.com/sherpa/sherpa/pull/1085
    #
    # Note that for "normal" Chandra observations this does not matter,
    # since the bin-width is constant in energy space, but this could
    # be used for non-Chandra data or for 'unusual' Chandra data.
    #
    from sherpa.utils.logging import SherpaVerbosity

    with SherpaVerbosity("ERROR"):
        ui.set_analysis(id=DATASET_ID, quantity='energy', type='counts')
    verb4(f"Data:\n{ui.get_data(DATASET_ID)}")

    # It would be nice if Sherpa gave us a nice way to do
    # this, but it's currently a bit messy (in particular
    # sorting out the bin edges), so rely on this approach.
    #
    dater = ui.get_model_plot(DATASET_ID)
    verb4(f"Model:\n{dater}")

    # Check the y-axis units: this should not happen but just in case.
    # It is also technically dependent on the plottin backend in use by
    # Sherpa (e.g. to potentially add LaTeX-like capabilities).
    #
    if dater.xlabel != 'Energy (keV)':
        verb1(f"Note: the model evaluation is not in keV but {dater.xlabel}")

    if dater.ylabel not in ['Counts', 'Counts/s']:
        verb1(f"Note: the model evaluation creates {dater.ylabel} and not Counts")

    emid = dater.x
    flux = dater.y

    if metric == "mean":
        mean_e = np.sum(flux * emid) / np.sum(flux)
    elif metric == "max":
        mean_e = emid[np.argmax(flux)]
    else:
        raise ValueError(f"Unknown metric '{metric}'")

    return mean_e


def pset_output(mono_energy):
    "Save output back to own parameter file"

    # having to call preprocess_arglist here is not ideal
    args = lw.preprocess_arglist(sys.argv)
    pfile = paramopen(TOOLNAME, "rwHL", args)   # allow @@
    pset(pfile, "energy", str(mono_energy))
    paramclose(pfile)


def clear_output():
    "Clear output at startup"
    pset_output("")


@lw.handle_ciao_errors(TOOLNAME, __REVISION__)
def main():
    "Main routine"

    clear_output()

    # Load parameters
    pars = get_params(TOOLNAME, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    setup_dataspace(pars["arffile"], pars["rmffile"])
    setup_model(pars["model"], pars["paramvals"])
    setup_band(pars["band"])
    mono_energy = compute_metric(pars["metric"])

    verb1(f"Characteristic Energy = {mono_energy:.4g}")

    pset_output(mono_energy)


if __name__ == "__main__":
    main()
