#!/usr/bin/env python
#
# Copyright (C) 2022-2024 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Create a region based on a PSF simulated at the input location'

import sys
import os
import numpy as np
from pycrates import read_file, CrateKey
from ciao_contrib.runtool import make_tool
from region import CXCRegion, ellipse, circle
import ciao_contrib.logger_wrapper as lw


MIN_FRACTION = 0.6
FRACTION_STEP = 0.1

toolname = "psf_contour"
__revision__ = "15 July 2024"

lw.initialize_logger(toolname)
verb0 = lw.get_logger(toolname).verbose0
verb1 = lw.get_logger(toolname).verbose1
verb2 = lw.get_logger(toolname).verbose2


class PSF_Region():
    'Base class for region objects. This handles the PSF simulation'

    psf_suffix = ".psf"
    smpsf_suffix = ".smpsf"
    rays_suffix = "_projrays.fits"
    out_suffix = "_src.reg"

    def __init__(self, params):
        'constructor'
        from copy import deepcopy
        self.params = deepcopy(params)
        self.coords = None
        self.psf_file = None
        self.outfile = None

    def get_coords(self):
        'Convert RA/Dec to chandra coordinates'

        from ciao_contrib._tools.fileio import get_keys_from_file
        from coords.chandra import cel_to_chandra

        keys = get_keys_from_file(self.params['infile'])
        self.coords = cel_to_chandra(keys, self.params['ra'], self.params['dec'])

    def make_psf(self):
        'Simulate PSF and smooth it'
        verb1(f"Making PSF for {self.params['outroot']}")
        marx_psf = self.sim_psf()
        self.smooth_psf(marx_psf)

    def sim_psf(self):
        'Simulate the PSF with MARX'

        sim = make_tool("simulate_psf")
        sim.infile = self.params['infile']
        sim.outroot = self.params['outroot']
        sim.ra = self.params['ra']
        sim.dec = self.params['dec']
        sim.spectrum = ""
        sim.monoenergy = self.params['energy']
        sim.flux = self.params['flux']
        sim.minsize = 256
        sim.random_seed = self.params['random_seed']
        sim.marx_root = self.params['marx_root']
        sim()

        if not os.path.exists(f"{self.params['outroot']}{self.psf_suffix}"):
            raise RuntimeError("Whoops")
        psf = f"{self.params['outroot']}{self.psf_suffix}"

        return psf

    def smooth_psf(self, psf):
        'Smooth the PSF, taken from acis_extract'

        # From AE:
        # radius50 = (0.85 -0.25 *off_angle + 0.10 *off_angle^2) * arcsec_per_skypixel  ; arcsec
        # arcsec_per_sigma    = 0.1 * (radius50 / 0.85)
        # PSFpixel_per_sigma  = arcsec_per_sigma / arcsec_per_PSFpixel ; PSF pixel

        theta = self.coords["theta"][0]

        radius50 = (0.85 - (0.25 * theta) + (0.10 * theta**2)) * 0.492
        arcsec_per_sigma = 0.1 * (radius50 / 0.85)
        psfpixel_per_sigma = arcsec_per_sigma / 0.492

        conv = make_tool("aconvolve")
        conv.infile = psf
        conv.outfile = f"{self.params['outroot']}{self.smpsf_suffix}"
        conv.kernelspec = f"lib:gaus(2, 5, 5, {psfpixel_per_sigma}, {psfpixel_per_sigma})"
        conv(method="slide", edge="const", const=0, clobber=True)

        self.psf_file = conv.outfile

    def make_region(self):
        "Create the region"
        raise NotImplementedError("Implement in the derived classed")


class PolygonContour(PSF_Region):
    'Create a polygon that encloses fraction of the PSF'

    def __init__(self, params):
        'This class will use dmcontour to create the contour polygon'

        super().__init__(params)
        self.make_dmtool = self._make_dmcontour_tool
        self.set_threshold = self._set_dmcontour_level

    def _make_dmcontour_tool(self):
        'create the tool instance'
        contour = make_tool("dmcontour")
        return contour

    def _set_dmcontour_level(self, contour, value):
        'set the parameter value'
        contour.levels = float(value)

    def make_region(self):
        'Create the region (just make the contour)'
        self.make_contour()

    def filter_contour(self):
        '''Filter the contour file.

        The dmcontour output may contain excluded regions and may contain
        small isolated small clusters of pixels above the specified threshold.

        This grabs just the single largest, inclusive polygon an uses that.
        '''

        regions = CXCRegion(self.outfile)
        out_reg = None
        max_area = 0.0

        for reg in regions:
            if reg.shapes[0].include.val != 1:
                continue

            if reg.area() > max_area:
                out_reg = reg
                max_area = reg.area()

        if out_reg is None:
            raise RuntimeError("Problem finding correct polygon")

        out_reg.write(self.outfile, fits=True, clobber=True)

    def make_contour(self):
        'Create contour to enclose fraction of PSF'

        frac = self.params['fraction']
        tol = self.params['tolerance']

        pixels = read_file(self.psf_file).get_image().values
        flat = pixels.flatten()
        pixels_sorted = np.sort(flat)

        cumulative = np.cumsum(pixels_sorted)
        cumulative /= cumulative[-1]
        idx = np.argwhere(cumulative >= (1-frac))[0][0]

        contour = self.make_dmtool()
        contour.infile = self.psf_file
        contour.outfile = f"{self.params['outroot']}{self.out_suffix}"
        self.outfile = contour.outfile

        stat = make_tool("dmstat")
        stat.infile = f"{self.psf_file}[sky=region({contour.outfile})]"

        psffrac = 0
        numiter = 0

        while (psffrac < frac - tol) | (psffrac > frac + tol):
            numiter += 1

            self.set_threshold(contour, float(pixels_sorted[idx]))
            contour(clobber=True)

            self.filter_contour()

            stat(centroid=False, sigma=False, median=False)
            psffrac = float(stat.out_sum)

            if numiter > 20:
                verb1(f"Too many iterations for {self.outfile}, go with what we've got. psffrac={psffrac}")
                break

            if psffrac < frac:
                idx -= 1
                if idx < 0:
                    print("Fraction is too low")
                    break
            elif psffrac > frac:
                idx += 1
                if idx >= len(pixels_sorted):
                    print("Fraction is too high")
                    break


class PolygonLasso(PolygonContour):
    'Same as contour, but use dmimglasso'

    def __init__(self, params):
        'As above, use dmimglasso'

        super().__init__(params)
        self.make_dmtool = self._make_dmimglasso_tool
        self.set_threshold = self._set_dmimglasso_level

    def _make_dmimglasso_tool(self):
        'Create the dmimglasso tool and set parameters'

        contour = make_tool("dmimglasso")
        contour.xpos = self.coords["x"][0]
        contour.ypos = self.coords["y"][0]
        contour.high_value = "INDEF"
        contour.coord = "physical"
        contour.value = "absolute"
        return contour

    def _set_dmimglasso_level(self, contour, value):
        'Set the threshold level'

        contour.low_value = float(value)


class EllipseFitToPolygon(PolygonContour):
    'Fit an ellipse to the dmcontour, contour'

    def make_region(self):
        'method to create the region'

        self.make_contour()
        self.fit_contour()

    @staticmethod
    def fit_ellipse(x, y):
        '''Given a set of x, y, finds the ellipse that fits the data, in the
        least squares sense.

        Adapted from:  https://github.com/ndvanforeest/fit_ellipse
        '''

        def _ellipse_center(a):
            'Compute center from fit parameters'
            b, c, d, f, g, a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
            num = b*b-a*c
            x0 = (c*d-b*f)/num
            y0 = (a*f-b*d)/num
            return np.array([x0, y0])

        def _ellipse_angle_of_rotation(a):
            'Compute angle from fit parameters'
            b, c, d, f, g, a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
            retval = np.rad2deg(0.5*np.arctan(2*b/(a-c)))
            return retval if retval > 0 else retval+180.0

        def _ellipse_axis_length(a):
            'Compute major/minor axes from fit parameters'
            b, c, d, f, g, a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
            up = 2*(a*f*f+c*d*d+g*b*b-2*b*d*f-a*c*g)
            down1 = (b*b-a*c)*((c-a)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
            down2 = (b*b-a*c)*((a-c)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
            res1 = np.sqrt(up/down1)
            res2 = np.sqrt(up/down2)
            return np.array([res1, res2])

        from numpy.linalg import eig, inv
        x = x[:, np.newaxis]
        y = y[:, np.newaxis]
        D = np.hstack((x*x, x*y, y*y, x, y, np.ones_like(x)))
        S = np.dot(D.T, D)
        C = np.zeros([6, 6])
        C[0, 2] = C[2, 0] = 2
        C[1, 1] = -1
        E, V = eig(np.dot(inv(S), C))
        n = np.argmax(np.abs(E))
        a = V[:, n]

        xc, yc = _ellipse_center(a)
        mjr, mnr = _ellipse_axis_length(a)
        ang = _ellipse_angle_of_rotation(a)
        return (xc, yc, mjr, mnr, ang)

    def fit_contour(self, numiter=1):
        '''Fit the contour with an ellipse'''
        from numpy.linalg import LinAlgError

        tab = read_file(self.outfile)
        xx = tab.get_column('X').values[0]
        yy = tab.get_column('Y').values[0]

        _x = np.array([x for x in xx if np.isfinite(x)])
        _y = np.array([y for y in yy if np.isfinite(y)])

        # problems with singular matrices so we pre-condition the data
        # by subtracting off the mean.   Add it back when writing region.
        # as the region becomes closer to circular, the greater chance of
        # getting a singular matrix when fitting to an ellipse (angle
        # becomes unconstrained).
        x_0 = np.mean(_x)
        y_0 = np.mean(_y)
        _x -= x_0
        _y -= y_0

        try:
            xc, yc, mjr, mnr, angle = self.fit_ellipse(_x, _y)
            ell = ellipse(xc+x_0, yc+y_0, mjr, mnr, angle)
            self.outfile = f"{self.params['outroot']}{self.out_suffix}"
            ell.write(self.outfile, fits=True, clobber=True)

        except LinAlgError as bad_fit:
            verb0("Bad fit, trying again")
            if numiter < 5:
                print(self.params)
                self.params["flux"] *= 1.5
                self.make_contour()
                self.fit_contour(numiter+1)
            else:
                verb0("Problem fitting polygon data with an ellipse, using a circle instead")
                radius = np.mean(np.hypot(_x, _y))
                ell = circle(x_0, y_0, radius)
                self.outfile = f"{self.params['outroot']}{self.out_suffix}"
                ell.write(self.outfile, fits=True, clobber=True)


class ConvexHull(EllipseFitToPolygon):
    'Instead of fitting with an ellipse, lets make a convex hull around it'

    def fit_contour(self):
        'Find convex hull around the contour instead of fitting it'

        hull = make_tool("dmimghull")
        hull.infile = self.psf_file+f"[(x,y)=region({self.outfile})]"
        hull.outfile = self.outfile
        hull(clobber=True, tolerance=0)


class ECF_Ellipse(PSF_Region):
    'Instead of any contour, lets just get an ellipse directly'

    def make_region(self):
        'Find ellipse fixed at input location'

        dmellipse = make_tool("dmellipse")
        dmellipse.infile = self.psf_file
        dmellipse.outfile = f"{self.params['outroot']}{self.out_suffix}"
        dmellipse.fraction = self.params['fraction']
        dmellipse.fix_centroid = True
        dmellipse.x_centroid = self.coords['x'][0]
        dmellipse.y_centroid = self.coords['y'][0]
        dmellipse.tolerance = self.params['tolerance']
        dmellipse(clobber=True)

        self.outfile = dmellipse.outfile


def get_cli():
    'Get parameter values from .par file'

    from ciao_contrib.param_soaker import get_params
    pars = get_params(toolname, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    for to_real in ['energy', 'fraction', 'tolerance', 'flux']:
        pars[to_real] = float(pars[to_real])

    if pars["marx_root"] == '':
        from shutil import which
        marx = which('marx')
        if marx is None:
            raise ValueError("ERROR: marx_root parameter cannot be blank;" +
                             " set to the directory where MARX is installed")
        bindir = os.path.dirname(marx)               # ie ".."
        pars["marx_root"] = os.path.dirname(bindir)  # another ".."

    if not os.path.isdir(pars["marx_root"]):
        raise ValueError("ERROR: marx_root parameter must be set to " +
                         "the top level directory where MARX is installed")

    marx_exe = os.path.join(pars["marx_root"], "bin", "marx")
    if not os.path.isfile(marx_exe):
        raise ValueError(f"ERROR: Cannot find marx executable in '{marx_exe}'" +
                         "; check the marx_root parameter")

    return pars


def get_positions(params):
    'Get the list of input positions'
    from ciao_contrib.parse_pos import get_radec_from_pos

    try:
        ra, dec = get_radec_from_pos(params['pos'])
        ra = np.array(ra)
        dec = np.array(dec)
    except Exception as bad_parse:
        raise ValueError(f"Could not parse coordinates: {params['pos']}") from bad_parse

    return (ra, dec)


def shrink_psfs(first_psf, second_psf):
    'We shrink the PSFs until the area of the overlapping regions is 0'
    verb1(f"Shrinking regions for {first_psf.params['outroot']} and {second_psf.params['outroot']}")

    keep_going = True
    while keep_going:
        keep_going = False

        if first_psf.params['fraction'] > MIN_FRACTION + FRACTION_STEP:
            first_psf.params['fraction'] -= FRACTION_STEP
            first_psf.make_region()
            keep_going = True

        if second_psf.params['fraction'] > MIN_FRACTION + FRACTION_STEP:
            second_psf.params['fraction'] -= FRACTION_STEP
            second_psf.make_region()
            keep_going = True

        overlap = CXCRegion(first_psf.outfile) * CXCRegion(second_psf.outfile)
        if overlap.area(bin=0.25) == 0:
            return

    # Only get here if both files are at min_frac
    #
    # If that's the case, then we stop shrinking regions and just exclude
    # them from each other.

    first_reg = CXCRegion(first_psf.outfile)
    second_reg = CXCRegion(second_psf.outfile)

    new_first = first_reg - second_reg[0]
    new_first.write(first_psf.outfile, fits=True, clobber=True)

    new_second = second_reg - first_reg[0]
    new_second.write(second_psf.outfile, fits=True, clobber=True)


def shrink_overlaps(psfs):
    '''The AE approach is to "shrink" overlapping regions.  This
    is done by recomputing them using a lower threshold (ie smaller
    PSF fraction).   But it will stop if it goes below a threshold
    and just exclude overlapping regions'''

    verb1("Checking for region overlaps")

    num_psf = len(psfs)

    for ii in range(num_psf):
        this_region = CXCRegion(psfs[ii].outfile)

        for jj in range(ii+1, num_psf):
            that_region = CXCRegion(psfs[jj].outfile)

            overlap = this_region * that_region
            if overlap.area(bin=0.25) == 0:
                continue

            shrink_psfs(psfs[ii], psfs[jj])


def write_history(psfs, pars):
    'Write history -- using original pars not local copy'
    from ciao_contrib.runtool import add_tool_history
    from math import isclose

    verb1("Writing outputs")

    fov = None
    if pars['fovfile'].lower() not in ['', 'none']:
        fov = CXCRegion(pars['fovfile'])

    for psf in psfs:
        # Intersect the region w/ the FOV file
        if fov:
            # So it turns out that the region lib will include both the
            # region and the fov, if their bounding boxes overlap.
            #
            # That means that all sources inside the fov will have the FOV
            # included in the region file -- which really isn't what we want.
            # We really only want the FOV when it clips the region.
            #
            # Next problem is that region areas are computed differently
            # for simple regions vs. composite|complex regions.  So
            # just doing a simple area-check doesn't work because one
            # will may be computed analytically and the other computed
            # by pixelating.  Even if we force both to be pixelated, we can't
            # control the grid (only the step size) which can still lead
            # to diffs in the area.  Argh.
            #
            # So, we use the math.isclose() to check if the values are
            # within 10% of each other, with a half pixel absolute
            # tolerance.
            #
            # We could do something like check all the polygon end-points
            # are inside the FOV, but this has to work for ellipses too
            #

            src = CXCRegion(psf.outfile)
            src_and_fov = src * fov
            src_area = src.area(bin=0.25)
            saf_area = src_and_fov.area(bin=0.25)
            if not isclose(src_area, saf_area, rel_tol=0.1, abs_tol=0.5):
                src_and_fov.write(psf.outfile, fits=True, clobber=True)

        add_tool_history(psf.outfile, toolname, pars,
                         toolversion=__revision__)
        tab = read_file(psf.outfile, mode="rw")
        key = CrateKey()
        key.name = "PSFFRAC0"
        key.value = psf.params["fraction"]
        key.desc = "Target PSF fraction"
        tab.add_key(key)
        tab.write()


def run_one_pos(pars):
    'Create the PSF and run algorithm for 1 source at a time'

    params, ra, dec, iter_num = pars

    if params['method'] == "contour":
        psf = PolygonContour(params)
    elif params['method'] == "lasso":
        psf = PolygonLasso(params)
    elif params['method'] == "fitted_ellipse":
        psf = EllipseFitToPolygon(params)
    elif params['method'] == "ecf_ellipse":
        psf = ECF_Ellipse(params)
    elif params['method'] == "convex_hull":
        psf = ConvexHull(params)
    else:
        raise ValueError(f"Unknown method: '{params['method']}'")

    psf.params['ra'] = ra
    psf.params['dec'] = dec
    psf.params['outroot'] = f"{psf.params['outroot']}_i{iter_num:04d}"

    psf.get_coords()
    psf.make_psf()
    psf.make_region()

    return psf


@lw.handle_ciao_errors(toolname, __revision__)
def main():
    'Main routine'

    params = get_cli()
    ra, dec = get_positions(params)

    pars = []
    for ii, _rd in enumerate(zip(ra, dec)):
        _ra, _dec = _rd
        pars.append((params, _ra, _dec, ii+1))

    if params["parallel"] == "yes":
        from sherpa.utils import parallel_map

        if "INDEF" == params["nproc"]:
            psfs = parallel_map(run_one_pos, pars)
        else:
            psfs = parallel_map(run_one_pos, pars,
                                numcores=int(params["nproc"]))
    else:
        psfs = list(map(run_one_pos, pars))

    shrink_overlaps(psfs)
    write_history(psfs, params)


if __name__ == '__main__':
    main()
