#!/opt/conda/envs/ciao-4.17.0/bin/python3

# 
#  Copyright (C) 2009-2012,2016-2018,2022-2023  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


import os
import re
import sys
import string
import numpy
import pycrates
import paramio
import math
import subprocess

import stk
from sherpa.utils import is_binary_file

import lev3_mho
import lev3_iss

import logging

'''Use the history module to add HISTORY records to crate'''
from pycrates import read_file 
from pycrates import add_history, CrateKey
from history import HistoryRecord


def set_verbose(verb):
#    lev3_mho.set_verbose(verb)
    if verb > 0:
        lev3_mho.set_verbose(verb-1)
    else:
        lev3_mho.set_verbose(verb)
    log = logging.getLogger(__name__)
    log.setLevel(lev3_mho.verbosity_conversion[verb])
    del log

def get_verbose():
    log = logging.getLogger(__name__)
    level = log.level
    del log
    return level


def filter_filename(filename):
    """Filters DM virtual file specifications (eg energy filters) from filenames.

    Parameters
    ----------
    filename : str
      A filename that may have DM virtual file specifications.

    Returns
    -------
    str
      The filename without the DM virtual file specifications.
 
    Examples
    --------
    >>> filter_filename('acisf04425_000N003_evt1.fits"[energy=500:7000]"')
    acisf04425_000N003_evt1.fits

    """
    filename = re.sub(r'"|\'', '', filename)  # Remove " & ' characters from string.
    filename = re.sub(r'\[.+]', '', filename) # Remove [] substrings
    return filename


def stk_build(stack_str: str):
    """Lite wrapper of :func:`stk.build` that catches the ValueError for empty input lists and the Exception for an empty input string.
    
    Parameters
    ----------
    stack_str : str
      A .lis file or comma-seperated string of file names.

    Returns
    -------
    list
      A list of input file names. If the input list file or string is empty, the returned list will have a single empty string element.

    """
    try:
        files = stk.build(stack_str)
    except (ValueError, Exception):
        files = [""]
    return files


def get_parameters(filename):
    params = {}

    args = sys.argv[:]
    if len(args) > 1 and args[1].startswith("@@"):
        try:
            pfile = paramio.paramopen(None, "rw", args)
        except:
            raise IOError("could not open parameter file '%s'" % filename)
    else:
        try:
            pfile = paramio.paramopen(filename, "rwL", sys.argv)
        except:
            raise IOError("could not open parameter file '%s'" % filename)

    for key in ["srcfile", "psffile", "regfile", "outfile"]:
        pkey = key+'s'        # srcfiles, outfiles ...
        params[pkey] = paramio.pgetstr(pfile,key)
        params[pkey] = stk_build(params[pkey])

        params[key] = paramio.pgetstr(pfile,key)   #sat-151

    params["x0"] = None
    if paramio.pgetstr(pfile,"x0").strip().upper() != "INDEF":
        params["x0"] = paramio.pgetd(pfile,"x0")

    params["y0"] = None
    if paramio.pgetstr(pfile,"y0").strip().upper() != "INDEF":
        params["y0"] = paramio.pgetd(pfile,"y0")

    params["shape"]       = paramio.pgetstr(pfile,"shape")
    params["srcsize"]     = paramio.pgetd(pfile,"srcsize")
    params["theta"]       = paramio.pgetd(pfile,"theta")
    params["imgsize"]     = paramio.pgetd(pfile,"imgsize")
    params["binfactor"]   = paramio.pgeti(pfile,"binfactor")
    params["mincounts"]   = paramio.pgeti(pfile,"mincounts")
    params["minthresh"]   = paramio.pgeti(pfile,"minthresh")
    params["sigmafactor"] = paramio.pgetd(pfile,"sigmafactor")
    params["psfblur"]     = paramio.pgetd(pfile,"psfblur")

    params["clobber"]     = paramio.pgetb(pfile, "clobber")
    params["verbose"]     = paramio.pgeti(pfile, "verbose")
    params["mode"]        = paramio.pgetstr(pfile, "mode")

    clob    = params["clobber"]
    verbose = params["verbose"]

    paramio.paramclose(pfile)

    shape=False
    if params["shape"].startswith("gaussian"):
        shape=True
    params["shape"]=shape

    if len(params["srcfiles"]) == 1 and params["srcfiles"][0] == "":
        raise IOError("no source file found")

    if len(params["outfiles"]) == 1 and params["outfiles"][0] == "":
        raise IOError("no output file found")

    if len(params["srcfiles"]) != len(params["outfiles"]):
        raise IOError("number of source files must match number of output files")

    for outfile in params["outfiles"]:
        if outfile == "":
            raise IOError("please enter an output file")
        if os.path.isfile(outfile):
            if not bool(clob):
                raise IOError("output file '%s' exists and clobber is not set." % outfile)
            else:
                # Remove outfile if it exists
                os.remove(outfile)

    return (params, verbose)


def find_brightest_pixel(pix):
    x0, x1 = numpy.where(pix==pix.max())

    # if there is more than one brightest pixel, choose the one in the middle
    # of the lot.
    if len(x0) > 1:
        pivot = len(x0)//2
        x0 = x0[pivot]
        x1 = x1[pivot]

    x0 = int(x0)
    x1 = int(x1)
    return (x0, x1)


def check_normalized_psf(pix, num):

    if num > 0.0 and num <= 1.0:
        pix = pix.ravel()
        lows = pix[pix>0.0]
        val = 1.
        if len(lows) > 0:
            val = lows.min()
        if val != 0.0:
            pix = pix / val
            num = pix.sum()

    return num


def get_half_max(half_max_val, x, x_max, range):
    for ii, val in enumerate(range):
        if val <= half_max_val:
            return numpy.abs(x[ii] - x_max)
    return x_max


def get_fwhm(vals):
    y = numpy.asarray(vals)
    x = numpy.arange(1, len(vals)+1,dtype=float)

    half_max_val = y.max() / 2.0

    pivot = y.argmax()
    x_max = x[pivot]

    lhs = get_half_max(half_max_val, x[pivot:], x_max, y[pivot:])
    rhs = get_half_max(half_max_val, x[:pivot][::-1], x_max, y[:pivot][::-1])

    fwhm = (lhs+rhs)
    if fwhm > 1:
        fwhm -= 1

    return fwhm


def read_image(img):
    #ysize, xsize = img.get_dimarr()
    #y = img.get_image_values()
    y = pycrates.get_piximgvals(img).squeeze()
    ysize, xsize = y.shape

    # ..if image axes are even length, adjust axes to be odd
    if ysize % 2 == 0:
        ysize -= 1
    if xsize % 2 == 0:
        xsize -= 1
    y = y[0:ysize,0:xsize]

    ax0 = numpy.arange(1,xsize+1, dtype=float)
    ax1 = numpy.arange(1,ysize+1, dtype=float)
    x0, x1 = numpy.meshgrid(ax0, ax1)
    x0 = x0.ravel()
    x1 = x1.ravel()

    skynames = ['SKY', 'sky', 'pos', 'POS']
    names = img.get_axisnames()

    # find the SKY name using the set intersection
    inter = list(set(names) & set(skynames))

    if not inter:
        raise AttributeError("sky coordinates not found")

    sky = img.get_transform(inter[0])

    # sky = img.get_transform("sky")
    # if sky is None:
    #     sky = img.get_transform("SKY")
    # if sky is None:
    #     raise AttributeError("sky coordinates not found")

    vals = sky.apply(numpy.array([x0,x1]).transpose())
    x0, x1 = vals.transpose()

    dummy0 = numpy.ones(ax0.size, dtype=float)
    dummy1 = numpy.ones(ax1.size, dtype=float)

    vals = sky.apply(numpy.array([ax0,dummy0]).transpose())
    ax0, dummy = vals.transpose()

    vals = sky.apply(numpy.array([dummy1,ax1]).transpose())
    dummy, ax1 = vals.transpose()

    return x0, x1, ax0, ax1, y


def calc_num_counts_tbl(file, reg):
    # tbl = pycrates.TABLECrate(file + reg)
    # if tbl.get_column("X") is None:
    #     counts = 0
    # else:
    #     counts = tbl.get_column("X").get_values().size

    try:
        tbl = pycrates.read_file(file + reg)
    except IOError:
        return 0
        
    counts = 0
    if pycrates.col_exists(tbl, "X"):
        counts = pycrates.get_colvals(tbl, "X").size
    return counts


def calc_num_counts_img(file, reg):

    #img = pycrates.IMAGECrate(file + reg)
    img = pycrates.read_file(file + reg)

    # FIXME What to do if the image is normalized PSF??

    #pix = img.get_image_values()
    pix = pycrates.get_piximgvals(img)

    if pix is None:
        return 0

    num = check_normalized_psf(pix, pix.sum())

    return num

#    if pix is None:
#        counts = 0
#    else:
#        counts = pix.sum()
#        if counts <= 1.0:
#            i = pix.nonzero()
#            if len(i) > 1:
#                pix = pix / pix[i].min()
#                counts = pix.sum()
#
#    return counts


def image_guess(ax0, ax1, y):

    argmax = y.argmax()
    x0 = ax0[argmax]
    y0 = ax1[argmax]

    shape = y.shape
    i0, i1 = find_brightest_pixel(y)
    s0 = y[ 0:shape[0], i1 ]  # choose image column with brightest pixel
    s1 = y[ i0, 0:shape[1] ]  # choose image row with brightest pixel
    fwhm0 = get_fwhm(s0)
    fwhm1 = get_fwhm(s1)
    # take geometric average of FWHMs from slices.
    size = numpy.sqrt(numpy.square(fwhm0)+numpy.square(fwhm1))
#   size = (fwhm0+fwhm1)/2.0

    return (x0,y0,size)


def parse_regfile(regfile, _to_arcsecs=1.0):
    if is_binary_file(regfile):
        return parse_binary_regfile(regfile, _to_arcsecs)
    return parse_ascii_regfile(regfile, _to_arcsecs)


def _check_for_arcsecs(val=None, _to_arcsecs=1.0):
    if "\"" in val:
        val = float(val.strip("\""))
        val /= _to_arcsecs
    else:
        val = float(val)
    return val

def parse_ascii_regfile(regfile, _to_arcsecs=1.0):

    f = open(regfile, "r")
    lines = f.readlines()
    f.close()

    if len(lines) > 1:
        line = lines.pop(0)
        while lines and line[0] == '#':
            line = lines.pop(0)

        if line == "":
            raise IOError("could not parse ASCII region file")
    else:
        raise IOError("could not parse ASCII region file")

    # Determine the region shape
    start = line.index('(')
    shape = line[:start].strip(string.punctuation + string.digits +
                                string.whitespace).lower()
    stop = line.index(')')

    params = line[start+1:stop].split(',')

    if shape == 'circle':
        x0, y0, size = params
        x0 = float(x0)
        y0 = float(y0)
        size = _check_for_arcsecs(size, _to_arcsecs)

    elif shape == 'ellipse':
        x0, y0, r1, r2, ang = params
        x0 = float(x0)
        y0 = float(y0)
        r1 = _check_for_arcsecs(r1, _to_arcsecs)
        r2 = _check_for_arcsecs(r2, _to_arcsecs)
        size = numpy.sqrt(numpy.sum(numpy.square([r1,r2])))/numpy.sqrt(2.)

    else:
        raise IOError("cannot use '%s' shape," % shape.upper() +
                      " please use CIRCLE or ELLIPSE")

    return (x0,y0,size)


def parse_binary_regfile(regfile, _to_arcsecs=1.0):

    reg = pycrates.read_file(regfile)
    if pycrates.get_crate_type(reg) != "Table":
        raise IOError('region file is not a table')

    if not pycrates.col_exists(reg, "x"):
        raise IOError('region file does not contain a SKY X coordinate')

    if not pycrates.col_exists(reg, "y"):
        raise IOError('region file does not contain a SKY Y coordinate')

    if not pycrates.col_exists(reg, "r"):
        raise IOError('region file does not contain a radius')


    xx = reg.get_column('x')
    yy = reg.get_column('y')
    rr = reg.get_column('r')

    # if xx is None:
    #     raise IOError('region file does not contain a SKY X coordinate')

    # if yy is None:
    #     raise IOError('region file does not contain a SKY Y coordinate')

    # if rr is None:
    #     raise IOError('region file does not contain a radius')

    rr_unit = ("arcsec" in rr.unit)

    # xx = xx.get_values()
    # yy = yy.get_values()
    # rr = rr.get_values()

    xx = xx.values
    yy = yy.values
    rr = rr.values

    while (not numpy.isscalar(xx) and len(xx) > 1):
        xx = xx[0]

    while (not numpy.isscalar(yy) and len(yy) > 1):
        yy = yy[0]

    while len(rr.shape) > 1:
        rr = rr[0]

    if rr_unit:
        rr /= _to_arcsecs

    # FIXME make 0.355 a user-setable parameter
    # FIXME assumes that region is from wavdetect at 4.5 sigma
    # size = numpy.sqrt(numpy.sum(numpy.square(rr)))/numpy.sqrt(2.) * 0.355

    size = numpy.sqrt(numpy.sum(numpy.square(rr)))/numpy.sqrt(2.)
    x0 = float(xx)
    y0 = float(yy)

    return (x0,y0,size)


class MHO(lev3_mho.MHO):

    def __init__(self, x0, y0, crude_src_size, img_size=0,
                 bin_factor=0, min_counts=15, min_thresh=4):
        self._to_arcsecs = 1.0
        self._to_degrees = 180./numpy.pi
        self._units_size = 'pixel'

        lev3_mho.MHO.__init__(self, x0, y0, crude_src_size, img_size,
                              bin_factor, min_counts, min_thresh)


    def setup_image(self, img, imgfile, regfile=None):
        x0, x1, ax0, ax1, y = read_image(img)

        if((self.x0 is None) or (self.y0 is None) or (self.crude_src_size <= 0.0)):
            xx, yy, size = image_guess(x0, x1, y)
            if regfile is not None:
                xx, yy, size = parse_regfile(regfile, self._to_arcsecs)

            if self.x0 is None:
                self.x0 = xx
            if self.y0 is None:
                self.y0 = yy
            if self.crude_src_size <= 0.0:
                self.crude_src_size = size
            else:
                # convert user input from arcseconds to pixels
                self.crude_src_size /= self._to_arcsecs

        if y is None:
            return -1

        num = check_normalized_psf(y, y.sum())
        if num < self.min_thresh:
            return -1

        f = lev3_mho.MHOStruct(ax0, ax1, self.x0, self.y0, None,None,None, y)

        f.x -= self.x0; f.y -= self.y0
        dx = f.x[1] - f.x[0]
        dy = f.y[1] - f.y[0]
#        f.x += 0.5*dx; f.y += 0.5*dy

        (f.x_0, f.x_1) = numpy.meshgrid(f.x, f.y)
        f.i = f.image.nonzero()

        self.f = f

        lev3_mho.set_data(lev3_mho.MHOData("",None,None,None))

        self.improve_source_position()

        return 0


    def setup_table(self, tbl, evtfile, psfbin, regfile=None):

        if((self.x0 is None) or (self.y0 is None) or (self.crude_src_size <= 0.0)):
            if regfile is None:
                raise IOError('A region file must be supplied for an event file')

            xx, yy, size = parse_regfile(regfile, self._to_arcsecs)
            if self.x0 is None:
                self.x0 = xx
            if self.y0 is None:
                self.y0 = yy
            if self.crude_src_size <= 0.0:
                self.crude_src_size = size
            else:
                # convert user input from arcseconds to pixels
                self.crude_src_size /= self._to_arcsecs

        #if tbl.get_column("X") is None:
        if not pycrates.col_exists(tbl, "X"):
            return -1

        X = pycrates.get_colvals(tbl, "X")
        Y = pycrates.get_colvals(tbl, "Y")

        if X.size < self.min_thresh:
            return -1

        if self.img_size <= 0:
            # img_size should be _much_ larger than the max allowed
            # elliptical region size
            max_major_axis = 2 * (lev3_mho.Mho_Max_Size_Factor * self.crude_src_size)
            self.img_size = 10 * max_major_axis
        else:
            # convert user input from arcseconds to pixels
            self.img_size /= self._to_arcsecs

        adjustable_bin_factor = False
        if self.bin_factor <= 0:
            adjustable_bin_factor = True
            if self.img_size < 128:
                self.bin_factor = 1
                if psfbin: # leave bin_factor at 1 if psfbin was set to null string
                    if psfbin < 1: # set smaller psfbin if it was specified to be < 1
                        self.bin_factor = psfbin
                        self.img_size = self.img_size / psfbin
                        if self.img_size > 128:
                            self.img_size = 128
            else:
                self.bin_factor = int(numpy.round(self.img_size/128.))


        self.f = lev3_mho.extract_image(X,Y,self.x0, self.y0, self.img_size,
                               self.img_size, self.bin_factor)

        if self.f is None:
            return -1

        lev3_mho.set_data(lev3_mho.MHOData("",None,None,None))

        if adjustable_bin_factor and self.bin_factor > 1:
            # If the source region looks bright, don't use coarse binning.
            while True:
                dims = self.f.image.shape
                ny = dims[0]
                nx = dims[1]
                slice = self.f.image[ny//2-ny//8:ny//2+ny//8+1,
                                     nx//2-nx//8:nx//2+nx//8+1]
                if slice.max() < 100 or self.bin_factor == 1:
                    break
                self.bin_factor //= 2
                self.f = lev3_mho.extract_image(X,Y,self.x0, self.y0, self.img_size,
                                       self.img_size, self.bin_factor)
                if self.f is None:
                    return -1

        if self.improve_source_position():
            self.f = lev3_mho.extract_image(X,Y,self.x0, self.y0, self.img_size,
                                   self.img_size, self.bin_factor)
            if self.f is None:
                return -1


        return 0


    def mho_find_source_extent(self, file, psfbin, source_shape=True, regfile=None):
        if file is None:
            return [None, None,
                    numpy.array([numpy.nan]*3),
                    numpy.array([numpy.nan]*2)]

        obj = pycrates.read_file(file)

        # retrieve the WCS for pixel to arcsecond conversion
        #if "EQPOS" not in obj.get_axisnames():
        #    raise ValueError("WCS information not found in '%s', cannot convert arcseconds to pixels" % file)

        pos = obj.get_transform("EQPOS")
        cdelt = pos.get_parameter_value("CDELT")
        self._to_arcsecs = numpy.abs(cdelt[1]) * 3600.0
        assert self._to_arcsecs > 0.0, "WCS - CDELT is malformed, cannot convert arcseconds to pixels"
        self._units_size = 'arcsec'

        #if type(obj) == pycrates.IMAGECrate:
        if pycrates.get_crate_type(obj) == "Image":
            if 0 != self.setup_image(obj, file, regfile):
                return self.counts_error(file)

        #elif type(obj) == pycrates.TABLECrate:
        elif pycrates.get_crate_type(obj) == "Table":
            if 0 != self.setup_table(obj, file, psfbin, regfile):
                return self.counts_error(file)

        else:
            raise IOError("source '%s' is not a table or image" % file)

        pars = self.mho_optimize_axes()

        if source_shape:
            # source = 2-D Gaussian => sigma_i = a_i / sqrt(3)
            pars[[0,1]] /= numpy.sqrt(3)

        else:
            # source = uniform circular disk =>  R = a * sqrt(2)
            pars[[0,1]] *= numpy.sqrt(2)

        # determine number of counts inside source ellipse
        reg = "[sky=ellipse(%s,%s,%s,%s,%s)]" % (self.x0, self.y0,
                                                 pars[0], pars[1], pars[2])

        counts=None
        #if type(obj) == pycrates.IMAGECrate:
        if pycrates.get_crate_type(obj) == "Image":
            counts = calc_num_counts_img(file, reg)
        else:
            counts = calc_num_counts_tbl(file, reg)

        if counts < self.min_counts:
            print("Error: less than %s counts detected in " % self.min_counts +
                   "detected in ellipse(%s,%s,%s,%s,%s)" %
                   ( self.x0, self.y0, pars[0], pars[1], pars[2]))
            return [self.x0, self.y0, pars, [numpy.nan]*2]

        errs = self.calc_errors(counts, pars)

        return [self.x0, self.y0, pars, errs]


def _save_params(x, y, pars, errs):
    d = {}
    d['x'] = x
    d['y'] = y
    d['sigma_1']=pars[0]
    d['sigma_2']=pars[1]
    d['sigma_1_err']=errs[0]
    d['sigma_2_err']=errs[1]
    d['pos']=pars[2]

    return d


def _transform_params(in_d, _to_degrees, _to_arcsecs):
    # ..if errors are NaN, set major/minor axis and pos to nan
    d={}
    d['x']          = in_d['x']
    d['y']          = in_d['y']
    if math.isnan(in_d['sigma_1_err']) or math.isnan(in_d['sigma_2_err']):
        d['sigma_1']    = convert(numpy.nan, _to_degrees)
        d['sigma_1_lo'] = convert(numpy.nan, _to_degrees)
        d['sigma_1_hi'] = convert(numpy.nan, _to_degrees)
        d['sigma_2']    = convert(numpy.nan, _to_degrees)
        d['sigma_2_lo'] = convert(numpy.nan, _to_degrees)
        d['sigma_2_hi'] = convert(numpy.nan, _to_degrees)
        d['pos']        = convert(numpy.nan, _to_degrees)
    else:
        d['sigma_1']    = convert(in_d['sigma_1'], _to_arcsecs)
        d['sigma_1_lo'] = convert(in_d['sigma_1'] - in_d['sigma_1_err'], _to_arcsecs)
        d['sigma_1_hi'] = convert(in_d['sigma_1'] + in_d['sigma_1_err'], _to_arcsecs)
        d['sigma_2']    = convert(in_d['sigma_2'], _to_arcsecs)
        d['sigma_2_lo'] = convert(in_d['sigma_2'] - in_d['sigma_2_err'], _to_arcsecs)
        d['sigma_2_hi'] = convert(in_d['sigma_2'] + in_d['sigma_2_err'], _to_arcsecs)
        d['pos']        = convert(in_d['pos'], _to_degrees)
    d['pos_lo']     = convert(numpy.nan, _to_degrees)
    d['pos_hi']     = convert(numpy.nan, _to_degrees)

    return d


def convert(val, factor):
    return val*factor

def format(val, format="%.2f"):
    if numpy.isnan(val) or numpy.isinf(val):
        return "INDEF"
    return format % val


def _get_size(val):
    return numpy.sqrt(numpy.sum(numpy.square(val)))/numpy.sqrt(2.)


def _transform_data(src_d, psf_d, _to_degrees, _to_arcsecs):
#     src_1 = src_d['sigma_1']
#     src_2 = src_d['sigma_2']
#     src_err_1 = src_d['sigma_1_err']
#     src_err_2 = src_d['sigma_2_err']

#     psf_1 = psf_d['sigma_1']
#     psf_2 = psf_d['sigma_2']
#     psf_err_1 = psf_d['sigma_1_err']
#     psf_err_2 = psf_d['sigma_2_err']

    a_rss = numpy.nan; a_rss_err=numpy.nan; extended=0

    a_rss = lev3_iss.calc_a_rss(src_d, psf_d)

    if a_rss == a_rss and a_rss > 0: 
        a_rss_err = lev3_iss.calc_a_rss_err(a_rss, src_d, psf_d)

        extended =  lev3_iss.calc_extended(a_rss, a_rss_err, psf_d,
                                           extent_flag_threshold=5)

    int_d = _save_params(None,None,[a_rss,a_rss,0.0],[a_rss_err,a_rss_err])

    int_d = _transform_params(int_d, _to_degrees, _to_arcsecs)
    src_d = _transform_params(src_d, _to_degrees, _to_arcsecs)
    psf_d = _transform_params(psf_d, _to_degrees, _to_arcsecs)

    return (int_d, src_d, psf_d, extended)


def set_column(tbl, col, colname, val, unit=None):
    col.name = str(colname).strip()

    if unit is not None and unit != '':
        col.unit = unit

    if numpy.isscalar(val):
        col.values = numpy.array([val])
    else:
        col.values = numpy.asarray(val)

    pycrates.add_col(tbl, col)

    # if type(val) in (str, numpy.string_):
    #     col.set_nsets(1)
    #     col.load(val,False)
    # else:
    #     if numpy.isscalar(val):
    #         col.set_nsets(1)
    #     else:
    #         col.set_nsets(len(val))    
    #     col.set_values(numpy.asarray(val))
    #tbl.add_column(col)



def data_substring(x,y,size,size_lo,size_hi,pos,pos_lo,pos_hi, hdr_str, _units_size):

    unit = '"'
    if _units_size == 'pixel':
        unit = _units_size

    lines = ("%s%s %s @ PA %s deg" % 
             (hdr_str, format(size) , unit, format(pos)))

    if x is not None and y is not None:
        lines += (" at pixel coords %s, %s" % (format(x, "%.1f"),format(y, "%.1f")))

    lines += ("\n    90%% Confidence intervals: (%s -- %s) @ (%s -- %s)\n\n" %
              (format(size_lo),format(size_hi),format(pos_lo),format(pos_hi)))

    return lines


def format_data(id, psf, int_d, src_d, psf_d, extended, src_units_size, psf_units_size):

    src_size    = _get_size(numpy.array([src_d['sigma_1'],src_d['sigma_2']]))
    src_size_lo = _get_size(numpy.array([src_d['sigma_1_lo'],src_d['sigma_2_lo']]))
    src_size_hi = _get_size(numpy.array([src_d['sigma_1_hi'],src_d['sigma_2_hi']]))

    lines =  "Results for Source %i\n\n" % id
    lines += data_substring(src_d['x'], src_d['y'], src_size, src_size_lo, src_size_hi,
                            src_d['pos'], src_d['pos_lo'], src_d['pos_hi'],
                            "Source Observed Size: ", src_units_size)

    if psf is not None:
        psf_size    = _get_size(numpy.array([psf_d['sigma_1'],psf_d['sigma_2']]))
        psf_size_lo = _get_size(numpy.array([psf_d['sigma_1_lo'],psf_d['sigma_2_lo']]))
        psf_size_hi = _get_size(numpy.array([psf_d['sigma_1_hi'],psf_d['sigma_2_hi']]))

        lines += data_substring(psf_d['x'], psf_d['y'], psf_size, psf_size_lo, psf_size_hi,
                                psf_d['pos'], psf_d['pos_lo'], psf_d['pos_hi'],
                                "PSF Observed Size:    ", psf_units_size)

        # already converted by ISS
        int_size    = int_d['sigma_1']
        int_size_lo = int_d['sigma_1_lo']
        int_size_hi = int_d['sigma_1_hi']

        lines += data_substring(int_d['x'], int_d['y'], int_size, int_size_lo, int_size_hi,
                                int_d['pos'], int_d['pos_lo'], int_d['pos_hi'],
                                "Estimated Intrinsic Size: ", src_units_size)

        lhs = "Source is"
        rhs = "extended at 90% confidence"

        if extended:
            lines += ' '.join([lhs,rhs])
        else:
            lines += ' '.join([lhs,"not",rhs])

    return lines


def write_file(outfile, srcfile, psffile, extended, int_d, src_d, psf_d,
               src_units_size, psf_units_size):

    #tbl = pycrates.TABLECrate(outfile)
    tbl = pycrates.TABLECrate()
    tbl.name = "TABLE"
    hdr = ['filename', 'psf', 'extent_flag']

    d_keys = ['x','y',
              'sigma_1','sigma_1_lo','sigma_1_hi',
              'sigma_2','sigma_2_lo','sigma_2_hi',
              'pos','pos_lo','pos_hi']

    int_cols = ['','',
                'major_axis','major_axis_lo','major_axis_hi',
                'minor_axis','minor_axis_lo','minor_axis_hi',
                'pos_angle', 'pos_angle_lo', 'pos_angle_hi']

    src_cols = ['x', 'y',
                'mjr_axis_raw','mjr_axis_raw_lo','mjr_axis_raw_hi',
                'mnr_axis_raw','mnr_axis_raw_lo','mnr_axis_raw_hi',
                'pos_angle_raw','pos_angle_raw_lo','pos_angle_raw_hi']

    psf_cols = ['psf_x', 'psf_y',
                'psf_mjr_axis','psf_mjr_axis_lo','psf_mjr_axis_hi',
                'psf_mnr_axis','psf_mnr_axis_lo','psf_mnr_axis_hi',
                'psf_pos_angle','psf_pos_angle_lo','psf_pos_angle_hi']

    size = len(hdr)+len(int_cols)+len(src_cols)+len(psf_cols)-2
    cols = [pycrates.CrateData() for ii in range(size)]

    src_units = ['pixel']*2 + [src_units_size]*6 + ['deg']*3
    psf_units = ['pixel']*2 + [psf_units_size]*6 + ['deg']*3

    filenames = [srcfile,psffile,int(extended)]
    if psffile is None:
        hdr = ['filename']
        filenames = [srcfile]

    for item, name in zip(filenames, hdr):
        set_column(tbl, cols.pop(0), name.upper(), item)

    if psffile is not None:
        for colname, key, unit in zip(int_cols, d_keys, src_units):
            if int_d[key] is not None:
                set_column(tbl, cols.pop(0), colname.upper(), int_d[key], unit)

    for colname, key, unit in zip(src_cols, d_keys, src_units):
        set_column(tbl, cols.pop(0), colname.upper(), src_d[key], unit)

    if psffile is not None:
        for colname, key, unit in zip(psf_cols, d_keys, psf_units):
            set_column(tbl, cols.pop(0), colname.upper(), psf_d[key], unit)

    tbl.write(outfile)


def srcsize_from_approxPSF(theta):
    '''return the MHO input parameter SRCISZE by looking up an approximate PSF size (radius)
    [ref: level3 release 1 tool lev3_2d_fit.sl] based on the off-axis angle (THETA) reported 
    in the source's region file header. The ciao version asks for an explicit input value of
    THETA. Since the lookup table returns ACIS pixels (physcial sky coords), multiply by the 
    acis pixel size to return a SRCSIZE in units of arcseconds.'''

    if theta < 0.0 or theta > 54.742:
        if get_verbose() < 40:
            print("WARNING: unexpected THETA value of {0}. Assuming max or min srcsize for theta < 0 or > 54.742 arcmin for PSF fit.".format(theta))

    # off-axis angle [arcmin]
    tbl_theta = [0.000000e+00, 3.902054e-01, 1.127930e+00, 1.718110e+00, 2.234518e+00,
                 3.356987e+00, 4.705897e+00, 6.808413e+00, 1.020195e+01, 1.341105e+01,
                 1.739477e+01, 2.075142e+01, 2.462447e+01]

    # Approximate PSF 1-sigma radius [ACIS physical pixels]
    tbl_psf_size = [0.91031789,   0.91104024,   0.93522419,   1.00289959,  1.11370447,
                    1.57935   ,   2.42282114,   4.54356301,   8.979     , 16.40306504,
                    29.70497967,  41.39542683,  54.74180894]

    srcsize = numpy.interp(theta,tbl_theta,tbl_psf_size) * 0.492 # convert to arcseconds from ACIS physical pixels

    return srcsize


def calculate_regevtfile_theta(regevtfile,xx,yy):
    ''' Call dmcoords to calcaulte thetad (off-axis angle)
    for a region file'''

    # ..calculate sky coords for bottom left image pixel
    args = {}
    args['infile']   = regevtfile
    args["opt"] = "sky";
    args["asolfile"] = "none"
    args["x"] = xx
    args["y"] = yy

    paramio.punlearn("dmcoords")
    cmd = "dmcoords "
    for par in args:
        cmd += par + "='" + str(args[par]) + "' "
    with open(os.devnull,'w') as null_fp:
        status = subprocess.check_call( cmd, shell=True, stdout=null_fp, stderr=None )
        if status != 0:
            print("Non-zero status with command: dmcoords infile={0} x={1} y={2}".format(args['infile'],args['x'],args['y']))
            os._exit(1)

    #...recover physical (sky) coords from dmcoords par file
    try :
        dmc_pfile = paramio.paramopen("dmcoords", "rw")
    except :
        print("ERROR: can't open dmcoords parameter file")
        os._exit(1)
        
    theta = float(paramio.pgetstr(dmc_pfile,"theta"))
    
    paramio.paramclose(dmc_pfile)

    return theta



def srcsize_from_regfile(regfile,srcfile,sigmaFactor):
    '''return the MHO input parameter SRCISZE by retrieving the region file ellipse
    and calculating a value from the average ellipse radius.'''

    # read the srcfile to get the sky pixel to arcsecond converstion factor
    obj = pycrates.read_file(srcfile)
    pos = obj.get_transform("EQPOS")
    cdelt = pos.get_parameter_value("CDELT")
    skypix_to_arcsecs = numpy.abs(cdelt[1]) * 3600.0

    xx, yy, size = parse_regfile(regfile)
    src_size = size * sigmaFactor * skypix_to_arcsecs

    return src_size


def do_source(params, id, srcfile, outfile, psffile, psfbin, theta, regfile=None):

    # When fitting the PSF, check for the special case of param["srcsize"] == -1.
    # This indicates that srcsize should be based on approximate 90% ECF PSF table
    # lookup values based on off-axis angle.
    # If srcsize is not < 0, pass the value directly. The value of srcsize == 0
    # is another special case handled internally in MHO (uses a srcsize derived from
    # region file region size).
    if params["srcsize"] < 0 and psffile:
        src_size = srcsize_from_approxPSF(theta)
    else:
        src_size = params["srcsize"]
	
    psf = MHO(params["x0"], params["y0"], src_size, 
              img_size=params["imgsize"], bin_factor=params["binfactor"],
              min_counts=params["mincounts"], min_thresh=params["minthresh"])
    
    x, y, pars, errs = psf.mho_find_source_extent(psffile, psfbin,
                                                  source_shape=params["shape"],
                                                  regfile=regfile)

    psf_d = _save_params(x, y, pars, errs)
        
    # Add additional aspect blur if this is an acis source
    blur = 0.0
    if psffile:
        psf_cr = pycrates.read_file(psffile)
        instrume_keywd = psf_cr.get_key("INSTRUME")
        if instrume_keywd:
            if instrume_keywd.value == "ACIS":
                blur = params['psfblur']
    if blur > 0.0:
        blur_sqd = blur * blur
        psf_s1 = numpy.sqrt(psf_d["sigma_1"]**2+blur_sqd)
        psf_s2 = numpy.sqrt(psf_d["sigma_2"]**2+blur_sqd)
        psf_d["sigma_1"] = psf_s1
        psf_d["sigma_2"] = psf_s2

    # geometric average of PSF sigmas as guess of crude source size
    src_size = _get_size(pars[[0,1]])

    if numpy.isnan(src_size):
        src_size = params["srcsize"]

    x = None; y = None;

    # When fitting the SRC, check for the special case of param["srcsize"] == -1.
    # This indicates that srcsize should be based on the input region file ellipse.
    # Scale srcsize by the input param sigmafactor, which allows (for example) 
    # for a 3 sigma detection ellipse to be scaled to a 90% ECF ellipse (1.6/3.0).
    # Also change srcsize to be in units of arcseconds, not region file sky pixels.
    if params["srcsize"] < 0:
        src_size = srcsize_from_regfile(regfile,srcfile,params["sigmafactor"])

    # Fit the SRC using PSF fit
    src = MHO(params["x0"], params["y0"], src_size,
              img_size=params["imgsize"], bin_factor=params["binfactor"],
              min_counts=params["mincounts"], min_thresh=params["minthresh"])

    x, y, pars, errs = src.mho_find_source_extent(srcfile, psfbin,
                                                  source_shape=params["shape"],
                                                  regfile=regfile)

    src_d = _save_params(x, y, pars, errs)

    int_d, src_d, psf_d, extended = _transform_data(src_d, psf_d, 
                                                    src._to_degrees, 
                                                    src._to_arcsecs)

    if get_verbose() < 30:
        print(format_data(id, psffile, int_d, src_d, psf_d, extended,
                          src._units_size, psf._units_size))

    write_file(outfile, srcfile, psffile, extended, int_d, src_d, psf_d,
               src._units_size, psf._units_size)
#    return (srcfile, psffile, extended, int_d, src_d, psf_d, src._units_size, psf._units_size)


def get_file(infiles):
    file = None

    if infiles:
        file = infiles.pop(0)

    if str(file) == "" or str(file) == "INDEF" or str(file).lower() == "none":
        file = None

    if file is not None and not os.path.isfile(file):
        raise IOError("file not found '%s'" % file)

    return file


def select_hiRes_psfext(fn):
    # select_hiRes_psfext routine identifies which extension has higher resolution
    # binning for the PSF. It returns the higher res extension name, so,
    # eg, for a BINPSF extension with bin=0.5 value, "[binpsf]" is returned.

    # check to see if the psf file is an image. If not, return empty strings
    cr = pycrates.read_file(fn)
    if pycrates.get_crate_type(cr) != "Image":
        return "",""

    # read the psf extension
    try:
        psf_cr = pycrates.read_file(fn+"[psf]")
        psf_ext_str = "[psf]"
    except:
        psf_cr = pycrates.read_file(fn)
        psf_ext_str = ""

    # look for [binpsf] extension and read BIN header keyword from both extensions
    try:
        binpsf_cr = pycrates.read_file(fn+"[binpsf]")
    except:
        binpsf_cr = ""
    psf_cr_bin = psf_cr.get_key("BIN")
    if binpsf_cr:
        binpsf_cr_bin = binpsf_cr.get_key("BIN")
    else:
        binpsf_cr_bin = ""
    if psf_cr_bin:
        psf_cr_bin = float(psf_cr.get_key("BIN").value)
    if binpsf_cr_bin:
        binpsf_cr_bin = float(binpsf_cr.get_key("BIN").value)

    # if [binpsf] doesn't exist or either extension has no bin keyword value,
    # just return the [psf] extension data. Note that this will return either
    # an actual bin keyword value when present in the [psf] extension, or it
    # will return an empty string when the value isn't present.
    if not psf_cr_bin or not binpsf_cr_bin:
        return psf_cr_bin,psf_ext_str

    # return the extension with the smallest bin keyword value
    if psf_cr_bin < binpsf_cr_bin:
        return psf_cr_bin,psf_ext_str
    else:
        return binpsf_cr_bin,"[binpsf]"
#end: select_hiRes_psfext()


def add_history_to_file(file_name, tool_name, param_names, param_values):
   if os.path.exists( file_name) == False : 
      return 
   crate = read_file(file_name, mode="rw")
    
   histnum = crate.get_key("HISTNUM")
   if histnum is None:
      histnum = CrateKey()
      histnum.name = "HISTNUM"
      histnum.value = 1
      crate.add_key(histnum)
        
   startat = histnum.value
   startat = 1 if startat is None else startat

   hh = HistoryRecord(tool=tool_name,param_names=param_names,
                      param_values=param_values, asc_start=startat)    
   add_history(crate, hh) 
    
   # Increment HISTNUM for new lines
   hlen = hh.as_ASC_FITS().split("\n")
   histnum.value = startat+len(hlen)

   crate.write() 
#end: add_history_to_file  


def run_srcextent():

    params, verbose = get_parameters("srcextent.par")

    # Get parameter name/values in order stored in parameter file
    pnames= paramio.plist( "srcextent" )
    pvals=[params[x] for x in pnames]

    set_verbose(verbose)

    for ii, srcfile in enumerate(params["srcfiles"]):
        id = ii + 1

        if not os.path.isfile(filter_filename(srcfile)):
            raise IOError("source file not found '%s'" % srcfile)

        psffile = get_file(params["psffiles"])
        regfile = get_file(params["regfiles"])
        outfile = params["outfiles"][ii]

        if psffile:
            psfbin,psfext = select_hiRes_psfext(psffile)
            psf_fn = psffile + psfext
        else: # no psf file provided
            psf_fn = None
            psfbin = ""

        do_source(params, id, srcfile, outfile, psf_fn, psfbin, params['theta'], regfile)

        ## add history 
        add_history_to_file( outfile, "srcextent", pnames, pvals)

if __name__ == '__main__':
    run_srcextent()
