#!/usr/bin/env python

# Copyright (C) 2005,2014,2016,2019
#               Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

import pycrates as pc
import numpy as np

import sys

import ciao_contrib.logger_wrapper as lw

toolname = "monitor_photom"
__revision__ = "22 January 2023"

lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2
verb3 = lgr.verbose3
verb5 = lgr.verbose5

#
#
# From the original S-Lang version:
#
#
#%
#% $Id: monitor_photom,v 1.0 2005/06/23 taldcroft Exp $
#%
#% Purpose: Generate a photometric light curve for a Chandra
#%	   target which was observed using an ACA monitor
#%	   window.
#%


def get_data( infile):
    """
    Load data
    """

    class FlatCrate():
        pass

    tab = pc.read_file(infile)

    # Make something crate-like to make translation from S-Lang easier
    # we multiply by 1.0 to detach values from crate
    dat = FlatCrate()
    dat.time = tab.get_column("time").values * 1.0
    dat.img_raw = tab.get_column("img_raw").values* 1.0
    dat.img_row0 = tab.get_column("img_row0").values* 1
    dat.img_col0 = tab.get_column("img_col0").values* 1
    dat.aca_comp_bkg_avg = tab.get_column("aca_comp_bkg_avg").values* 1.0

    dat.integ_time = tab.get_key_value("INTGTIME")* 1.0

    return dat


def cosmic_ray_removal( dat ):
    """

    %%--------------------------------------------------------------------------
    % Median filter the image data in time, taking care of the fact that the image
    % window can shift in position
    %%--------------------------------------------------------------------------

    """
    n = len(dat.img_row0)
    img_size = dat.img_raw[0].shape[0]
    img_size_range = list(range(img_size))
    plus_minus_two = list(range(-2,3))

    img_corr = dat.img_raw*5.0  # TBR:  Where does '5' come from?

    verb1("Filtering image data (cosmic ray removal)...")

    for i in range(n):
        if 0 == (i % 100):
            verb2("  image {} of {}".format(i+1,n))

        for r in img_size_range:         # Rows of one readout image
            for c in img_size_range:     # Cols of one readout image

                #
                # Median is taken from current image +/-2 images
                #
                samp = np.zeros(5, dtype=float)
                n_samp = 0

                for j in plus_minus_two:
                    ij = i+j
                    if ij >= 0 and ij < n:
                        r0 = r+dat.img_row0[i] - dat.img_row0[ij]  # Account for row offset between images
                        c0 = c+dat.img_col0[i] - dat.img_col0[ij]  # Account for column offset between images

                        if r0 >=0 and r0<img_size and c0 >=0 and c0 < img_size:
                            samp[n_samp] = img_corr[ij, r0, c0]
                            n_samp = n_samp+1

                if n_samp >= 3:
                    # Only compute median if at least 3 values
                    img_corr[i,r,c] = np.median( samp[0:n_samp] )

    return img_corr


def get_edge_pixels( img_size ):
    """
        %  Define the 'edge pixels' for detecting warm pixels
    """

    if 6 == img_size:
        r = [0, 0, 0, 0, 5, 5, 5, 5, 1, 2, 3, 4, 1, 2, 3, 4]
        c = [1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 5, 5, 5, 5]
    elif 8 == img_size:
        r = [0, 0, 0, 0, 0, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 6, 1, 6 ]
        c = [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 7, 7, 7, 7, 7, 7, 1, 1, 6, 6 ]
    else:
        raise RuntimeError("Image data must be 6x6 or 8x8")

    return r,c


def stack_dark_current_data(dat, img_corr, rc0 ):
    """

        %---------------------------------------------------------------------
        % Make a stack of dark current measurements
        %---------------------------------------------------------------------

    """
    n = len(dat.img_row0)
    img_size = dat.img_raw[0].shape[0]
    sz = 2*rc0 + img_size  # Size of pixel region on CCD that completely contains all mon. window data

    r,c = get_edge_pixels( img_size )
    n_r = range(len(r))

    # output arrays
    dark = np.zeros( [n, sz, sz], dtype=float)
    i_dark = np.zeros( [n, sz, sz], dtype=np.int32)  # Not used anywhere, keep for now
    n_dark = np.zeros( [sz, sz], dtype=np.int32 )

    verb1("Stacking dark current data...")
    for i in range(n):
        if 0 == (i % 100):
            verb2("  image {} of {}".format(i+1,n))

        rowoff = rc0 + dat.img_row0[i] - dat.img_row0[0]  # Account for shift between images
        coloff = rc0 + dat.img_col0[i] - dat.img_col0[0]  # Account for shift between images

        for j in n_r:
            c0 = c[j] + coloff
            r0 = r[j] + rowoff

            if c0 < sz and c0 >= 0 and r0 < sz and r0 >= 0:
                nd = n_dark[r0, c0]
                dark[nd, r0, c0] = img_corr[i, r[j], c[j]]
                i_dark[nd, r0, c0] = i
                n_dark[r0, c0]   = nd + 1
            else:
                verb0("Warning: pixel ({},{},{})) has out-of-bounds (c0,r0)=({},{})".format(i, c[j], r[j], c0, r0))

    return dark, i_dark, n_dark


def make_median_dark_current( dat, img_corr, n_dark, dark, pars ):
    """

    %---------------------------------------------------------------------
    % Make the median dark current for each pixel with at least
    % min_dark_measurements measurements
    %---------------------------------------------------------------------

    """
    min_dark_measurements = int(pars["min_dark_meas"])
    min_dark_limit = float(pars["min_dark_limit"])
    dark_ratio = float(pars["dark_ratio"])        # hot limit as a fraction of image counts
    rc0 = int(pars["max_dither_motion"]) # Maximum possible dither motion in ACA pixels

    img_size = dat.img_raw[0].shape[0] # must be at least 1 row.
    sz = 2*rc0 + img_size  # Size of pixel region on CCD that completely contains all mon. window data

    avg_cts = np.sum(img_corr) / len(dat.time)
    warm_dark_limit = max([min_dark_limit, avg_cts * dark_ratio]);

    verb1( "Average counts  (e-) = {}".format( avg_cts))
    verb1( "Warm dark limit (e-) = {}".format(warm_dark_limit))

    median_dark = np.zeros( [sz, sz], dtype=float) - 9999.0
    sz_range = range(sz)

    for c0 in sz_range:
        for r0 in sz_range:

            if n_dark[r0,c0] <= min_dark_measurements:
                continue

            median_dark[r0,c0] = np.median( dark[ range(n_dark[r0,c0]), r0, c0] )
            if median_dark[r0,c0] > warm_dark_limit:
                r_ccd = int(r0 - rc0 + dat.img_row0[0])
                c_ccd = int(c0 - rc0 + dat.img_col0[0])
                verb1( "Warm pixel at CCD (row,col) = ({},{})\t Dark current (e-) = {}".format( r_ccd, c_ccd, median_dark[r0,c0]))

    # Calculate the median reported dark current
    avg_median_dark = np.median( dat.aca_comp_bkg_avg )

    # Use the median dark current reported by ACA for all pixels that are not "warm"
    ok = np.where(median_dark < warm_dark_limit)
    median_dark[ok] = avg_median_dark

    return median_dark


def subtract_dark( dat, img_corr, rc0, median_dark ):
    """
    Now subtract dark current from each image

    """

    n = len(dat.img_row0)
    img_size = dat.img_raw[0].shape[0]
    img_size_range = range(img_size)

    img_sub = np.zeros_like( img_corr )

    for i in range(n):
        rowoff = rc0 + dat.img_row0[i] - dat.img_row0[0] # Account for image offsets
        coloff = rc0 + dat.img_col0[i] - dat.img_col0[0] # Account for image offsets

        for c in img_size_range:
            for r in img_size_range:
                c0 = c+coloff
                r0 = r+rowoff
                img_sub[i,r,c] = img_corr[i,r,c]-median_dark[r0,c0]

    return img_sub


def write_output(dat, img_sub, outfile, clobber ):
    """
    """

    from crates_contrib.utils import write_columns

    counts = np.sum( img_sub, axis=(1,2) )
    mag0 = 10.32                # e-
    cnt_rate_mag0 = 5263.0       # e-/sec
    min_cnt_rate  = 10.0        # min allowed count rate (e-/sec)
    cnt_rate = counts / dat.integ_time
    delta_time = dat.time - dat.time[0]
    mag = cnt_rate
    mag[np.where(cnt_rate < min_cnt_rate)] = min_cnt_rate
    mag = mag0 - 2.5 * np.log10(mag/cnt_rate_mag0)

    clb = ('yes' == clobber)
    write_columns( outfile, dat.time, counts, cnt_rate, mag, img_sub,
        colnames=["time", "counts", "count_rate", "mag", "img_sub"],
        clobber=clb, format="fits" )


@lw.handle_ciao_errors( toolname, __revision__)
def main():
    """
    """
    from ciao_contrib.param_soaker import get_params
    from ciao_contrib._tools.fileio import outfile_clobber_checks
    from ciao_contrib.runtool import add_tool_history

    pars = get_params(toolname, "rw", sys.argv,
         verbose={"set":lw.set_verbosity, "cmd":verb1} )

    outfile_clobber_checks(pars["clobber"], pars["outfile"] )
    rc0 = int(pars["max_dither_motion"]) # Maximum possible dither motion in ACA pixels

    dat = get_data( pars["infile"] )

    img_corr = cosmic_ray_removal( dat )

    dark, i_dark, n_dark = stack_dark_current_data( dat, img_corr, rc0)

    median_dark = make_median_dark_current( dat, img_corr, n_dark, dark, pars )

    img_sub = subtract_dark( dat, img_corr, rc0, median_dark )

    write_output( dat, img_sub, pars["outfile"], pars["clobber"] )

    add_tool_history( pars["outfile"], toolname, pars, toolversion=__revision__)


if __name__ == "__main__":
    try:
        main()
    except Exception as E:
        print("\n# "+toolname+" ("+__revision__+"): ERROR "+str(E)+"\n", file=sys.stderr)
        sys.exit(1)
    sys.exit(0)
