#!/usr/bin/env python

#
#  Copyright (C) 2012, 2015, 2016
#            Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"""
Usage:

  update_column_range infile colnames [round=yes]

Aim:

  Update the TLMIN/MAX values of the columns (colnames) in
  the input filename (infile) so that it reflects the current
  data range. The colnames argument is a stack and can include
  the name of vector columns (e.g. sky).

  The file is edited in place.

  This is intended for use with the SKY column on a Chandra
  events file that has been reprojected with reproject_events,
  but it can be used for other cases.

  The TLMIN/MAX values are set to the union of the existing
  TLMIN/MAX range and the data range of the column. This means
  that the range can only increase, and not decrease. If round
  is set to yes then the ranges are rounded down/up to the
  nearest half an integer, as used by the Chandra coordinate
  systems (the SKY column for ACIS observations is limited
  by default to the range 0.5 to 8192.5).

"""


toolname = 'update_column_range'
revision = '12 September 2016'

import sys
import os
import time

import numpy as np
import paramio as pio
import cxcdm
import stk

# This is only needed for development.
try:
    if not __file__.startswith(os.environ['ASCDS_INSTALL']):
        _thisdir = os.path.dirname(__file__)
        _libname = "python{}.{}".format(sys.version_info.major,
                                        sys.version_info.minor)
        _pathdir = os.path.normpath(os.path.join(_thisdir, '../lib', _libname, 'site-packages'))
        if os.path.isdir(_pathdir):
            os.sys.path.insert(1, _pathdir)
        else:
            print("*** WARNING: no {}".format(_pathdir))

        del _libname
        del _pathdir
        del _thisdir

except KeyError:
    raise IOError('Unable to find ASCDS_INSTALL environment variable.\nHas CIAO been started?')

import ciao_contrib.cxcdm_wrapper as cw
import ciao_contrib.logger_wrapper as lw
import ciao_contrib.runtool as rt

from ciao_contrib.param_wrapper import open_param_file

lw.initialize_logger(toolname)
v1 = lw.make_verbose_level(toolname, 1)
v2 = lw.make_verbose_level(toolname, 2)
v3 = lw.make_verbose_level(toolname, 3)
v4 = lw.make_verbose_level(toolname, 4)
v5 = lw.make_verbose_level(toolname, 5)


def find_new_limits(bl, nrows, dss, roundval=True):
    """Given a list of data descriptors, return
    the new min/max values for each descriptor
    that needs changing.

    bl is the block descriptor and nrows is the
    number of rows to look through.

    The return value is a list of
       (colname, fmin, newmin, fmax, newmax)
    values for those descriptors that need changing
    (fmin/fmax indicate whether the min/max value needs changing).

    An empty list means no limits need changing.
    """

    # TODO: what happens for columns with no TLMIN/MAX?

    # Convert from tuple to array and insert descriptor
    def gr(ds):
        (tlo, thi) = cxcdm.dmDescriptorGetRange(ds)
        v5("Original TLMIN/MAX of {} is {} to {}".format(
            cxcdm.dmGetName(ds), tlo, thi))
        return [ds, tlo, False, thi, False]

    lims = [gr(ds) for ds in dss]

    # It would be nice to chunk this but I do not know
    # how to get the 'preferred' buffer size (or even if
    # it is possible).
    #
    for i in range(0, nrows):

        for lim in lims:
            dval = cxcdm.dmGetData(lim[0])
            if dval < lim[1]:
                lim[1] = dval
                lim[2] = True
            if dval > lim[3]:
                lim[3] = dval
                lim[4] = True

        cxcdm.dmTableNextRow(bl)

    # Do we round down or round up?
    if roundval:
        def qrl(val):
            bval = np.floor(val)
            if val - bval >= 0.5:
                nval = bval + 0.5
            else:
                nval = bval - 0.5
            return nval.astype(val.dtype)

        def qrh(val):
            bval = np.floor(val)
            if val - bval > 0.5:
                nval = bval + 1.5
            else:
                nval = bval + 0.5
            return nval.astype(val.dtype)

    else:
        def qrl(val):
            return val

        def qrh(val):
            return val

    out = [(cxcdm.dmGetName(ds), qrl(tmin), qrh(tmax))
           for (ds, tmin, f1, tmax, f2) in lims
           if f1 or f2]

    return out


def validate_legal_range(filename, colnames, roundval=True):
    """Edit the TLMIN/MAX values for the given columns
    so that they are the union of the existing TLMIN/MAX
    values and the actual range of the column.

    colnames is an array of column names, each element can
    refer to a vector column - e.g. SKY.

    The actual limits are rounded down/up to the nearest
    half integer value (so -1234.56:9123.45 is rounded to
    -1235.5:9123.5) when roundval is True.

    The file is edited in place. If the table contains no
    rows then no change is made. The return value is True
    unless no change were made.
    """

    # It looks like the DM doesn't recognize write-protected
    # files on open, so manually check.
    #
    if not os.access(filename, os.R_OK):
        raise IOError("Unable to find infile={}".format(filename))
    if not os.access(filename, os.W_OK):
        raise IOError("Unable to edit {} as write protected.".format(filename))

    bl = cxcdm.dmTableOpen(filename, update=False)
    nrows = cxcdm.dmTableGetNoRows(bl)
    v3("Opened {} and found {} rows.".format(filename, nrows))

    # special case the empty file
    if nrows == 0:
        cxcdm.dmTableClose(bl)
        return False

    try:
        colinfo = []
        cptnames = set()
        for colname in colnames:
            col = cw.open_column(bl, colname, filename)
            dss = [cxcdm.dmGetCpt(col, n) for n in
                   range(1, cxcdm.dmGetElementDim(col) + 1)]

            if len(dss) == 1:
                v3("Column {} contains 1 component.".format(colname))
            else:
                v3("Column {} contains {} components.".format(colname,
                                                              len(dss)))

            # Just in case a user has said colnames=sky,x
            for ds in dss:
                cptname = cxcdm.dmGetName(ds).lower()
                if cptname not in cptnames:
                    colinfo.append(ds)
                    cptnames.add(cptname)

        dlims = find_new_limits(bl, nrows, colinfo, roundval=roundval)
        colinfo = None

    finally:
        cxcdm.dmTableClose(bl)
        bl = None

    if dlims == []:
        v1("The TLMIN/MAX values of {} in {} do not need changing.".format(' '.join(colnames), filename))
        return False

    correct_ranges(filename, dlims)
    return True


def correct_ranges(filename, dlims):
    """Given one or more range that needs editing,
    use the cxcdm module to set the new values.
    """

    # We assume the r/w access hasn't changed during the run-time of
    # this script
    etime = time.asctime()
    bl = cxcdm.dmTableOpen(filename, update=True)
    v3("Opened {} for editing ({} columns).".format(filename, len(dlims)))
    try:
        for (cname, tmin, tmax) in dlims:
            v1("Editing column {} to have TLMIN={} and TLMAX={}".format(cname,
                                                                        tmin,
                                                                        tmax))
            ds = cw.open_column(bl, cname, filename)
            cxcdm.dmDescriptorSetRange(ds, tmin, tmax)

            # I originally used a HISTORY keyword for the following but
            # it doesn't seem to be retained by some downstream processing
            # so converting to a COMMENT line instead.
            txt = "{} range of {} now {}:{}".format(etime, cname, tmin, tmax)
            cxcdm.dmBlockWriteComment(bl, 'COMMENT', txt)

            ds = None

    finally:
        cxcdm.dmTableClose(bl)
        bl = None


def add_history_records(opts):
    """Update the HISTORY block of the file to show
    that the script was run."""

    v3("Adding {} HISTORY record to {}".format(toolname, opts['infile']))

    params = {'infile': opts['infile'],
              'columns': ','.join(opts['columns']),
              'round': opts['roundval'],
              'verbose': opts['verbose'] }
    rt.add_tool_history(opts['infile'], toolname, params,
                        toolversion=revision)


def process_command_line(argv):
    "Handle the parameter input for this script."

    if argv is None or argv == []:
        raise ValueError("argv argument is None or empty.")

    pinfo = open_param_file(argv, toolname=toolname)
    fp = pinfo["fp"]

    infile = pio.pget(fp, 'infile')
    if infile.strip() == '':
        raise ValueError('infile parameter is empty')

    columns = pio.pget(fp, 'columns')
    if columns.strip() == '':
        raise ValueError('columns parameter is empty')
    columns = stk.build(columns)

    roundval = pio.pgetb(fp, 'round') == 1
    verbose = pio.pgeti(fp, 'verbose')

    pio.paramclose(fp)
    lw.set_verbosity(verbose)

    return {'progname': pinfo['progname'],
            'parname': pinfo['parname'],
            'infile': infile,
            'columns': columns,
            'roundval': roundval,
            'verbose': verbose }


@lw.handle_ciao_errors(toolname, revision)
def update_column_range(args):
    "Run the tool"
    opts = process_command_line(args)

    v1("Running: {}".format(toolname))
    v1("  version: {}".format(revision))

    v2('Parameter values')
    v2('  infile:  ' + opts['infile'])
    v2('  columns: ' + " ".join(opts['columns']))
    if opts['roundval']:
        v2('  round:   yes')
    else:
        v2('  round:   no')

    flag = validate_legal_range(opts['infile'],
                                opts['columns'],
                                roundval=opts['roundval'])

    if flag:
        add_history_records(opts)


if __name__ == "__main__":
    update_column_range(sys.argv)
