#!/usr/bin/env python
#
# Copyright (C) 2020 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'''Computes approximate gain maps for ACIS data and reports the
energy range of the dataset'''


import sys
import ciao_contrib.logger_wrapper as lw

TOOLNAME = "acis_check_pha_range"
__REVISION__ = "02 December 2020"

lw.initialize_logger(TOOLNAME)
VERB0 = lw.get_logger(TOOLNAME).verbose0
VERB1 = lw.get_logger(TOOLNAME).verbose1
VERB2 = lw.get_logger(TOOLNAME).verbose2
VERB3 = lw.get_logger(TOOLNAME).verbose3


class ChipInfo():
    "Hold infor for each individual ACIS chip."

    def __init__(self, cr):
        """Get chip info from current crate.

        The mask file has multiple FITS blocks, so multiple crates
        in the dataset.  We get the info for each crate.

        PHAMIN and PHAMAX in the mask file are the **READOUT**, pha_ro,
        values -- the limits on the PHA value computed on-orbit.

        The PHA in the event file has additional calibrations applied
        (CTI corrections).
        """

        self.energy_lo = None
        self.energy_hi = None
        self.gain_values = None

        self.chip = cr.get_key_value("ccd_id")
        self.pharo_lo = max(cr.get_column("phamin").values)
        self.pharo_hi = min(cr.get_column("phamax").values)
        if self.pharo_hi > 3800:
            VERB2("Replacing max={} with 3800".format(self.pharo_hi))
            self.pharo_hi = 3800

        VERB1("Processing data for CCD_ID={}".format(self.chip))
        VERB2("PHA_RO min = {} ADU".format(self.pharo_lo))
        VERB2("PHA_RO_max = {} ADU".format(self.pharo_hi))

    def compute_approximate_gain(self, evt, binsize):
        """Compute an **APPROXIMATE** gain

        The ACIS gain (ADU -> eV) is different for each chip and
        varies spatially across each chip.  The gain is non-linear
        and is time|temperature dependent.

        In this routine we compute an approximate gain correction
        computing the gain in {binsize} blocks -- so we try to
        capture the spatial variation -- but we ignore the
        time|temperature variation and the non-linearity.  This is
        mostly because we usually don't have enough counts.  If we did
        then we could compute a "soft" gain and a "hard" gain to apply
        to the upper and lower limits.

        We restrict to blocksize <= 256 since gain jumps across
        node boundaries.
        """
        VERB2("Computing gain for CCD_ID={}".format(self.chip))

        from pycrates import IMAGECrate
        binspec = "{evt}[ccd_id={ccd}][bin chip=::{blk};{{weight}}]"
        binspec = binspec.format(evt=evt, ccd=self.chip, blk=binsize)
        energy = binspec.format(weight="energy")
        pharo = binspec.format(weight="pha_ro")
        VERB3("Energy bin command: {}".format(energy))
        VERB3("PHA_RO bin command: {}".format(pharo))

        energy_img = IMAGECrate(energy, "r").get_image().values
        pharo_img = IMAGECrate(pharo, "r").get_image().values

        import numpy as np
        np.seterr(divide="ignore", invalid="ignore")
        gain_values = energy_img/pharo_img
        g_flat = gain_values.flatten()
        self.gain_values = g_flat[g_flat > 0]

        VERB3("Number of gain cells: {}".format(len(self.gain_values)))

        self.compute_energy_limit()

    def compute_energy_limit(self):
        """Compute energy limits (in eV)

        We take a conservative view here.  We want the max()
        low energy above which we expect to have received
        all events and we wan the min() high energy below which
        we expect to have received all events.

        Other options could be to locate region of interest
        on chip of interest and use values closest spatially.
        Or use some per-chip mean. Or, yadda, yadda, yadda ...
        """

        if len(self.gain_values) == 0:
            # No values, return nan's.
            from math import nan
            self.energy_lo = nan
            self.energy_hi = nan
            VERB2("No gain cells available")
            return

        # Low -- use max(gain)
        max_gain = max(self.gain_values)
        self.energy_lo = max_gain * self.pharo_lo
        VERB2("Max gain : {:.3f} eV/ADU".format(max_gain))

        # Upper -- use min(gain)
        min_gain = min(self.gain_values)
        self.energy_hi = min_gain * self.pharo_hi
        VERB2("Min gain : {:.3f} eV/ADU".format(min_gain))

    def summarize(self, header=False):
        """Summarize results"""
        fmt = "{chip}\t{plo}\t{phi}\t{elo:>6.1f}\t{ehi:>8.1f}"
        retval = fmt.format(chip=self.chip, plo=self.pharo_lo,
                            phi=self.pharo_hi, elo=self.energy_lo,
                            ehi=self.energy_hi)

        if header is True:
            VERB1("#ccd_id\tpha_lo\tpha_hi\tE_lo\tE_hi")
        VERB1(retval)


def check_pha_range(evtfile, mskfile="", binsize=32):
    """Main routine, this should be callable from other modules

    Returns a dictionary with lower and upper limits, along with
    per chip details.
    """

    from ciao_contrib._tools.obsinfo import ObsInfo
    o_info = ObsInfo(evtfile)
    if '' == mskfile:
        mskfile = o_info.get_ancillary("mask")
        VERB1("Using maskfile: '{}'".format(mskfile))

    if mskfile.lower() == 'none':
        raise ValueError("ERROR: mskfile is not optional; it " +
                         "cannot be {}".format(mskfile))

    if mskfile == '':
        raise ValueError("ERROR: The mskfile parameter cannot be blank")

    if binsize not in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
        raise ValueError("ERROR: binsize must be an integer " +
                         "power of 2 and <= 256")

    # Collect data from each block in mask file
    from pycrates import CrateDataset
    msk = CrateDataset(mskfile, "r")

    if "ACIS" != msk.get_crate(1).get_key_value("INSTRUME"):
        raise ValueError("ERROR: Must be an ACIS observation")

    if o_info.obsid.obsid != msk.get_crate(1).get_key_value("OBS_ID"):
        raise ValueError("ERROR: Wrong OBSID in mask file")

    chip_info = [ChipInfo(msk.get_crate(nn))
                 for nn in range(2, msk.get_ncrates()+1)]

    lower = 0.0
    upper = 999999.0
    for chip in chip_info:
        chip.compute_approximate_gain(evtfile, binsize)
        lower = max(lower, chip.energy_lo)
        upper = min(upper, chip.energy_hi)

    return {'lower_limit': lower,
            'upper_limit': upper,
            'details': chip_info}


@lw.handle_ciao_errors(TOOLNAME, __REVISION__)
def main():
    "Main routine"

    from ciao_contrib.param_soaker import get_params
    pars = get_params(TOOLNAME, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": VERB1})

    retval = check_pha_range(pars["infile"], mskfile=pars["mskfile"],
                             binsize=int(pars["binsize"]))

    for chip in retval["details"]:
        chip.summarize(chip == retval["details"][0])   # Print banner 1st chip

    VERB0("")
    out = "Conservative limits are {:.1f} to {:.1f} eV"
    VERB0(out.format(retval["lower_limit"], retval["upper_limit"]))


if __name__ == "__main__":
    main()
