#!/usr/bin/env python

# Copyright (C) 2013,2016,2022 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Look up PSF fraction inside of circular region'

import os
import sys

import ciao_contrib.logger_wrapper as lw
from pycrates import read_file

TOOLNAME = "src_psffrac"
__REVISION__ = "01 March 2022"

lw.initialize_logger(TOOLNAME)
VERB1 = lw.get_logger(TOOLNAME).verbose1
VERB2 = lw.get_logger(TOOLNAME).verbose2


def lookup_psffile_in_caldb(calfile):
    """
    Lookup REEF caldb file -- calfile itself not currently used
    """
    # pylint: disable=unused-argument
    
    import caldb4 as cldb

    c = cldb.Caldb("chandra", None, "REEF")
    myfile = c.search
    if len(myfile) != 1:
        raise IndexError("One 'REEF' CALDB file expected.  Found {}".format(len(myfile)))

    return myfile[0].split("[")[0]


def get_psf_fractions(psffile, theta, phi, energy, radii, pixsize):
    """
    Get PSF size for given fration/theta/phi
    """
    from psf import psfInit, psfFrac

    def initalize_psf(calfile):
        """ Open PSF file """
        if not calfile or calfile.startswith('CALDB'):
            calfile = lookup_psffile_in_caldb(calfile)
        psf = psfInit(calfile)
        return psf

    psf = initalize_psf(psffile)
    if len(theta) != len(phi):
        raise LookupError("Must have same number of theta & phi values")

    retvals = [psfFrac(psf, energy, tp[0], tp[1], pixsize*tp[2])
               for tp in zip(theta, phi, radii)]
    return retvals


def sky_to_cel(myfile, x_vals, y_vals):
    """
    Convert x,y to RA/DEC
    """
    eqpos_t = myfile.get_transform("eqpos")    # Must be chandra so OK

    ra = []
    dec = []
    for x, y in zip(x_vals, y_vals):
        sky = eqpos_t.apply([[x, y]])
        ra.append(sky[0][0])
        dec.append(sky[0][1])
    return (ra, dec)


def get_keys_from_crate(myfile):
    """
    Get all keys into dictionary
    """
    kk = myfile.get_keynames()
    vv = [myfile.get_key_value(k) for k in kk]
    mykeys = dict(zip(kk, vv))
    return mykeys


def parse_ebands(energy):
    """
    Parse energy band string
    """
    csc = {'broad': 2.3, 'soft': 0.92, 'ultrasoft': 0.4,
           'medium': 1.56, 'hard': 3.8, 'wide': 1.5}
    if energy in csc:
        return csc[energy]

    try:
        retval = float(energy)
        return retval
    except ValueError:
        pass    # okay,will try other combos

    if ':' not in energy:
        raise ValueError("Unknown energy {}".format(energy))

    ee = energy.split(":")
    if 3 != len(ee):
        raise ValueError("Bad energy range format {}".format(energy))

    try:
        retval = float(ee[2])
        return retval
    except ValueError:
        raise ValueError("Invalid mono energy {}".format(ee[2]))


def write_output(outfile, sky_x, sky_y, radii, ra_vals, dec_vals, coords, frac, mykeys):
    """
    Write output file
    """
    shape = ['circle'] * len(radii)
    component = list(range(1, len(radii)+1))
    bgfrac = [0.0] * len(radii)

    cnam = ['shape', 'x', 'y', 'r', 'component',
            'ra', 'dec', 'theta', 'phi', 'detx', 'dety', 'chip_id',
            'chipx', 'chipy', 'psffrac', 'bg_psffrac']
    cnam = [x.upper() for x in cnam]
    cval = [shape, sky_x, sky_y, radii, component, ra_vals, dec_vals,
            coords["theta"], coords["phi"], coords["detx"], coords["dety"],
            coords["chip_id"], coords["chipx"], coords["chipy"],
            frac, bgfrac]
    units = ['', 'pixel', 'pixel', 'pixel', '', 'deg', 'deg', 'arcmin',
             'deg', 'pixel', 'pixel', '', 'pixel', 'pixel', '', '']
    desc = ['region geometry',
            'X center',
            'Y center',
            'Radii',
            'Region Number',
            'Right Assention',
            'Declination',
            'Off axis angle',
            'Azimutla angle',
            'Detector X',
            'Detector Y',
            'Chip ID',
            'Chip X',
            'Chip Y',
            'PSF Fraction in Source region',
            'PSF Fraction in background region']

    import pycrates as pc
    tab = pc.TABLECrate()

    tab.name = 'REGION'
    for ii, nme in enumerate(cnam):
        cr = pc.CrateData()
        cr.name = nme
        cr.desc = desc[ii]
        cr.unit = units[ii]
        cr.values = cval[ii]
        tab.add_column(cr)

    for k in mykeys:
        ck = pc.CrateKey()
        ck.name = k
        ck.value = mykeys[k]
        tab.add_key(ck)

    tab.write(outfile=outfile, clobber=True)


def parse_regions(region_stk, infile):
    'Parse region strings'

    from ciao_contrib.runtool import dmmakereg, dmmerge
    from tempfile import NamedTemporaryFile

    import stk as stk

    regions = stk.build(region_stk)
    all_files = []
    for region in regions:

        # wrap in a region() if needed
        try:
            read_file(region)
            reg = f"region({region})"
        except:
            reg = region

        tmproot = NamedTemporaryFile(dir=os.environ["ASCDS_WORK_PATH"])
        tfile = tmproot.name
        tmproot.close()

        # convert to FITS and in physical coordinates

        dmmakereg.region = reg
        dmmakereg.outfile = tfile
        dmmakereg.wcsfile = infile
        dmmakereg.kernel = 'fits'
        dmmakereg.clobber = True

        VERB2(dmmakereg())

        all_files.append(tfile)

    tmproot = NamedTemporaryFile(dir=os.environ["ASCDS_WORK_PATH"])
    tfile = tmproot.name
    tmproot.close()

    VERB2(dmmerge(infile=all_files, outfile=tfile, clobber=True))

    for ff in all_files:
        os.remove(ff)

    tab = read_file(tfile)

    shapes = tab.get_column("shape").values.tolist()

    # shape must be a circle
    check = [x.lower() == 'circle' for x in shapes]
    if not all(check):
        raise ValueError("Only circle shapes are allowed")

    xx = tab.get_column("x").values.tolist()
    yy = tab.get_column("y").values.tolist()

    # r's are arrays, get 1st element from each row
    rr = tab.get_column("r").values.tolist()
    rr = [r[0] for r in rr]

    os.remove(tfile)

    return (xx, yy, rr)


#
# Main Routine
#
@lw.handle_ciao_errors(TOOLNAME, __REVISION__)
def main():
    'Main routine'
    from ciao_contrib._tools.fileio import outfile_clobber_checks
    from ciao_contrib.runtool import add_tool_history
    from coords.chandra import cel_to_chandra
    from ciao_contrib.param_soaker import get_params

    # get parameters
    pars = get_params(TOOLNAME, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": VERB1})

    outfile_clobber_checks(pars["clobber"], pars["outfile"])

    # Parse regions
    sky_x, sky_y, radii = parse_regions(pars["region"], pars["infile"])

    # Parse Energy
    energy = parse_ebands(pars["energy"])

    # Convert ra/dec to theta/phi and x/y
    myfile = read_file(pars["infile"])
    mykeys = get_keys_from_crate(myfile)
    mykeys["ENERGY"] = energy

    ra_vals, dec_vals = sky_to_cel(myfile, sky_x, sky_y)

    coords = cel_to_chandra(mykeys, ra_vals, dec_vals)

    # Get PSF size
    frac = get_psf_fractions(pars["psffile"], coords["theta"], coords["phi"],
                             energy, radii, coords["pixsize"])

    # Write output
    write_output(pars["outfile"], sky_x, sky_y, radii, ra_vals, dec_vals, coords, frac, mykeys)
    add_tool_history(pars["outfile"], TOOLNAME, pars, toolversion=__REVISION__)


if __name__ == "__main__":
    main()
