#!/opt/conda/envs/ciao-4.17.0/bin/python3
# 
#  Copyright (C) 2009-2010,2012,2016,2022  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


# (8/2009) - initial version

from __future__ import print_function
import os
import sys
import glob
import paramio
import string

'''Use the history module to add HISTORY records to crate'''
from pycrates import read_file
from pycrates import add_history, CrateKey
from history import HistoryRecord

##  get parameters
#
def get_params( toolpar ) :
   pp = {}
   try :
      pfile = paramio.paramopen(toolpar, "rw", sys.argv)
   except :
      error_out("ERROR: can't open parameter file "+toolpar )
   pp["infile"]   = paramio.pgetstr(pfile,"infile")
   pp["fovfile"]  = paramio.pgetstr(pfile,"fovfile")
   pp["expmap"]   = paramio.pgetstr(pfile,"expmap")
   pp["outfile"]  = paramio.pgetstr(pfile,"outfile")
   pp["xygrid"]   = paramio.pgetstr(pfile,"xygrid")
   pp["outimgfile"]  = paramio.pgetstr(pfile,"outimgfile")
   pp["outstreakmap"]  = paramio.pgetstr(pfile,"outstreakmap")
   pp["outlowfreqmap"]  = paramio.pgetstr(pfile,"outlowfreqmap")
   pp["tmpdir"]   = paramio.pgetstr(pfile,"tmpdir")
   pp["expcorr"]  = paramio.pgetstr(pfile,"expcorr")
   pp["scale"]  = int(paramio.pgetstr(pfile,"scale"))
   pp["smoothing_kernel"]  = paramio.pgetstr(pfile,"smoothing_kernel")
   pp["clobber"]  = paramio.pgetstr(pfile,"clobber")
   pp["verbose"]  = int(paramio.pgetstr(pfile,"verbose"))
   pp["mode"] = paramio.pgetstr(pfile,"mode")
   paramio.paramclose(pfile)
   return pp


## error out w/ message
#
def error_out( msg ) :
    sys.stderr.write( msg )
    sys.exit(1)


## if clobber is yes and the outfile exists, error out.
#
def check_clob( Clob, File ) :
   if ( Clob[0] == "n") or ( Clob[0] == "N") :
      if os.path.exists( File ) == True :
         msg="\nERROR: outfile '"+File+"' exists and clobber=no.\n"
         error_out(msg)


## execute task (cmd); if there's a problem, error_out w/ message 
#
def exe_task (tool, cmd) :
   paramio.punlearn( tool )
   status = os.system ( tool+" "+cmd )
   if status != 0  :
      error_out( "problem with "+tool+"\n")


## get the key value from File header
#
def get_key_value ( File, Key ) :
   cmd=File+" "+Key+r" echo-"
   exe_task("dmkeypar", cmd )
   value = paramio.pget("dmkeypar", "value")
   return value


##  set xygrid
#
def set_xygrid( par_xygrid ) :
   xygrid = par_xygrid.split(",")
   if len(xygrid) < 2 :
      error_out("ERROR: problem with xygrid.\n")
   return xygrid


##  create streak map for acis
#   
def create_streakmap( evtfile, fovFile, T_bkgroot, 
                      T_reg, T_dir, cbv ) :
   remove_files(T_bkgroot+"*.fits")
   cmd="infile="+evtfile+" fovfile="+fovFile+" bkgroot="
   cmd+=T_bkgroot+" regfile="+T_reg+" tmppath="+T_dir+cbv
   exe_task("acis_streak_map", cmd)


##  reproject per-chip ccd streak maps to match exposure map
#   ( Note: always set lookupTab="" )
#   
def reproject_streakmap( T_bkgroot, expFile, 
                         T_streakmap, cbv ) :
   T_BKG = T_bkgroot+r"?.fits"
   cmd="infile="+T_BKG+" matchfile="+expFile+" outfile="+T_streakmap
   cmd+=r" method=sum coord_sys=world resolution=1 lookupTab=  "+cbv   
   exe_task("reproject_image", cmd)
   if os.path.exists(T_streakmap) == False:
      error_out ("ERROR: problem with streak maps.\n")


#--- Note: always use lookupTab=""
def exe_dmimgcalc( infile1, infile2, outfile, op, cbv ) :
   cmd = "infile="+infile1+" infile2="+infile2+" outfile="+outfile
   cmd+=r' op="'+op+r'" lookupTab=  '+ cbv
   exe_task("dmimgcalc", cmd)


##  create lffile for ACIS
#   
def acis_lffile( evtfile, fovFile, expFile, T_dir, 
                 T_image, T_bkgroot, T_reg, T_streakmap, 
                 T_rmstreak, cbv ) :
   # create a streak map for ACIS
   create_streakmap( evtfile, fovFile, T_bkgroot, T_reg, T_dir, cbv)

   ##  reproject per-chip ccd streak maps to match exposure map
   reproject_streakmap( T_bkgroot, expFile, T_streakmap, cbv )

   ##  subtract streak map from image
   exe_dmimgcalc( T_image, T_streakmap, T_rmstreak, "sub", cbv )


##  run dmcopy 
#   
def exe_dmcopy(infile, outfile, cbv) :
   cmd ="infile="+infile+" outfile="+outfile+cbv
   exe_task("dmcopy", cmd)


##  run aconvolve for HRC
def exe_aconvolve(Infile, Outfile, Kernelspec, cbv ) :
   cmd="infile="+Infile+" outfile="+Outfile+r' kernelspec="'+Kernelspec+'"'
   cmd+=" method=slide edge=const const=0 pad=no norm=area"+cbv
   exe_task("aconvolve", cmd )


##  run dmimgpm
#   
def exe_dmimgpm( lffile, T_lfreq, expFile, T_smthexp, scale, cbv) :
   cmd="infile="+lffile+" outfile="+T_lfreq+" expfile="+expFile
   cmd+=" outexpfile="+T_smthexp+" xhalf="+repr(scale)+" yhalf="+repr(scale)+cbv
   exe_task("dmimgpm", cmd)


## normalized exposure map
#
def norm_expmap( infile, outfile, cbv ) :
   cmd=infile+r" sig- med- cen- verb=0"
   exe_task("dmstat", cmd)
   out_max = paramio.pget("dmstat", "out_max")

   exe_dmimgcalc( infile, "none", outfile, 
                 r"imgout=((float)(img1/"+repr(float(out_max))+r"))", cbv )


## is it a blank string ?
#
def isBlank( ss ) :
   if len(ss.strip(" ")) == 0 :
      return True
   else :
      return False


## remove a list of files
#
# allow any number of spaces between files.
# allow syntax w/ '*' ( eg.  *file*  or file* ).
#     eg. List=" f1 f2  any*  *any*  f4  "
#
def remove_files( List ) :
   for FN in List.split() :
      for fn in glob.glob( FN ):
         if isBlank(fn) == False :
            try:
               os.unlink( fn )
            except :
               pass


## This function is for variables: tmp_image/tmp_lfreq/tmp_streakmap
##   
## If the parameter outimgfile is "" or none, 
#     set tmp_image to /tmp/tmp_image.fits and flag to delete it at the end.
#  Else
#     take the user specified value and flag to not delete it 
#   
#  Similarly do the parameters :  outlowfreqmap(tmp_lfreq) and
#                                 outstreakmap (tmp_streakmap)
#
def set_file_flag( set2par, set2tmp) :
   # initialize flag
   del_tmp_file  = False      
   tmp_file  = set2par                  # set file name to parvalue 
   if (isBlank(tmp_file)==True) or (tmp_file.lower()=="none") :
       tmp_file = set2tmp               # set file name to /tmp/tmp_*.fits 
       del_tmp_file  = True
   return del_tmp_file , tmp_file


def add_history_to_file(file_name, tool_name, param_names, param_values):
   if os.path.exists( file_name) == False : 
      return 
   crate = read_file(file_name, mode="rw")
    
   histnum = crate.get_key("HISTNUM")
   if histnum is None:
      histnum = CrateKey()
      histnum.name = "HISTNUM"
      histnum.value = 1
      crate.add_key(histnum)
        
   startat = histnum.value
   startat = 1 if startat is None else startat

   hh = HistoryRecord(tool=tool_name,param_names=param_names,
                      param_values=param_values, asc_start=startat)    
   add_history(crate, hh)
    
   # Increment HISTNUM for new lines
   hlen = hh.as_ASC_FITS().split("\n") 
   histnum.value = startat+len(hlen)

   crate.write()
# end: add_history_to_file  


# ------------ main code --------------
#
def main() :
   ##  get parameters 
   pp = get_params( None )


   ##  check clobber 
   check_clob( pp["clobber"], pp["outfile"] ) 

   verb  = pp["verbose"]
   cv =" clob="+pp["clobber"]+" verb="+repr(verb)

   ##  set xygrid
   #dmhistory emap mkexpmap |tr " " "\012" |grep xygrid |tr ',="' "   "
   xygrid = set_xygrid( pp["xygrid"] )

   ##  create tmp_image
   del_tmp_image, tmp_image = set_file_flag ( pp["outimgfile"], 
                              pp["tmpdir"]+r"/tmp_image.fits")

   exe_dmcopy(pp["infile"]+r'"[bin x='+xygrid[0]+",y="+xygrid[1]+r']"',
                 tmp_image, cv)


   ##  get instrume 
   vv = get_key_value ( pp["infile"], "INSTRUME") 
   instr=vv[0:4]


   ## set lffile for HRC and ACIS
   #
   if instr != "ACIS" :            # HRC
      tmp_smoothed_image = pp["tmpdir"]+r"/tmp_smoothed_image.fits"
      exe_aconvolve( tmp_image, tmp_smoothed_image, pp["smoothing_kernel"], cv)
      lffile =  tmp_smoothed_image
   else:                           # ACIS
      if verb > 0 :
         print(".... creating streak maps for acis ....\n")

      # create lffile for ACIS
      tmp_bkgroot = pp["tmpdir"]+r"/tmp_bkgroot"
      tmp_reg = pp["tmpdir"]+r"/tmp_reg" 

      del_tmp_streakmap, tmp_streakmap = set_file_flag( pp["outstreakmap"],
                                       pp["tmpdir"]+r"/tmp_streakmap.fits" )

      tmp_rmstreak = pp["tmpdir"]+r"/tmp_rmstreak.fits"

      acis_lffile( pp["infile"], pp["fovfile"], pp["expmap"],
                   pp["tmpdir"], tmp_image, tmp_bkgroot, 
                   tmp_reg, tmp_streakmap, tmp_rmstreak, cv)

      lffile = tmp_rmstreak


   ## add FOV filter to image 
   #
   lffile = lffile+r'"[sky=region('+pp["fovfile"]+r')][opt full]"'


   ## run dmimgpm  
   #

   del_tmp_lfreq, tmp_lfreq = set_file_flag( pp["outlowfreqmap"],
                                    pp["tmpdir"]+r"/tmp_lfreq.fits" )

   tmp_smthexp = pp["tmpdir"]+"/tmp_smthexp.fits"
   exe_dmimgpm(lffile, tmp_lfreq, pp["expmap"], tmp_smthexp, pp["scale"], cv)


   if verb > 0 :
      print(".... creating normalized exposure maps ....\n")


   ## create normalized exposure map:  tmp_normexp
   #
   tmp_normexp = pp["tmpdir"]+"/tmp_normexp.fits"
   norm_expmap ( pp["expmap"], tmp_normexp, cv )


   ## create normalized exposure map:  tmp_normsmthexp
   #
   tmp_normsmthexp = pp["tmpdir"]+"/tmp_normsmthexp.fits"
   norm_expmap ( tmp_smthexp,  tmp_normsmthexp, cv)


   if verb > 0 :
      print(".... creating bkg map....\n")


   ##  combine the pieces
   #
   LL   = tmp_lfreq
   FF_n = tmp_normsmthexp
   EE_n = tmp_normexp

   if instr != "ACIS" :                #--- HRC 
      exe_dmimgcalc( LL+","+FF_n, "none", pp["outfile"],
                     "imgout=((float)(img1/img2))", cv ) 

   else :                              #--- ACIS
      SS = tmp_streakmap
      exe_dmimgcalc( LL+","+SS+","+FF_n+","+EE_n, "none", pp["outfile"],
                     "imgout=((float)((img1/img3)+(img2/img4)))", cv ) 

   expcorr = pp["expcorr"]
   if (expcorr[0] == "n") or (expcorr[0] == "N") :
      tmp_nc = pp["outfile"]+".nc"
      cv2 = r" clob+ verb="+repr(verb)     ### always set clob=yes here
      remove_files( tmp_nc )
      exe_dmcopy( pp["outfile"],  tmp_nc, cv2 )
      exe_dmimgcalc( tmp_nc+","+EE_n,  "none", pp["outfile"], 
                     "imgout=((float)(img1*img2))", cv2 )

   ## add hiistory
   # Get parameter name/values in order stored in parameter file
   pnames = paramio.plist("create_bkg_map")
   pvals = [pp[x] for x in pnames]
   add_history_to_file(pp["outfile"], "create_bkg_map", pnames, pvals)


   ##  cleanup temp files 
   #
   list = tmp_smthexp+" "+tmp_normexp+" "+tmp_normsmthexp
   if del_tmp_image == True :
      list += " "+tmp_image
   else : 
      add_history_to_file( tmp_image, "create_bkg_map", pnames, pvals)

   if del_tmp_lfreq == True :
      list += " "+tmp_lfreq
   else : 
      add_history_to_file( tmp_lfreq, "create_bkg_map", pnames, pvals)

   if (expcorr[0] == "n") or (expcorr[0] == "N") :
      list += " "+tmp_nc

   if instr == "ACIS" :
      rm_BKG=tmp_bkgroot+r"*.fits"
      list += " "+tmp_reg+" "+rm_BKG+" "+tmp_rmstreak

      add_history_to_file( tmp_reg, "create_bkg_map", pnames, pvals)
      add_history_to_file( rm_BKG, "create_bkg_map", pnames, pvals)
      add_history_to_file( tmp_rmstreak, "create_bkg_map", pnames, pvals)

      if del_tmp_streakmap == True :
         list += " "+tmp_streakmap
      else : 
         add_history_to_file( tmp_streakmap, "create_bkg_map", pnames, pvals)

   else :      # HRC
      list += " "+tmp_smoothed_image
      add_history_to_file( tmp_smoothed_image, "create_bkg_map", pnames, pvals)

   remove_files( list )


   ##  all done
   #
   if verb > 0 :
      print(".... done! ....\n")

   sys.exit(0)


if __name__=="__main__":
    main()
