#!/usr/bin/env python
#
# Copyright (C) 2020 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"Covert region from physical to celestial coordinates"

import sys
import ciao_contrib.logger_wrapper as lw

toolname = "regphystocel"
__revision__ = "19 Aug 2020"

lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2

from pycrates import read_file, get_transform

class PhysicalRegion(object):
    """Physical Region object"""

    def __init__(self, infile, wcsfile, text, tags):
        """Load region and get wcs xfrom"""
        self.read_region(infile, text, tags)

        if len(wcsfile)>0 and wcsfile.lower() != "none":
            self.get_xform(wcsfile)
        else:
            self.get_xform(infile)

    def read_region(self, infile, text, tags):
        "Load region"
        from region import CXCRegion

        if "'" in infile:
            raise ValueError("ERROR: Found \"'\" in region string; make sure radii are in physical pixels")
        if '"' in infile:
            raise ValueError("ERROR: Found '\"' in region string; make sure radii are in physical pixels")

        self.physical_region = CXCRegion(infile)
        self.load_tags_and_text( infile, text, tags )


    def load_tags(self, tab, tags):
        "Tags, aka groups in ds9"

        num_shapes = len(self.physical_region.shapes)

        self.tags = [""]*num_shapes
        for tag in tags:
            if tab.key_exists(tag):
                tagvals = " tag={{{}={}}}".format(tag,tab.get_key_value(tag))
                tagvals = [tagvals]*num_shapes
            elif tab.column_exists(tag):
                tagvals = [" tag={{{}={}}}".format(tag,x) for x in tab.get_column(tag).values]
            else:
                verb0("Could not find key or column '{}' for tags attribute".format(tag))
                tagvals = [""]*num_shapes

            for ii in range(num_shapes):
                self.tags[ii] = self.tags[ii]+tagvals[ii]


    def load_text(self, tab, text):
        "Text to display above shapes"

        num_shapes = len(self.physical_region.shapes)
        if tab.key_exists(text):
            self.text = "text={{{}}}".format(tab.get_key_value(text))
            self.text = [self.text]*num_shapes
        elif tab.column_exists(text):
            self.text = [ "text={{{}}}".format(x) for x in tab.get_column(text).values ]
        else:
            verb0("Could not find key or column '{}' for text attribute".format(text))
            self.text = None


    def load_tags_and_text(self, infile, text, tags):
        "Process tags and text requests"

        if text == "" or text.lower() == "none":
            if tags == "" or tags.lower() == "none":
                self.tags = None
                self.text = None
                return

        try:
            tab = read_file(infile)
        except OSError as ee:
            if 'does not exist' in str(ee):
                verb0("WARNING: text and tags can only be used with FITS region files.")
                self.tags = None
                self.text = None
                return
            raise ee

        if text != "" and text.lower() != "none":
            self.load_text( tab, text)
        else:
            self.text = None

        if tags != "" and tags.lower() != "none":
            import stk as stk
            self.load_tags( tab, stk.build(tags))
        else:
            self.tags = None


    def get_xform(self, infile):
        "Get WCS transform: try eqpos then eqsrc, then barf"
        input_table = read_file(infile)

        try:
            self.xform = get_transform(input_table, "eqpos")
        except ValueError as missing_col:
            if 'does not exist' not in str(missing_col):
                raise missing_col
            try:
                self.xform = get_transform(input_table, "eqsrc")
            except ValueError as missing_col:
                if 'does not exist' not in str(missing_col):
                    raise missing_col
                else:
                    raise IOError("Could not find EQPOS nor EQSRC coordinate transform in '{}'".format(infile))

        delta = [x.get_value() for x in self.xform.get_parameter_list()
                 if "delt" in x.get_name().lower()]

        if len(delta) == 0:
            raise RuntimeError("Cannot find CDELT transform parameter in {}".format(infile))

        if abs(delta[0][0]) != abs(delta[0][1]):
            raise RuntimeError("Must have same X and Y pixel size (square pixels)")

        self.delta = abs(delta[0][0])*3600.0      # arcsec

    @staticmethod
    def trim_coordinate(coord, digits=5):
        """Remove excess precision"""
        dot = coord.index(".")
        return coord[:dot+digits]

    @staticmethod
    def write_shape(inc, shape, pos, rad, angle):
        """Write string based on shape name"""
        retval = "{}{}({}".format(inc, shape, pos)

        if shape in ['point', 'polygon']:
            retval = retval + ")"
        elif shape in ['circle', 'box', 'annulus']:
            retval = retval + ",{})".format(rad)
        elif shape in ['ellipse', 'rotbox']:
            retval = retval+",{},{})".format(rad, angle)
        elif shape in ['field', 'pie', 'sector']:
            verb0("Skipping {}".format(shape))
            retval = None
        else:
            raise ValueError("Unsupported shape {}".format(shape))

        return retval

    def process_position(self, shape):
        """Convert position from physical to celestial in HMS"""

        from coords.format import deg2ra, deg2dec

        pos = list(zip(shape.xpoints, shape.ypoints))
        radec = self.xform.apply(pos)

        cel = []
        for rd in radec:
            _r = deg2ra(rd[0], ":")
            _r = self.trim_coordinate(_r)

            _d = deg2dec(rd[1], ":")
            _d = self.trim_coordinate(_d)

            _c = "{},{}".format(_r, _d)
            if _c.index(":") == 1:
                _c = '0'+_c         # Add a leading 0, looks better IMO
            cel.append(_c)

        cel = ",".join(cel)
        return cel

    def process_radii(self, shape):
        """Convert radii from phys pixels to arcsec"""
        rr = shape.radii
        if rr is not None:
            rr = rr * self.delta
            rr = ['{:g}"'.format(x) for x in rr]
            rr = ",".join(rr)
        return rr

    def process_angles(self, shape):
        """Process angles, just convert to str"""
        aa = shape.angles
        if aa is not None:
            aa = ['{:g}'.format(x) for x in aa]
            aa = ",".join(aa)
        return aa

    def process_shape(self, shape):
        """Convert shape parameters to celestial notation"""

        shape_name = shape.name.lower()
        if shape_name not in ['circle', 'box', 'rotbox', 'ellipse',
                              'polygon', 'rectangle', 'point', 'annulus']:
            verb0("Skipping unsupported shape={}".format(shape_name))
            return None

        inc = "-" if shape.include.str == "!" else ""
        cel = self.process_position(shape)
        rr = self.process_radii(shape)
        aa = self.process_angles(shape)

        cel_region = self.write_shape(inc, shape_name, cel, rr, aa)
        return cel_region

    def run(self):
        """Loop over all shapes in the region """

        shapes = self.physical_region.shapes

        doit = []
        for ii in range(len(shapes)):
            shape = self.process_shape(shapes[ii])
            text = self.text[ii] if self.text else ""
            tags = self.tags[ii] if self.tags else ""
            line = "{} # {} {}".format(shape, text, tags)
            doit.append(line)

        valid = [x for x in doit if x is not None]
        self.cel_region_str = "\n".join(valid)

    def write(self, outfile, clobber=True):
        """Write output"""
        
        if outfile == "-":
            fp = sys.stdout
        else:
            from ciao_contrib._tools.fileio import outfile_clobber_checks
            outfile_clobber_checks(clobber, outfile)
            fp=open(outfile, "w")

        fp.write('# Region file format: DS9 version 4.1\n')
        fp.write('global color=green dashlist=8 3 width=1 font="helvetica 10 normal roman" select=1 highlite=1 dash=0 fixed=0 edit=1 move=1 delete=1 include=1 source=1\n')
        fp.write('fk5\n')
        fp.write(self.cel_region_str)
        fp.write("\n")


@lw.handle_ciao_errors(toolname, __revision__)
def main():
    'doit'

    from ciao_contrib.param_soaker import get_params
    pars = get_params(toolname, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    pr = PhysicalRegion(pars["infile"], pars["wcsfile"],
                        pars["text"], pars["tag"])
    pr.run()
    pr.write(pars["outfile"], pars["clobber"])


if __name__ == "__main__":
    main()
