#!/usr/bin/env python
#
# Copyright (C) 2022-2024 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Create background annulus regions with a fixed number of counts'

import sys
import numpy as np
from pycrates import read_file
import ciao_contrib.logger_wrapper as lw


toolname = "bkg_fixed_counts"
__revision__ = "28 March 2024"

lw.initialize_logger(toolname)
verb0 = lw.get_logger(toolname).verbose0
verb1 = lw.get_logger(toolname).verbose1
verb2 = lw.get_logger(toolname).verbose2


def load_positions(pos):
    "Get the RA/Dec of all sources"
    from ciao_contrib.parse_pos import get_radec_from_pos

    verb1("Loading source positions")
    try:
        ra, dec = get_radec_from_pos(pos)
        ra = np.array(ra)
        dec = np.array(dec)
    except Exception as bad_parse:
        raise ValueError(f"Could not parse coordinates: {pos}") from bad_parse

    return ra, dec


def convert_coords(infile, ra_deg, dec_deg):
    "Convert from RA/Dec to x,y and theta/phi"

    from ciao_contrib._tools.fileio import get_keys_from_file
    from coords.chandra import cel_to_chandra

    verb1("Convert coordinates")

    keys = get_keys_from_file(infile)
    coords = cel_to_chandra(keys, ra_deg, dec_deg)
    return coords


def get_psf_size(coords, energy, inner_ecf):
    "Get the PSF size at the requested energy and ECF, convert to pixels"

    from ciao_contrib.psf_contrib import PSF
    psf = PSF()

    verb1("Getting PSF size for inner radius")

    coords['inner_rad_pix'] = []
    for theta, phi in zip(coords['theta'], coords['phi']):
        rad_arcsec = psf.psfSize(energy, theta, phi, inner_ecf)
        rad_pix = rad_arcsec / coords['pixsize']
        coords['inner_rad_pix'].append(rad_pix)


def make_swiss_cheese(coords, src_region_list):
    "Exclude the all the PSF ECF circles from the field"
    from region import field, circle, CXCRegion

    verb1("Building excluded region expression")

    exclude_region = field()

    inner_srcs = CXCRegion()
    for xpos, ypos, rad in zip(coords['x'], coords['y'], coords['inner_rad_pix']):
        inner_srcs += circle(xpos, ypos, rad)
    exclude_region = field() - inner_srcs

    srcregs = None
    if src_region_list.lower() not in ['', 'none']:
        import stk
        src_stk = stk.build(src_region_list)

        srcregs = CXCRegion()
        for src_file in src_stk:
            src = CXCRegion(src_file)

            # skip if srcreg is null
            if len(src) == 0:
                continue

            # Why [0]?  (See long comment in psf_contour re:
            # region area and FOV boundaries.
            #
            # if the FOV is included in the source region then
            # the exclusion operation below will take a REALLY, REALLY
            # long time (hours) .. and yet it won't matter since it's
            # intersected with the interesting shape. We know that the
            # interesting source region is the 1st shape, so we'll just
            # go ahead and only consider the 1st shape here:
            srcregs += src[0]

        exclude_region -= srcregs

    return exclude_region, srcregs


def filter_events(infile, exclude_region):
    "Filter the event file to remove all the source regions, only bkg events remain"

    verb1("Filtering events with excluded region expression")

    tab = read_file(infile)
    xpos = tab.get_column('x').values
    ypos = tab.get_column('y').values

    bkg_x = []
    bkg_y = []
    for _xpos, _ypos in zip(xpos, ypos):
        if exclude_region.is_inside(_xpos, _ypos):
            bkg_x.append(_xpos)
            bkg_y.append(_ypos)

    return np.array(bkg_x), np.array(bkg_y)


def find_outer_bkg_radius(coords, bkg_x, bkg_y, numcts, maxradius):
    """Determine the radius around each source that encloses fixed number of counts.

    With the sources purged, we just need to compute the distance
    from each sources' center, sort, and then select the N-th distance
    which is the outer radius of the annulus.  Kind a brute force,
    but it runs quickly.
    """

    verb1("Determining outer background radius")

    nbkg = len(bkg_x)
    if nbkg < numcts:
        pass

    if (numcts/float(nbkg)) > 0.1:
        pass

    coords['outer_rad_pix'] = []
    for xpos, ypos, rad_in in zip(coords['x'], coords['y'], coords['inner_rad_pix']):
        dist = np.hypot(xpos-bkg_x, ypos-bkg_y)
        sdist = np.sort(dist)
        outer_rad = sdist[numcts]

        if maxradius is not None and outer_rad > maxradius:
            if maxradius < rad_in:
                raise ValueError("Maximum radius is less than the inner PSF radius")
            outer_rad = maxradius

        coords['outer_rad_pix'].append(outer_rad)


def _remove_inner_circle(background, inner_circle):
    """The ecf circle with the radius equal to the annulus'
    inner radius is always excluded.  We don't want that.
    """
    from region import CXCRegion, opAND, opOR, opNOOP
    idx = background.index(-inner_circle)

    for ii in range(len(background)):
        if ii in idx:
            continue
        if background.shapes[ii].logic == opNOOP:
            assert ii == 0, "Undefined region logic found"
            return_value = CXCRegion(background.shapes[ii])
        elif background.shapes[ii].logic == opOR:
            return_value = return_value + CXCRegion(background.shapes[ii])
        elif background.shapes[ii].logic == opAND:
            return_value = return_value * CXCRegion(background.shapes[ii])
        else:
            raise ValueError("Unknown region logic")

    return return_value


def write_output(input_pars, coords, src_regions):
    """Write output files

    One file for each source, since nearby sources are also excluded
    """

    from region import annulus, CXCRegion, circle
    from ciao_contrib.runtool import add_tool_history
    from math import isclose

    verb1("Writing outputs")

    xpos = coords['x']
    ypos = coords['y']
    rad0 = coords['inner_rad_pix']
    rad1 = coords['outer_rad_pix']

    del_srcs = CXCRegion()
    for pars in zip(xpos, ypos, rad0):
        del_srcs += circle(*pars)

    if src_regions:
        del_srcs += src_regions

    fov = None
    if input_pars['fovfile'].lower() not in ['', 'none']:
        fov = CXCRegion(input_pars['fovfile'])

    for idx, pars in enumerate(zip(xpos, ypos, rad0, rad1)):
        out_region = annulus(*pars)-del_srcs
        inner_circle = circle(pars[0], pars[1], pars[2])

        if fov:
            # So it turns out that the region lib will include both the
            # region and the fov, if their bounding boxes overlap.
            #
            # That means that all background inside the fov will have the FOV
            # included in the region file -- which really isn't what we want.
            # We really only want the FOV when it clips the region.
            #
            # (see longer discussion in psf_contour)

            bkg_and_fov = out_region * fov
            bkg_area = out_region.area(bin=0.25)
            baf_area = bkg_and_fov.area(bin=0.25)
            if not isclose(bkg_area, baf_area, rel_tol=0.1, abs_tol=0.5):
                out_region = bkg_and_fov

        out_region = _remove_inner_circle(out_region, inner_circle)
        cmpt = idx+1
        outfile = f"{input_pars['outroot']}_i{cmpt:04d}_bkg.reg"
        out_region.write(outfile, fits=True, clobber=True)
        add_tool_history(outfile, toolname, input_pars,
                         toolversion=__revision__)


def get_cli():
    'Get parameter values from .par file'

    from ciao_contrib.param_soaker import get_params
    pars = get_params(toolname, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    pars['min_counts'] = int(pars['min_counts'])
    pars['max_radius'] = None if pars['max_radius'] == "INDEF" else float(pars['max_radius'])
    pars['inner_ecf'] = float(pars['inner_ecf'])
    pars['energy'] = float(pars['energy'])

    return pars

@lw.handle_ciao_errors(toolname, __revision__)
def main():
    'Main routine'

    pars = get_cli()

    ra_deg, dec_deg = load_positions(pars["pos"])
    coords = convert_coords(pars['infile'], ra_deg, dec_deg)
    get_psf_size(coords, pars['energy'], pars['inner_ecf'])
    excl, srcregs = make_swiss_cheese(coords, pars["src_region"])
    bkg_x, bkg_y = filter_events(pars['infile'], excl)
    find_outer_bkg_radius(coords, bkg_x, bkg_y, pars['min_counts'],
                          pars['max_radius'])

    if pars["max_radius"] is None:
        pars["max_radius"] = "INDEF"
    write_output(pars, coords, srcregs)


if __name__ == '__main__':
    main()
