#!/usr/bin/env python

# 
# Copyright (C) 2019,2020,2025 Smithsonian Astrophysical Observatory
# 
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


import sys
import os

from dax.dax_plot_utils import *


def blt_plot_fit(sherpa_plot_fit_obj, xpa="ds9", 
       tle="Plot Fit", xlb="x ", ylb="y " ):

    
    _d = sherpa_plot_fit_obj.dataplot
    _m = sherpa_plot_fit_obj.modelplot

    _dchi = (_d.y-_m.y)/_d.yerr
    _ones = _dchi*0+1.0
    if _d.xerr is None:
        _d.xerr = (_d.x-_d.x)  # zeros

    if hasattr(_m, "xlo"):
        mx = list(_m.xlo)
        mx.append(_m.xhi[-1])
        my = list(_m.y)
        my.append(_m.y[-1])
        step = True
    else:
        mx = _m.x
        my = _m.y
        step = False

    blt_plot_model( xpa, mx, my,
      tle, xlb, ylb, new=True, step=step,
      winname="dax_model_editor")
    blt_plot_data( xpa, _d.x, None, _d.y, _d.yerr)
    blt_plot_delchisqr(xpa, _d.x,  _d.xerr/2.0, _dchi, _ones, "")


# -----------------------------------------

import subprocess as sp
from tempfile import NamedTemporaryFile    
import sherpa.astro.ui as sherpa

ROOT=os.path.join( os.environ["DAX_OUTDIR"], "rprof_fit", str(os.getpid()))


def translate_units(units):
    if 'physical' == units:
        rcol = 'r'
    elif 'arcsec' == units:
        rcol = 'cel_r'
    else:
        raise ValueError("Unknown units: {}".format(units))

    return rcol

def translate_model(model, model_name):
    
    if 'none' == model:
        return None

    m = sherpa.create_model_component( model, model_name )
    if "polynom1d" == model:
        sherpa.thaw(m.c1)
        sherpa.thaw(m.c2)

    return(m)

    

def load_arguments():
    plist =  ['xpa', 'units', 'model1', 'model2', 'method']    
    pars = { d : sys.argv[i+1] for i,d in enumerate(plist) }
    return pars


def get_ciao_stack_from_ds9(xpa):
    ds9_file = NamedTemporaryFile(dir=ROOT, suffix="_ds9.reg")
    ciao_list = NamedTemporaryFile(dir=ROOT, suffix="_ciao.lis", delete=False)

    cmd = [ 'xpaget', xpa, 'regions', '-format', 'ciao', '-system', 'physical', 'background', '-strip', 'yes' ]
    ds9_reg = sp.run( cmd, check=True, stdout=sp.PIPE).stdout    
    bkg=ds9_reg.decode() if len(ds9_reg) > 0 else ""


    cmd = [ 'xpaget', xpa, 'regions', '-format', 'ds9', '-system', 'physical', 'source' ]
    ds9_reg = sp.run( cmd, check=True, stdout=sp.PIPE).stdout    
    open(ds9_file.name,"w").write(ds9_reg.decode())
    
    cmd = [ 'convert_ds9_region_to_ciao_stack', ds9_file.name, ciao_list.name, "clobber=yes", "verbose=0"]
    sp.run(cmd, check=True)    

    return(ciao_list.name, bkg)


def save_ds9_image(xpa):
    ds9_file = NamedTemporaryFile(dir=ROOT, suffix="_ds9.fits", delete=False)


    cmd = [ 'xpaget', xpa, 'block' ]
    block = sp.run( cmd, check=True, stdout=sp.PIPE).stdout    
    if int(block) != 1:
        raise IOError("ERROR: This task requires that the input image be blocked to 1")

    cmd = [ 'xpaget', xpa, 'fits' ]
    fits = sp.run( cmd, check=True, stdout=sp.PIPE).stdout    
    open(ds9_file.name,"wb").write(fits)
    
    return(ds9_file.name)
    

def run_dmextract( infile, stack, bkg):    
    from ciao_contrib.runtool import dmextract

    prof = NamedTemporaryFile(dir=ROOT, suffix="_radial.prof", delete=False)
    dmextract.infile = "{}[bin (x,y)=@-{}]".format(infile, stack)
    dmextract.outfile = prof.name
    dmextract.clobber=True
    dmextract.verbose=0
    dmextract.opt="generic"

    if len(bkg) > 0:
        dmextract.bkg = "{}[bin (x,y)={}]".format(infile, bkg)
    
    vv = dmextract()
    if vv:
        print(vv)
    
    return prof.name

    
def fit_profile( xpa, infile, model, xcol, method, ycol="sur_bri", ycol_err="sur_bri_err" ):

    # ~ from pycrates import read_file
    # ~ tab = read_file(infile)
    # ~ radcol = tab.get_column(xcol).values
    # ~ xlo = radcol[:,0]
    # ~ xhi = radcol[:,1]
    # ~ yy = tab.get_column(ycol).values
    # ~ yerr = tab.get_column(ycol_err).values
    # ~ sherpa.load_arrays(1, xlo, xhi, yy, yerr, sherpa.Data1DInt)

    sherpa.load_data( 1, infile, colkeys=[xcol+"mid", ycol, ycol_err])
    sherpa.set_model(model)
    sherpa.set_method(method)

    from dax.dax_model_editor import DaxModelEditor, DaxCancel

    s = sherpa.get_source()
    if hasattr(s,"parts") is True:
        mdl = s.parts
    else:
        mdl = (s,) 

    if 'rmid' == xcol:
        xcol = 'pixels'
    else:
        xcol = 'arcsec'

    try:
        me = DaxModelEditor(mdl, xpa, xlabel="Radius [{}]".format(xcol),
            ylabel="sur_bri [counts/pixel**2]")
        me.run(sherpa.fit, sherpa.conf)
        fobj = sherpa.get_fit_plot()    
        blt_plot_fit( fobj, xpa, 
           tle="Radial Profile: {}".format(model.name),
           xlb="Radius [{}]".format(xcol),
           ylb="sur_bri [counts/pixel**2]"        
           )
    except DaxCancel:
        print("Cancel button pressed")

    sav = ROOT+"/sherpa.save"
    sherpa.save( sav, clobber=True)
    print("\nTo restore session, start sherpa and type\n\nrestore('{}')".format(sav))
    print("-----------------------")


def doit():
    
    pars = load_arguments()
    os.makedirs(ROOT)

    xcol = translate_units( pars["units"])
    my_model = translate_model( pars["model1"], "first_cpt")
    if my_model is None:
        raise ValueError("First model cannot be none")

    model2 = translate_model( pars["model2"], "second_cpt")
    if model2 is not None:
        my_model = my_model + model2

    stk = fits = prof = None

    try:
        stk,bkg = get_ciao_stack_from_ds9(pars["xpa"])
        fits = save_ds9_image(pars["xpa"])
        prof = run_dmextract( fits, stk,bkg )
        
        fit_profile(pars["xpa"], prof, my_model, xcol, pars["method"])
    finally:
        if stk is not None and os.path.exists(stk):
            os.unlink( stk )
        if fits is not None and os.path.exists(fits):
            os.unlink( fits )
    
    
    

doit()


def test_plot():
    import sherpa.astro.ui as sherpa
    sherpa.load_arrays(1, [1,2,3], [3,4,3])


    sherpa.set_model(sherpa.const1d.c0)
    sherpa.fit()
    ff = sherpa.get_fit_plot()
    blt_plot_fit(ff, "ds9")
