#!/usr/bin/env python

"""
ecf_calc

Purpose: Output a table of encircled energy fractions as a function of radius via dmellipse.

Input:
    event file
    radius that encloses the total counts
    fractions set of ECF values
    x,y sky positions of the source
    binsize
    output file name

Usage:
% ecf_calc evtfile outfile radius fraction xpos ypos binsize

"""

toolname = "ecf_calc"
__revision__ = "28 March 2023"

import numpy, tempfile, sys, os, paramio

import stk
import pycrates as pcr

from ciao_contrib.logger_wrapper import handle_ciao_errors
from ciao_contrib.param_wrapper import open_param_file
from ciao_contrib.cxcdm_wrapper import get_block_info_from_file
from ciao_contrib.runtool import dmhistory, dmellipse, new_pfiles_environment, add_tool_history


class ScriptError(RuntimeError):
    """Error found during running the script. This class is introduced
    in case there is a need to catch such an error and deal with it
    appropriately (e.g. recognize it as distinct from an error raised
    by the code).
    """
    pass


def error_out(msg):
    "Throw a ScriptError with msg as the message."
    raise ScriptError(msg)


def get_detnam(fname):
     """Get 'detnam' keyword from dictionary of header values from the
     'most-interesting-block' of the given file.
     """

     from ciao_contrib._tools.fileio import get_keys_from_file

     detnam = get_keys_from_file(fname)["DETNAM"]

     return detnam

################################################################

def radii_set(radius,bin):
    """return an array or radii"""

    steps = int(radius/bin)
    radii = [k*bin for k in range(1,steps+1)]

    return radii


def pause():
    input("Press ENTER to close the plot window and exit:")


def get_par(argv):
    """ Get data_products parameters from parameter file """

    pfile = open_param_file(argv,toolname=toolname)["fp"]

    # Common parameters:
    pars = {}

    pars["infile"] = paramio.pgetstr(pfile,"infile")
    if pars["infile"] == "":
        raise ValueError("The infile parameter is empty.")

    pars["outfile"] = paramio.pgetstr(pfile,"outfile")
    if pars["outfile"] == "":
        raise ValueError("The outfile parameter is empty.")

    pars["radius"] = paramio.pgetd(pfile,"radius")

    pars["xpos"] = paramio.pgetd(pfile,"xpos")

    pars["ypos"] = paramio.pgetd(pfile,"ypos")

    pars["binsize"] = paramio.pgetd(pfile,"binsize")

    pars["fraction"] = paramio.pgetstr(pfile,"fraction")

    # plot results?
    pars["plot"] = paramio.pgetb(pfile, "plot") == 1

    # set clobber of files
    pars["clobber"] = paramio.pgetb(pfile, "clobber") ==1

    if pars["clobber"] == 0:
        pars["clobber"] == False
    else:
        pars["clobber"] == True

    if pars["clobber"] == False and os.path.isfile(pars["outfile"]) == True:
        raise IOError("clobber=no and outfile already exists.")

    paramio.paramclose(pfile)
    return pars


def reg(radius,x,y,binsize,infile):
    """define region filter"""

    shape = "circle({0},{1},{2})".format(str(x),str(y),str(radius))

    # check for sky or det coords
    block = get_block_info_from_file(infile)

    if block[1]["type"] == "TABLE":
        if (("SKYPOS" in block[1]["columns"].keys()) or ("SKY" in block[1]["columns"].keys()) or ("sky" in block[1]["columns"].keys())) == True:
            region = "[sky={}]".format(shape)
            bin = "[bin sky={}]".format(str(binsize))

        else:
            region = "[detpos={}]".format(shape)
            bin = "[bin detpos={}]".format(str(binsize))

    elif block[1]["type"] == "IMAGE":
        if binsize < 1.0:
            error_out("binsize for an image must be an integer multiple of the pixel.")
        else:
            # check the logical binning factor of the image first, to avoid DM clash
            cr_img = pcr.read_file(infile)
            axisnames = cr_img.get_axisnames()

            for coord in ["eqpos","Eqpos","EQPOS","msc","MSC"]:
                try:
                    axisnames.remove(coord)
                    break
                except ValueError:
                    continue

            for axis in axisnames:
                transform = cr_img.get_transform(axis)

            pixelscale = numpy.mean(transform.get_parameter_value("SCALE"))

            if True in ["sky" in axisnames,"SKY" in axisnames,"Sky" in axisnames]:
                coords = "sky"
            elif len(axisnames) == 1:
                coords = axisnames[0]
            else:
                coords = "({0},{1})".format(axisnames[0],axisnames[1])

            region = "[{0}={1}]".format(coords,shape)
            bin = "[bin {0}={1}]".format(coords,str(binsize*pixelscale))

    return region+bin


def get_ecf(infile,reg,outfile,x,y,radius,binsize,fraction,clobber):

    ecf = tempfile.NamedTemporaryFile(suffix="_ecf.fits")

    if fraction.lower() in ["","none","indef"]:
        frac = "lgrid(0.01:1.0:0.01)" #",".join([str(i*0.01) for i in range(1,100)])
    else:
        frac = ",".join(stk.build(fraction))

    # run dmellipse to obtain a radii for a given set of ECFs
    with new_pfiles_environment(ardlib=False, copyuser=False):

         dmellipse.punlearn()
         dmellipse.infile = infile+reg
         dmellipse.outfile = ecf.name
         dmellipse.fraction = frac
         dmellipse.shape = "ellipse"
         dmellipse.fix_centroid = "yes"
         dmellipse.x_centroid = x
         dmellipse.y_centroid = y
         dmellipse.fix_ellipticity = "yes"
         dmellipse.ellipticity = "1"
         dmellipse.fix_angle = "yes"
         dmellipse.angle = "0"
         dmellipse.tolerance = "1e-3"
         dmellipse.minstep = "1e-3"
         dmellipse.maxwalk = "10"
         dmellipse.normalize = "yes"
         dmellipse.clobber = "yes"
         dmellipse.verbose = "1"

         dmellipse()

    file = pcr.read_file(ecf.name)

    rad = file.get_column("r").values

    colnames = file.get_colnames(vectors=True)
    colnames.remove("r")
    colnames.remove("fraction")

    if fraction != "":
        colnames.remove("shape")
        colnames.remove("rotang")
        colnames.remove("component")

    for col in colnames:
        try:
            file.delete_column(col)
            break
        except (ValueError,AttributeError):
            continue

    r_mid = [numpy.mean(bin) for bin in rad]

    # determine a set of fractions and interpolate to match
    # the evenly-spaced radii
    if fraction == "":
        from sherpa.utils import interpolate, linear_interp

        frac = file.get_column("fraction").values
        file.delete_column("fraction")

        radii = radii_set(radius,binsize)

        # force the fraction to 1.0 when encompassed by entire radius of interest to
        # place ceiling on interpolation
        frac = numpy.append(frac,[1.0])
        r_mid = numpy.append(r_mid,[radius])

        frac_interp = interpolate(numpy.array(radii),numpy.array(r_mid),frac,linear_interp)
        r_mid = radii

        col = pcr.CrateData()
        col.name = "fraction"
        col.values = frac_interp

        file.add_column(col)
        file.delete_column("r")
        del(col)

    # add 'requested fraction' column to results using the r_mid fraction values
    else:
        col = pcr.CrateData()
        col.name = "frac_in"
        col.values = [float(f) for f in frac.split(",")]
        col.desc = "requested fraction"

        file.add_column(col)
        del(col)

    col = pcr.CrateData()

    col.name = "r_mid"
    col.unit = "pixels"
    col.values = r_mid

    file.add_column(col)

    # add angular radii transform information to radii vector
    try:
        if file.get_key("TELESCOP").value.lower() == "chandra":
            import pytransform as pyt

            tr = pyt.LINEARTransform("angular_sep")
            tr.get_parameter("OFFSET").set_value(0)

            if file.get_key("DETNAM").value.lower().startswith("acis"):
                tr.get_parameter("SCALE").set_value(0.492)
            elif file.get_key("DETNAM").value.lower().startswith("hrc"):
                tr.get_parameter("SCALE").set_value(0.13175)

            trcol = pcr.CrateData()
            trcol.name = "cell_r_mid"
            trcol.unit = "arcsec"
            trcol._set_eltype(1)
            trcol.source = col
            trcol._set_transform(tr)

            file.add_column(trcol)

    except (AttributeError,KeyError): #as e:
        pass

    file.write(outfile,["clobber",clobber])

    ecf.close()


def plot_ecf(rad_file):

    import matplotlib
    from matplotlib import pyplot as plt
    from matplotlib.ticker import ScalarFormatter

    matplotlib.use("TkAgg")
    matplotlib.interactive(True)

    try:
        detector = get_detnam(rad_file)
    except KeyError:
        detector = ""

    file = pcr.read_file(rad_file)
    r_mid = pcr.get_colvals(file,"r_mid")
    ecf = pcr.get_colvals(file,"fraction")

    ecf[ecf <= 0 ] = numpy.nan # remove vertical dropline from the zero-values that would be plotted by matplotlib

    fig,axs = plt.subplots()
    axs.plot(r_mid,ecf,linestyle="solid",marker="D",
             markersize=3.5,fillstyle="none")
    axs.loglog()
    axs.set_ylabel("Encircled Counts Fraction")

    if "HRC" in detector:
        axs.set_xlabel("HRC Pixels")
    elif "ACIS" in detector:
        axs.set_xlabel("ACIS Pixels")
    else:
        axs.set_xlabel("Pixels")

    for axis in [axs.xaxis, axs.yaxis]:
        formatter = ScalarFormatter()
        formatter.set_scientific(False)
        axis.set_major_formatter(formatter)


############# Run Script #############


@handle_ciao_errors(toolname, __revision__)   # decorate top-level routine
def doit():
     params = get_par(sys.argv)

     infile = params["infile"]
     outfile = params["outfile"]
     radius = params["radius"]
     x = params["xpos"]
     y = params["ypos"]
     bin = params["binsize"]
     fraction = params["fraction"]
     plot = params["plot"]
     clobber = params["clobber"]

     filter = reg(radius,x,y,bin,infile)
     get_ecf(infile,filter,outfile,x,y,radius,bin,fraction,clobber)

     add_tool_history(outfile,toolname,params)

     if plot:
        if 'DISPLAY' not in os.environ or os.environ['DISPLAY'] == "":
            sys.stderr.write("DISPLAY not set, skipping plot.\n")
        else:
             plot_ecf(outfile)
             pause()


if __name__ == "__main__":
    doit()
