#!/usr/bin/env python
#
# Copyright (C) 2019-2022, 2024 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"Create a map of a certain statistic of a chosen column"


import sys
import os

import numpy as np
from pycrates import read_file

import ciao_contrib.logger_wrapper as lw
toolname = "statmap"
__revision__ = "18 July 2024"
lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2
verb3 = lgr.verbose3
verb4 = lgr.verbose4
verb5 = lgr.verbose5


class CIAONamedTemporaryFile():
    "CIAO named temp file, uses ASCDS_WORK_PATH and del when done."

    _tmpdir = os.environ["ASCDS_WORK_PATH"]

    def __init__(self):
        from tempfile import NamedTemporaryFile
        retval = NamedTemporaryFile(dir=self._tmpdir, delete=False)
        self.__tmp = retval

    @property
    def name(self):
        'Get file name'
        return self.__tmp.name

    def __del__(self):
        self.__tmp.close()
        if os.path.exists(self.__tmp.name):
            os.unlink(self.__tmp.name)


def weighted_median(vals, weights):
    'Weighted median'
    # from :https://stackoverflow.com/questions/20601872/numpy-or-scipy-to-calculate-weighted-median

    if len(vals) == 1:
        return vals

    if len(vals) == 2:
        return np.average(vals)

    quantiles = 0.5  # median
    idx = np.argsort(vals)
    csum = np.cumsum(weights[idx])
    quant = np.searchsorted(csum, quantiles * csum[-1])

    if quant == len(vals)-1:
        return vals[-1]

    return np.where(csum[quant]/csum[-1] == quantiles,
                    0.5 * (vals[idx[quant]] + vals[idx[quant+1]]),
                    vals[idx[quant]])


def weighted_mean(vals, weights):
    'Weighted mean'
    return np.average(vals, weights=weights)


def weighted_min(vals, weights):
    'Pick value with min weight'
    idx = np.argmin(weights, keepdims=True)
    return vals[idx[0]]


def weighted_max(vals, weights):
    'Pick value with max weight'
    idx = np.argmax(weights, keepdims=True)
    return vals[idx[0]]


def weighted_sum(vals, weights):
    'Return sum of weights'
    return np.sum(vals*weights)


def map_stat_function(stat):
    "convert stat name to a statistical function"

    do_stat = {'median': np.median,
               'max': np.max,
               'min': np.min,
               'mean': np.mean,
               'count': len,
               'sum': np.sum,
               'wmedian': weighted_median,
               'wmax': weighted_max,
               'wmin': weighted_min,
               'wmean': weighted_mean,
               'wsum': weighted_sum,
               }
    assert stat in do_stat, "Unknown statistic"
    return do_stat[stat]


def assign_mapid_to_events(evtfile, mapfile, column, xcol, ycol, wcol=None):
    "Lookup event locations in map"

    from ciao_contrib.runtool import make_tool

    verb2("Mapping events")

    tmpevt = CIAONamedTemporaryFile()

    pick = make_tool("dmimgpick")
    if wcol is None:
        pick.infile = evtfile+f"[cols {xcol},{ycol},{column}]"   # fewer cols = faster
    else:
        pick.infile = evtfile+f"[cols {xcol},{ycol},{wcol},{column}]"   # fewer cols = faster
    pick.imgfile = mapfile
    pick.outfile = tmpevt.name
    pick.method = "closest"
    pick.clobber = True
    pick()
    return tmpevt


def load_event_file(infile, column, xcol, ycol, wcol=None):
    "Load event file w/ map IDs"

    def find_map_column_name(evtfile):
        "find map column name by process of elimination"
        cols = evtfile.get_colnames(vectors=False)

        known_cols = [xcol, ycol, column]
        if wcol is not None:
            known_cols.append(wcol)

        map_col = [x for x in cols if x.lower() not in known_cols]

        assert len(map_col) == 1, "Somehow extra columns are present"
        return map_col[0]

    verb2("Loading event file")
    evtfile = read_file(infile)
    map_col = find_map_column_name(evtfile)
    map_vals = evtfile.get_column(map_col).values
    col_vals = evtfile.get_column(column).values
    wgt_vals = evtfile.get_column(wcol).values if wcol is not None else None
    return map_vals, col_vals, wgt_vals


def compute_stats(map_vals, col_vals, wgt_vals, func):
    "Compute stats for each mapID"

    verb2("Computing stats")
    stat_vals = {}

    # Ignore any NaN or Inf values
    good_map_vals = map_vals[np.isfinite(map_vals)]

    unique_pixel_vals = np.unique(good_map_vals)
    npix = len(unique_pixel_vals)
    verb3(f"Number of unique map values in event file: {npix}")

    for pixel_val in unique_pixel_vals:
        idx = np.where(map_vals == pixel_val)
        vals = col_vals[idx[0]]

        if len(vals) == 0:
            stat_vals[pixel_val] = 0
        elif wgt_vals is None:
            stat_vals[pixel_val] = func(vals)
        else:
            wvals = wgt_vals[idx[0]]
            stat_vals[pixel_val] = func(vals, wvals)

    return stat_vals


def replace_mapid_with_stats(stat_vals, mapfile):
    "Replace map values with stat value, same as dmmaskfill"

    verb2("Paint by numbers")

    verb2(f"Reading mapfile '{mapfile}'")
    mapimg = read_file(mapfile).get_image().values.copy()
    outvals = np.zeros_like(mapimg).astype(float)

    unique_pixel_vals = np.unique(mapimg)
    npix = len(unique_pixel_vals)
    verb3(f"Number of unique map values in mapfile: {npix}")

    for pixel_val in unique_pixel_vals:
        idx = np.where(mapimg == pixel_val)
        outvals[idx] = stat_vals[pixel_val] if pixel_val in stat_vals else np.nan

    return outvals


def write_output(outvals, mapfile, outfile, stat, column, clobber):
    "Write output"

    verb2(f"Writing output file '{outfile}'")
    outmap = read_file(mapfile)
    outmap.name = f"{stat}_{column}"
    outmap.get_image().values = outvals
    outmap.write(outfile, clobber=clobber)


def process_parameters():
    "Read parameters and do some manipulation handling."

    from ciao_contrib.param_soaker import get_params
    from ciao_contrib._tools.fileio import outfile_clobber_checks

    pars = get_params(toolname, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    pars["clobber"] = (pars["clobber"] == "yes")
    outfile_clobber_checks(pars["clobber"], pars["outfile"])

    if pars["wcolumn"].lower() in ["", "none"]:
        pars["wcolumn"] = None
    else:
        if not pars["statistic"].lower().startswith('w'):
            verb0("WARNING: Weight column specified but unweighted statistic chosen; ignoring weights")
            pars["wcolumn"] = None

    return pars


@lw.handle_ciao_errors(toolname, __revision__)
def main():
    "Main routine"

    pars = process_parameters()

    do_stat_func = map_stat_function(pars["statistic"])
    tmpevt = assign_mapid_to_events(pars["infile"], pars["mapfile"],
                                    pars["column"], pars["xcolumn"],
                                    pars["ycolumn"],
                                    pars["wcolumn"])
    map_vals, col_vals, wgt_vals = load_event_file(tmpevt.name, pars["column"],
                                                   pars["xcolumn"], pars["ycolumn"],
                                                   pars["wcolumn"])
    stat_vals = compute_stats(map_vals, col_vals, wgt_vals, do_stat_func)

    outvals = replace_mapid_with_stats(stat_vals, pars["mapfile"])

    write_output(outvals, pars["mapfile"], pars["outfile"],
                 pars["statistic"], pars["column"], pars["clobber"])

    from ciao_contrib.runtool import add_tool_history
    add_tool_history(pars["outfile"], toolname, pars,
                     toolversion=__revision__)


if __name__ == "__main__":
    main()
