#!/opt/conda/envs/ciao-4.17.0/bin/python3
# 
#  Copyright (C) 2009-2010,2012,2015-2016,2021  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#



from __future__ import print_function
import sys
import numpy
import paramio
import sherpa.astro.ui as sherpa
from sherpa.astro.xspec import set_xschatter
from sherpa.utils.logging import SherpaVerbosity


__verbose=None


def set_verbose(verb):
    global __verbose
    __verbose=verb
    if verb > 4:
        chatter = 25
    elif verb > 1:
        chatter = 10
    else:
        chatter = 0
    set_xschatter(chatter)


def get_verbose():
    return __verbose


def _check_if_nan(arg, format="%.4e"):
    val = "INDEF"
    if arg is not None:
        if not (numpy.isnan(arg) or numpy.isinf(arg)):
            val = format % arg
    return val

def read_param_file(filename):

    # try:
    #     pfile = paramio.paramopen(filename, "rwL", sys.argv)
    # except:
    #     raise IOError("could not open parameter file '%s'" % filename)

    args = sys.argv[:]
    if len(args) > 1 and args[1].startswith("@@"):
        try:
            pfile = paramio.paramopen(None, "rw", args)
        except:
            raise IOError("could not open parameter file '%s'" % filename)
    else:
        try:
            pfile = paramio.paramopen(filename, "rwL", sys.argv)
        except:
            raise IOError("could not open parameter file '%s'" % filename)

    params = {}

    # ARF file
    params["arf"]  = paramio.pgetstr(pfile,"arf")

    # RMF file
    params["rmf"]  = paramio.pgetstr(pfile,"rmf")

    # Sherpa model definition string
    params["mdl"]  = paramio.pgetstr(pfile,"model")
    params["mdl"]  = params["mdl"].strip("\"' ")

    # ';' delimited string of parameter values
    # If comma-delimited, turn it into ';' delimited
    params["vals"] = paramio.pgetstr(pfile,"paramvals")
    params["vals"]  = params["vals"].strip("\"' ")
    params["vals"]  = params["vals"].replace(",",";")

    # upper and lower energy in keV
    params["emin"] = None
    if paramio.pgetstr(pfile,"emin").upper() != "INDEF":
        params["emin"] = paramio.pgetd(pfile,"emin")

    params["emax"] = None
    if paramio.pgetstr(pfile,"emax") != "INDEF":
        params["emax"] = paramio.pgetd(pfile,"emax")

    # Absorption model, for calculating unabsorbed flux
    params["absmdl"]  = paramio.pgetstr(pfile,"absmodel")
    params["absmdl"]  = params["absmdl"].strip("\"' ")

    # Absorption model parameters, for calculating unabsorbed flux
    # ';' delimited string of parameter values
    # If comma-delimited, turn it into ';' delimited
    params["absvals"] = paramio.pgetstr(pfile,"absparams")
    params["absvals"]  = params["absvals"].strip("\"' ")
    params["absvals"]  = params["absvals"].replace(",",";")

    # XSPEC abundance table setting
    params["abund"] = paramio.pgetstr(pfile,"abund")
    
    # observed instrumental upper and lower energy in keV
    params["oemin"] = params["emin"]
    if paramio.pgetstr(pfile,"oemin").upper() != "INDEF":
        params["oemin"] = paramio.pgetd(pfile,"oemin")

    params["oemax"] = params["emax"]
    if paramio.pgetstr(pfile,"oemax").upper() != "INDEF":
        params["oemax"] = paramio.pgetd(pfile,"oemax")

    # rate double, default=1.0
    # count rate in [counts s^-1]
    params["rate"]  = paramio.pgetd(pfile,"rate")

    # pflux double, default=None
    # photon flux in energy range [photon cm^-2 s^-1]
    params["pflux"] = None
    if paramio.pgetstr(pfile,"pflux").upper() != "INDEF":
        params["pflux"] = paramio.pgetd(pfile,"pflux")

    # flux double, default=None
    # energy flux in energy range [erg cm^-2 s^-1]
    params["flux"]  = None
    if paramio.pgetstr(pfile,"flux").upper() != "INDEF":
        params["flux"]  = paramio.pgetd(pfile,"flux")

    # opt string, default="rate", values=(rate|flux|pflux)
    # convert: rate->flux,pflux or flux->rate,pflux
    params["opt"]  = paramio.pgetstr(pfile,"opt")
    params["opt"]  = params["opt"].strip("\"' ")

    params["verbose"] = paramio.pgeti(pfile, "verbose")

    return params, pfile


def print_params(rate, pflux, flux, urate, upflux, uflux, oemin, oemax, emin, emax, mdl, absmdl):
    if get_verbose() > 0:

        # If the user left emin, emax as INDEF in the parameter file,
        # these are type None in this script.  So turn them into
        # floating point numbers, that are the minimum and maximum
        # energies defined by the RMF energy grid.  We can get those
        # bounds by asking Sherpa for its filter; when no filter is
        # defined, then get_filter() returns bounds covering the
        # entire energy grid.

        bounds = sherpa.get_filter()      # Will always be entire dataspace
                                          # even if we did set emin and emax
                                          # bounds
        if (emin is None):
            emin = float(bounds.split(":")[0])
        if (emax is None):
            emax = float(bounds.split(":")[-1])
        if (oemin is None):
            oemin = emin
        if (oemax is None):
            oemax = emax

        # Now print fluxes.  In unabsorbed case, print both absorbed
        # and unabsorbed fluxes.
        print('Model fluxes:\n' +
              'Rate (%.4g,%.4g)= %s count s^-1\n' %
              (oemin, oemax, _check_if_nan(rate, "%.5g")) +
              'Photon Flux (%.4g,%.4g)= %s photon cm^-2 s^-1\n' %
              (emin, emax, _check_if_nan(pflux, "%.5g")) + 
              'Energy Flux (%.4g,%.4g)= %s erg cm^-2 s^-1' %
              (emin, emax, _check_if_nan(flux, "%.5g")))

        if (urate is not None and
            upflux is not None and
            uflux is not None):
            print('Unabsorbed model fluxes:\n' +
                  'Rate (%.4g,%.4g)= %s count s^-1\n' %
                  (oemin, oemax, _check_if_nan(urate, "%.5g")) +
                  'Photon Flux (%.4g,%.4g)= %s photon cm^-2 s^-1\n' %
                  (emin, emax, _check_if_nan(upflux, "%.5g")) + 
                  'Energy Flux (%.4g,%.4g)= %s erg cm^-2 s^-1' %
                  (emin, emax, _check_if_nan(uflux, "%.5g")))

    if get_verbose() > 2:

        # N.B. -- in the calculations below, we really do mean to
        # divide *unabsorbed* fluxes by the ***absorbed*** count rate
        # (which is the count rate that was actually observed).  This
        # is what is needed for the user to check the math.  SMD 11/27/2012
        print("\nFluxes per Count:\n")
        # Count rate
        rate_output = "Count rate: (absorbed) %.5g" % rate
        if (urate is not None):
            rate_output = rate_output + " (unabsorbed) %.5g" % urate
        rate_output = rate_output + " count s^-1"
        print(rate_output)
        
        # Photon flux per count
        pflux_per_count = pflux/rate
        pflux_output = "Photon flux per count: (absorbed) %.5g" % pflux_per_count
        if (upflux is not None):
            upflux_per_count = upflux/rate
            pflux_output = pflux_output + " (unabsorbed) %.5g" % upflux_per_count
        pflux_output = pflux_output + " photon cm^-2 count^-1"
        print(pflux_output)

        # Energy flux per count
        flux_per_count = flux/rate
        flux_output = "Flux per count: (absorbed) %.5g" % flux_per_count
        if (upflux is not None):
            uflux_per_count = uflux/rate
            flux_output = flux_output + " (unabsorbed) %.5g" % uflux_per_count
        flux_output = flux_output + " erg cm^-2 count^-1"
        print(flux_output)
        
    if get_verbose() > 3:
        # Print all model parameters used in model calculations,
        # even those with values left undefined by user.

        print("\nModel Parameters:\n")
        sherpa.set_source(mdl)
        if (absmdl != ""):
            sherpa.set_source(absmdl + "*(" + mdl + ")")
        print(sherpa.get_source())

def write_params(pfile, rate, pflux, flux, urate, upflux, uflux):

    paramio.pputstr(pfile, "rate",  _check_if_nan(rate))
    paramio.pputstr(pfile, "pflux", _check_if_nan(pflux))
    paramio.pputstr(pfile, "flux",  _check_if_nan(flux))
    paramio.pputstr(pfile, "urate",  _check_if_nan(urate))
    paramio.pputstr(pfile, "upflux", _check_if_nan(upflux))
    paramio.pputstr(pfile, "uflux",  _check_if_nan(uflux))

    paramio.paramclose(pfile)



def main():

    params={}
    params, file = read_param_file("modelflux.par")

    set_verbose(params["verbose"])

    # If verbosity less than 2, shut off warnings from NumPy
    # per Doug's 11/21/2012 RFE (SMD 11/27/2012
    if get_verbose() < 2:
        numpy.seterr(all='ignore')

    # Set XSPEC abundance (can be set whether using unabsorbed flux
    # or not)
    if (sherpa.get_xsabund() != params["abund"]):
        sherpa.set_xsabund(params["abund"])

    if params["rmf"] == "" or params["rmf"].upper() == "INDEF":
        raise IOError("an RMF file is required")

    if params["mdl"] == "":
        raise IOError("The model parameter can not be empty.")

    rmf = sherpa.unpack_rmf(params["rmf"])

    counts = numpy.zeros(rmf.detchans)
    channel = numpy.arange(1, rmf.detchans+1)
    data = sherpa.DataPHA('', channel, counts)
    sherpa.set_data(data)
    sherpa.set_rmf(rmf)

    if not (params["arf"] == "" or params["arf"].upper() == "INDEF"):
        arf = sherpa.unpack_arf(params["arf"])

        # By not setting an exposure time for DataPHA, the convolved model returns
        # the predicted count rate, not predicted counts.
        arf.exposure = None
        sherpa.set_arf(arf)

    # Source = absmdl * (mdl) if calculating unabsorbed flux
    # Source = mdl otherwise

    with SherpaVerbosity("ERROR"):
        sherpa.set_analysis("energy")
    if (params["absmdl"] == ""):
        sherpa.set_source(params["mdl"])
        exec(params["vals"])
    else:
        sherpa.set_source(params["absmdl"] + "*(" + params["mdl"] + ")")
        exec(params["vals"])
        exec(params["absvals"])

    calc_cnt_rate = sherpa.calc_model_sum(lo=params["oemin"], # counts sec^-1
                                          hi=params["oemax"])
    calc_eflux = sherpa.calc_energy_flux(lo=params["emin"], # ergs cm^-2 sec^-1
                                         hi=params["emax"])
    calc_pflux = sherpa.calc_photon_flux(lo=params["emin"], # photons cm^-2 sec^-1
                                         hi=params["emax"])
    calc_cnt_rate_unabsorbed = None
    calc_eflux_unabsorbed = None
    calc_pflux_unabsorbed = None
    if (params["absmdl"] == ""):
        pass
    else:
        sherpa.set_source(params["mdl"])
        calc_cnt_rate_unabsorbed = sherpa.calc_model_sum(lo=params["oemin"],
                                                         hi=params["oemax"])
        calc_eflux_unabsorbed = sherpa.calc_energy_flux(lo=params["emin"],
                                                        hi=params["emax"])
        calc_pflux_unabsorbed = sherpa.calc_photon_flux(lo=params["emin"],
                                                        hi=params["emax"])

    params["urate"] = None
    params["upflux"] = None
    params["uflux"] = None
    opt = params["opt"].lower().strip()
    if opt == "flux":
        obs_flux = params["flux"]
        if obs_flux is None:
            raise IOError("please provide a flux, using 'flux' parameter")
        params["rate"] = (calc_cnt_rate/calc_eflux) * obs_flux
        params["pflux"] = (calc_pflux/calc_cnt_rate) * params["rate"]

    elif opt == "pflux":
        obs_flux = params["pflux"]
        if obs_flux is None:
            raise IOError("please provide a photon flux, using 'pflux' parameter")
        params["rate"] = (calc_cnt_rate/calc_pflux) * obs_flux
        params["flux"] = (calc_eflux/calc_cnt_rate) * params["rate"]

    elif opt == "rate":
        obs_rate = params["rate"]
        if obs_rate is None:
            raise IOError("please provide a count rate, using 'rate' parameter")
        params["flux"] = (calc_eflux/calc_cnt_rate) * obs_rate
        params["pflux"] = (calc_pflux/calc_cnt_rate) * obs_rate

    else:
        raise IOError("unknown option '%s'" % opt)

    # Now correct the unabsorbed fluxes, if they were calculated
    # Can base all on ratio of calculated absorbed count rate to 
    # output rate, since that ratio is now know for the absorbed fluxes
    # count rate.
    if (calc_cnt_rate_unabsorbed is not None):
        params["urate"] = (calc_cnt_rate_unabsorbed/calc_cnt_rate) * params["rate"]
    if (calc_pflux_unabsorbed is not None):
        params["upflux"] = (calc_pflux_unabsorbed/calc_cnt_rate) * params["rate"]
    if (calc_eflux_unabsorbed is not None):
        params["uflux"] = (calc_eflux_unabsorbed/calc_cnt_rate) * params["rate"]
        

    print_params(params["rate"], params["pflux"], params["flux"],
                 params["urate"], params["upflux"], params["uflux"],
                 params["oemin"], params["oemax"], params["emin"],
                 params["emax"], params["mdl"], params["absmdl"])

    write_params(file, params["rate"], params["pflux"], params["flux"], params["urate"], params["upflux"], params["uflux"],)


if __name__ == "__main__":

    import sys

    try:
    
        main()
    
    except Exception as e:
        sys.stderr.write("modelflux: ERROR: %s\n" % str(e))
        sys.exit(1)
