#!/usr/bin/env python
#
# Copyright (C) 2014-2020, 2023
# Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Create bins based on voroni tesselation'


import os
import sys

import ciao_contrib.logger_wrapper as lw


__toolname__ = "vtbin"
__revision__ = "24 August 2023"

__lgr__ = lw.initialize_logger(__toolname__)
verb0 = __lgr__.verbose0
verb1 = __lgr__.verbose1
verb2 = __lgr__.verbose2
verb3 = __lgr__.verbose3
verb5 = __lgr__.verbose5


class CIAOTemporaryFile():
    """
    A little class to make sure that tmpfiles are forcefully removed at end.
    """

    def __init__(self, *args, **kwargs):
        from tempfile import NamedTemporaryFile
        self.tmpfile = NamedTemporaryFile(dir=os.environ["ASCDS_WORK_PATH"],
                                          delete=False, *args, **kwargs)
        self.name = self.tmpfile.name

    def __del__(self):
        self.tmpfile.close()
        if os.path.exists(self.name):
            os.remove(self.name)


def find_local_max(infile, mask="box(0, 0, 5, 5)"):
    """
    Identify the local peaks in the data.  Value must be the max() valein the
    mask region around the pixel (must be strictly > all surrounding pixels)
    """
    from ciao_contrib.runtool import dmimgfilt
    from ciao_contrib.runtool import dmimgblob

    verb1("Finding local maxima")

    t1 = CIAOTemporaryFile()
    t2 = CIAOTemporaryFile()
    nn = t2.name

    dmimgfilt(infile, t1.name, function="peak", mask=mask, clobber=True)
    dmimgblob(t1.name, nn, thresh=0, src=True, clobber=True)
    return t2


def make_mask(shape, radius):
    """
    This is the mask used to determine local max
    """

    if "circle" == shape:
        return f"circle(0, 0, {radius})"

    if "box" == shape:
        r = float(radius)*2.0
        return f"box(0, 0, {r}, {r})"

    raise RuntimeError(f"Unknown shape '{shape}'")


def cercle_circonscrit(T):
    """Find the center of a circle that circumscribes 3 points (triangle)

    The 3 points cannot be colinear or this blows up.

    https: //stackoverflow.com/questions/44231281/circumscribed-circle-of-a-triangle-in-2d
    """
    import numpy as np
    (x1, y1), (x2, y2), (x3, y3) = T
    A = np.array([[x3-x1, y3-y1], [x3-x2, y3-y2]])
    Y = np.array([(x3**2 + y3**2 - x1**2 - y1**2),
                  (x3**2+y3**2 - x2**2-y2**2)])
    if np.linalg.det(A) == 0:
        return False
    Ainv = np.linalg.inv(A)
    X = 0.5*np.dot(Ainv, Y)
    x, y = X[0], X[1]
    # r = sqrt((x-x1)**2+(y-y1)**2)
    return (x, y)   # , r


class Cell():
    """Class to hold Voronoi cells"""

    def __init__(self, x, y):
        'Hold info about cells'
        self.x = x
        self.y = y
        self.neighbors = []

    def calc_midpts(self):
        'Convert center of triangles into voronoi cells'
        from math import atan2
        self.midpts = []
        for n in self.neighbors:
            (midx, midy) = n
            ang = atan2(midy-self.y, midx-self.x)

            # Do not add duplicate points
            to_add = (ang, (midx, midy))
            if to_add not in self.midpts:
                self.midpts.append(to_add)

        self.midpts.sort()
        self.cellx = []
        self.celly = []
        for p in self.midpts:
            self.cellx.append(p[1][0])
            self.celly.append(p[1][1])


def make_vcells(x, y):
    """
    Use matplotlib's Traiangulation to create Delauny triangulation,
    then find centers of triangles to create Voronoi cells.
    """

    import matplotlib.tri as mtri
    tri = mtri.Triangulation(x, y)

    pts = [Cell(_x, _y) for _x, _y in zip(x, y)]

    for t in tri.get_masked_triangles():
        # Find center of triangle; add center to list of
        # neighbors at each vertex of the triangle
        midx, midy = cercle_circonscrit(list(zip(x[t], y[t])))
        pts[t[0]].neighbors.append((midx, midy))
        pts[t[1]].neighbors.append((midx, midy))
        pts[t[2]].neighbors.append((midx, midy))

    for p in pts:
        p.calc_midpts()

    return pts


def make_polygon(cell, img_shape):
    "Create polygon and compute extent bounded by image"
    import numpy as np
    from region import polygon

    try:
        p = polygon(cell.cellx, cell.celly)

        if len(p) == 0:
            raise RuntimeError("Bad polygon")
    except Exception as bad:
        print(f"Cellx: {cell.cellx}")
        print(f"Celly: {cell.celly}")
        raise bad

    # Loop over pixels just inside extent of polygon
    ext = p.extent()
    xl = max([0, int(ext['x0'])])
    xh = min([img_shape[1], int(ext['x1']+1)])
    yl = max([0, int(ext['y0'])])
    yh = min([img_shape[0], int(ext['y1']+1)])

    _yrange = np.arange(yl, yh)
    _xrange = np.arange(xl, xh)

    return p, _xrange, _yrange


def compute_vcells(infile, sitesfile, outfile, clobber):
    """
    Given a set of local max, grow the regions until they touch and
    cover the image
    """
    from crates_contrib.masked_image_crate import MaskedIMAGECrate
    import numpy as np

    verb1("Assigning pixels to maxima")

    # Load data
    img = MaskedIMAGECrate(sitesfile, mode="r")
    imgd = img.get_image()
    vv = imgd.values

    # Open infile to get subspace
    dss_img = MaskedIMAGECrate(infile, mode="r")

    # get non-zero pixels
    sites = np.argwhere(vv > 0)

    # get pixel value at all non-zero pixels
    sitesv = [vv[y, x] for y, x in sites]

    # get x, y coords of non-zero pixels
    xx = sites[:, 1]
    yy = sites[:, 0]

    # Compute V. cells.
    cells = make_vcells(xx, yy)

    # Now fill in the V. cells with the pixel value
    outvv = np.zeros_like(vv)
    for c, v in zip(cells, sitesv):

        try:
            p, _xrange, _yrange = make_polygon(c, vv.shape)
        except Exception:
            continue

        for _y in _yrange:
            for _x in _xrange:
                # Pixel already assigned, skip it
                if outvv[_y, _x] != 0:
                    continue

                # Pixel not in subspace, skip it
                if not dss_img.valid(_x, _y):
                    continue

                # Do the hard work
                if p.is_inside(_x, _y):
                    outvv[_y, _x] = v

    # Save the values
    imgd.values = outvv
    img.name = "tess"
    img.write(outfile, clobber=clobber)


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    """
    main routine
    """
    # Load parameters
    from ciao_contrib.param_soaker import get_params
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    from ciao_contrib._tools.fileio import outfile_clobber_checks
    outfile_clobber_checks(pars["clobber"], pars["outfile"])
    outfile_clobber_checks(pars["clobber"], pars["binimg"])

    mask = make_mask(pars["shape"], pars["radius"])

    if len(pars["sitefile"]) == 0 or "none" == pars["sitefile"].lower():
        tmpsitefile = find_local_max(pars["infile"], mask=mask)
        sitefile = tmpsitefile.name
    else:
        sitefile = pars["sitefile"]

    compute_vcells(pars["infile"], sitefile, pars["outfile"], pars["clobber"])

    from ciao_contrib.runtool import add_tool_history
    add_tool_history(pars["outfile"], __toolname__, pars,
                     toolversion=__revision__)

    if len(pars["binimg"]) > 0 and "none" != pars["binimg"].lower():
        from ciao_contrib.runtool import dmmaskbin
        dmmaskbin(pars["infile"], pars["outfile"]+"[opt type=i4]",
                  pars["binimg"], clobber=True)
        add_tool_history(pars["binimg"], __toolname__, pars,
                         toolversion=__revision__)


if __name__ == "__main__":
    main()
