#!/usr/bin/env python
#
# Copyright (C) 2014-2024 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"Combine map regions together based on some metric"


import sys
import numpy as np
import pycrates as pc

import ciao_contrib.logger_wrapper as lw
__toolname__ = "merge_too_small"
__revision__ = "14 February 2024"

verb0 = lw.initialize_logger(__toolname__).verbose0
verb1 = lw.initialize_logger(__toolname__).verbose1
verb2 = lw.initialize_logger(__toolname__).verbose2
verb3 = lw.initialize_logger(__toolname__).verbose3
verb5 = lw.initialize_logger(__toolname__).verbose5


def find_neighbors(mask, maskval):
    'Identify which map ids neighbor the current mask value'

    xxyy = np.where(mask == maskval)
    retvals = []

    # For each pixel with a given maskval value
    for ee in zip(xxyy[1], xxyy[0]):

        # Go +/- 1 in Y and X directions
        for dy, dx in [(ee[1]-1, ee[0]),
                       (ee[1]+1, ee[0]),
                       (ee[1], ee[0]-1),
                       (ee[1], ee[0]+1)]:

            if dy < 0 or dy >= mask.shape[0]:
                continue
            if dx < 0 or dx >= mask.shape[1]:
                continue

            if mask[dy, dx] != maskval and mask[dy, dx] != 0:
                retvals.append(int(mask[dy, dx]))

    retvals = set(retvals)
    return list(retvals)


def purge_too_small(mask, minarea, counts, joinfunc):
    """
    The idea is to check the area or total counts of each
    mask value.  If it is below the threshold then the map value
    is reassigned to the neighbor with the smallest area/counts.
    """

    mask_ids = list(np.unique(mask))
    mask_max = int(np.max(mask))
    skiplist = []   # The list of maskidvals that have been reassigned

    while True:
        # make histogram of pixel values.  If counts=None, then
        # histogram is the area (logical pixels), if counts=value, then
        # counts=counts (or flux or whatever units the input image is).

        hh = np.histogram(mask, bins=mask_max, range=(1, mask_max+1),
                          weights=counts)
        area = hh[0]
        maskid = hh[1][:-1].astype(int)

        am = list(zip(area, maskid))
        am = [x for x in am if x[1] not in skiplist]
        am = [x for x in am if x[1] in mask_ids]

        working_on = min(am)

        if working_on[0] > minarea:
            # When all the mapvals are above the threshold we break out
            # and return.
            break

        verb2(f"Working on mask_id {working_on[1]} with value {working_on[0]}")
        nn = find_neighbors(mask, working_on[1])
        if len(nn) == 0:
            # If there are no neighbors, then continue
            mask_ids.remove(working_on[1])
            continue

        nn = [i-1 for i in nn]  # convert mapvalue to histogram index

        zz = list(zip(area[nn], maskid[nn]))
        zz = [x for x in zz
              if (x[1] != working_on[1]) and (x[1] not in skiplist) and (x[1] in mask_ids)]

        if len(zz) > 0:
            join_index = joinfunc(zz)
            replace_val = int(join_index[1])
            mask[np.where(mask == int(working_on[1]))] = replace_val
            verb2(f"Replacing {working_on[1]} with {replace_val}")
        else:
            skiplist.append(working_on[1])  # There are no pixel with this mapval

        mask_ids.remove(working_on[1])


def parse_parameters(pars):
    '''
    Parse the input parameters that need parsing:

    '''
    if "counts" == pars["method"]:
        counts = pc.read_file(pars["imgfile"]).get_image().values
    elif "area" == pars["method"]:
        counts = None
    else:
        raise ValueError("Unknown value for method parameter")

    if "min" == pars["join"]:
        joinfunc = min
    elif "max" == pars["join"]:
        joinfunc = max
    else:
        raise ValueError("Unknown value for join parameter")

    return counts, joinfunc


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    """
    Main routine
    """

    # Load parameters
    from ciao_contrib.param_soaker import get_params
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    from ciao_contrib._tools.fileio import outfile_clobber_checks
    outfile_clobber_checks(pars["clobber"], pars["outfile"])
    outfile_clobber_checks(pars["clobber"], pars["binimg"])

    inimg = pc.read_file(pars["infile"])
    mask = inimg.get_image().values.astype(int)

    counts, joinfunc = parse_parameters(pars)

    purge_too_small(mask, int(pars["minvalue"]), counts, joinfunc)

    inimg.get_image().values = mask
    pc.write_file(inimg, pars["outfile"], clobber=pars["clobber"])

    from ciao_contrib.runtool import add_tool_history
    add_tool_history(pars["outfile"], __toolname__, pars,
                     toolversion=__revision__)

    if len(pars["binimg"]) > 0 and "none" != pars["binimg"].lower():
        from ciao_contrib.runtool import dmmaskbin
        dmmaskbin(pars["imgfile"], pars["outfile"]+"[opt type=i4]",
                  pars["binimg"], clobber=True)
        add_tool_history(pars["binimg"], __toolname__, pars,
                         toolversion=__revision__)


if __name__ == "__main__":
    main()
