#!/usr/bin/env python
# 
#  Copyright (C) 2006,2010-2011,2014-2016,2022  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


# 9/2011 - initial python version

import os
import sys
import glob
import paramio
import subprocess
import time
from numpy import *
from cxcdm import *
from region import *
from psf import *
from math import *

'''Use the history module to add HISTORY records to crate'''
from pycrates import read_file 
from pycrates import add_history, CrateKey 
from history import HistoryRecord

PI = 3.1415926535898

## get parameters
#
def get_params( toolpar ) :
   pp = {}
   try :
      pfile = paramio.paramopen(toolpar, "rw", sys.argv)
   except :
      error_out("ERROR: can't open parameter file "+toolpar )
   pp["infile"] = paramio.pgetstr(pfile,"infile")
   pp["arf"] = paramio.pgetstr(pfile,"arf")
   pp["outfile"] = paramio.pgetstr(pfile,"outfile")
   pp["region"] = paramio.pgetstr(pfile,"region")
   pp["x"] = float(paramio.pgetstr(pfile,"x"))       # sky_x
   pp["y"] = float(paramio.pgetstr(pfile,"y"))       # sky_y
   pp["energy"] = float(paramio.pgetstr(pfile,"energy"))
   pp["e_step"] = int(paramio.pgetstr(pfile,"e_step"))
   pp["radlim"] = float(paramio.pgetstr(pfile,"radlim"))
   pp["nsubpix"] = int(paramio.pgetstr(pfile,"nsubpix"))
   pp["nfracpix"] = int(paramio.pgetstr(pfile,"nfracpix"))
   pp["ecffile"] = paramio.pgetstr(pfile,"ecffile")
   pp["tmpdir"]   = paramio.pgetstr(pfile,"tmpdir")
   pp["clobber"] = paramio.pgetstr(pfile,"clobber")
   pp["verbose"] = paramio.pgetstr(pfile,"verbose")
   pp["mode"] = paramio.pgetstr(pfile,"mode")
   paramio.paramclose(pfile)
   return pp

## display message
def display_msg( msg ) :
    sys.stderr.write( msg +"\n" )

## error out w/ message
def error_out( msg ) :
    display_msg( msg )
    sys.exit(1)

## execute task (cmd); if there's a problem, error_out w/ message
def exe_task (tool, cmd) :
   paramio.punlearn( tool )
   status = os.system ( tool+" "+cmd )
   if status != 0  :
      error_out( "problem with "+tool+"\n")

## check for the output file existence :
## If found, return True
## Else return False
def find_output_file( File ) :
   if os.path.exists( File ) == True :
      return True
   else :
      return False

## if clobber is yes and the outfile exists, error out.
def check_clob( Clob, File ) :
   if find_output_file( File ) == True :
      if ( Clob[0] == "n") or ( Clob[0] == "N") :
         error_out("\nERROR: outfile '"+File+"' exists and clobber=no.\n")
      else :
         remove_files( File )

## is it a blank string ?
def isBlank( ss ) :
   if len(ss.strip(" ")) == 0 :
      return True
   else :
      return False

## remove a list of files
# allow any number of spaces between files.
# allow syntax w/ '*' ( eg.  *file*  or file* ).
#     eg. List=" f1 f2  any*  *any*  f4  "
def remove_files( List ) :
   for FN in List.split() :
      for fn in glob.glob( FN ):
         if isBlank(fn) == False :
            try:
               os.unlink( fn )
            except :
               pass

## Open an input table file :
## If fail to open, error out.
## Else return dmBlock
#
# note: no check on the existence of infile
#
def open_input_tabFile( File ) :
   inBlock = None
   try :
      inBlock = dmTableOpen( File )
   except :
      error_out("ERROR: can't open the file "+File+".\n" )

   return inBlock


## Open an input image file : 
## If fail to open, error out.
## Else return dmBlock, img_dd
#
# no check on the existence of infile
#
def open_input_imgFile( File ) :
   img_block = None
   img_dd    = None
   try :
      img_block = dmImageOpen( File )
   except :
      error_out("arfcorr: ERROR: can't open image file "+File+".\n" )

   try :
      img_dd = dmImageGetDataDescriptor( img_block )
   except :
      error_out("arfcorr: ERROR: can't open image data descriptor for "+File+".\n")

   return img_block, img_dd


## write the out_img data
def write_image(outfile, out_img, pstmp) :
   dimx = out_img.shape[0]
   dimy = out_img.shape[1]
   data_1dim = zeros ( dimx*dimy, float )

   iir = range( dimx )
   jjr = range( dimy )
   for ii in iir :
       for jj in jjr :
           idx = ii + dimx * jj
           data_1dim [ idx ] = out_img[ ii, jj]
   data_2dim = zeros ( [ dimy, dimx ], float )
   iir = range( dimy )
   jjr = range( dimx )
   for ii in iir :
       for jj in jjr :
           xxx = jj + dimx * ii
           data_2dim [ ii, jj ] = data_1dim[ xxx ]

   axes = array([ dimy, dimx ])
   img = dmImageCreate( outfile+"[PSFAPPRX]", float64, axes)

   img_dd = dmImageGetDataDescriptor( img )
   dmImageSetData( img_dd, data_2dim)

   names   = ['x', 'y']
   axis_dd = dmArrayCreateAxisGroup( img_dd, 'cntFrac', float64, "unit", names)
   dmCoordSetTransform ( axis_dd, pstmp["pcrpix"], pstmp["pcrval"], pstmp["pcdelt"])

   dmImageClose ( img )

   return

# end: write_image


def read_pstamp( pp, pstmp, debug ) :
   blk, img_dd = open_input_imgFile( pp["infile"] )
   dims  = dmGetArrayDimensions( img_dd );
   xx_dim = dims[ 0 ]
   yy_dim = dims[ 1 ]

   pstmp["size"] = [ yy_dim, xx_dim ]
   if (debug > 2) :
      display_msg("pstmp.size is : "+str(pstmp["size"][0])+"  "+str(pstmp["size"][1]))

   #  compute theta,phi ;
   cmd  = "infile='"+pp["infile"]+"' asol= opt=sky  x="+str(pp["x"])
   cmd += " y="+str(pp["y"])+"  verb=0"
   exe_task( "dmcoords", cmd )

   theta = "%.6f"%(float(paramio.pget("dmcoords","theta"))) #arcm
   phi   = "%.6f"%(float(paramio.pget("dmcoords","phi")))   #deg
   logicalx = "%.6f"%(float(paramio.pget("dmcoords","logicalx"))) 
   logicaly = "%.6f"%(float(paramio.pget("dmcoords","logicaly"))) 

   pstmp["theta"]    = float( theta ) 
   pstmp["phi"]      = float( phi )
   pstmp["log_pos"]  = {}                
   pstmp["log_pos"][0] = float( logicalx )  
   pstmp["log_pos"][1] = float( logicaly )

   pstmp["pcdelt"] = zeros( 2, float )
   pstmp["pcrpix"] = zeros( 2, float )
   pstmp["pcrval"] = zeros( 2, float )
   pstmp["wcdelt"] = zeros( 2, float )
   pstmp["wcrpix"] = zeros( 2, float )
   pstmp["wcrval"] = zeros( 2, float )
   gp1 = dmArrayGetAxisGroup(img_dd,1)
   gp2 = dmDescriptorGetCoord( gp1 )
   ( pstmp["pcrpix"], pstmp["pcrval"], pstmp["pcdelt"] ) = dmCoordGetTransform (gp1 )
   ( pstmp["wcrpix"], pstmp["wcrval"], pstmp["wcdelt"] ) = dmCoordGetTransform (gp2 )

   if (debug > 2) :
      display_msg("pstmp.pcrpix is : "+str(pstmp["pcrpix"][0])+"  "+str(pstmp["pcrpix"][1]))
      display_msg("pstmp.pcrval is : "+str(pstmp["pcrval"][0])+"  "+str(pstmp["pcrval"][1]))
      display_msg("pstmp.pcdelt is : "+str(pstmp["pcdelt"][0])+"  "+str(pstmp["pcdelt"][1]))
      display_msg("pstmp.wcrpix is : "+str(pstmp["wcrpix"][0])+"  "+str(pstmp["wcrpix"][1]))
      display_msg("pstmp.wcrval is : "+str(pstmp["wcrval"][0])+"  "+str(pstmp["wcrval"][1]))
      display_msg("pstmp.wcdelt is : "+str(pstmp["wcdelt"][0])+"  "+str(pstmp["wcdelt"][1]))

   pstmp["scale"] = abs(pstmp["wcdelt"][1])  * abs(pstmp["pcdelt"][1]) * 3600.0 

   pstmp["sky_pos"]    = zeros( 2, float )
   pstmp["sky_pos"][0] = pp["x"]
   pstmp["sky_pos"][1] = pp["y"]

   if (debug > 2) :
      cmd  = "pstmp centroid physical (sky) x, y were: "+str(pstmp["sky_pos"][0])+"  "
      cmd += str(pstmp["sky_pos"][1])
      display_msg( cmd )

      cmd  = "pstmp centroid msc coords (theta, phi) were: "+str(pstmp["theta"])+"  "
      cmd += str(pstmp["phi"])
      display_msg( cmd )

   return

# end: read_pstamp()


##  get the column info. from input arf file.
def get_colInfo ( infile, in_block, colname, tname) :
   try :
      col_ptr  = dmTableOpenColumn( in_block, colname)
   except :
      cmd = tname+": ERROR: unable to open "+colname
      cmd+= " in input file '"+infile+"'\n"
      error_out(cmd)
   col_type = dmGetDataType( col_ptr )
   col_unit = dmGetUnit( col_ptr )
   col_desc = dmGetDesc( col_ptr )
   return  col_ptr, col_type, col_unit, col_desc


##  calling read_ARF Only If we've the param for energy~=0, 
##  so, the ARF file is required
def read_ARF ( arf_in_F, e_step, debug, tname ) :
   #  no check on the existence of arf_in_F
   i_blk = open_input_tabFile( arf_in_F )

   #  check for empty table 
   nRows = dmTableGetNoRows( i_blk )
   if nRows < 1 :
      error_out(tname+" ERROR: No data in arf file "+arf_in_F+"\n")

   #  fill in the row data info from input arf file 
   inrow = {}
   inrow["block"] = i_blk 
   inrow["nRows"] = nRows

   #  column data
   inrow["elo"] = zeros( nRows, float )      # ENERG_LO
   inrow["ehi"] = zeros( nRows, float )      # ENERG_HI
   inrow["spr"] = zeros( nRows, float )      # SPECRESP

   #  get energ_lo, energ_hi, specresp columns info., including
   #  each column's ptr, dtype, units and description.
   #  e.g. ptr["elo"], ptr["ehi"], ptr["spr"]..etc
   ptr  = {}         # pointer
   type = {}         # dtype  
   unit = {}         # units  
   desc = {}         # comments
   ptr["elo"],type["elo"],unit["elo"],desc["elo"]=get_colInfo(arf_in_F,i_blk,"ENERG_LO",tname)
   ptr["ehi"],type["ehi"],unit["ehi"],desc["ehi"]=get_colInfo(arf_in_F,i_blk,"ENERG_HI",tname)
   ptr["spr"],type["spr"],unit["spr"],desc["spr"]=get_colInfo( arf_in_F,i_blk,"SPECRESP",tname)

   # get data for ENERG_LO, ENERG_HI, SPECRESP in ALL rows
   dmTableSetRow( i_blk, 1 )
   inrow["elo"] = dmGetData( ptr["elo"],  1, nRows ) 
   inrow["ehi"] = dmGetData( ptr["ehi"],  1, nRows )    # energies_full
   inrow["spr"] = dmGetData( ptr["spr"],  1, nRows ) 

   # store the first energy, and every "e_step" energies after
   # then store the last energy from the ARF at the end
   energies_samp = {}
   energies_samp[ 0 ]  = inrow["ehi"][0]
   nn = 1;  ss = 1 ;  

   kkkr = range( 1, nRows )
   for kkk in kkkr :
      if (nn == e_step) :
         energies_samp[ ss ] = inrow["ehi"][kkk] 
         ss += 1
         nn = 1
      else :
         nn += 1
   
   # always store the last energy from the ARF.
   if (nn > 1) :
      energies_samp[ ss ] = inrow["ehi"][nRows-1]

   return energies_samp, inrow, ptr,type,unit,desc

# end: read_ARF ( pp, pstmp ) :


def calc_cf_heights ( rad, ecf, debug ) :
  hh = zeros( 101, float )   # hh = count fractions at a given radius
  iir = range (99, -1, -1)
  for ii in iir :
     ra = rad[ii]
     rb = rad[ii+1]
     ra2   = ra*ra
     rb2   = rb*rb
     rarb2 = ra*rb2
     ra3   = ra2*ra
     rb3   = rb2*rb
     if (abs(rb-ra) == 0.0) :
        error_out("arfcorr: ERROR: ecf radius values at "+repr(ii)+" and "+repr(ii+1)+" are identical.\n")

     frac = (rarb2 - (1.0/3.0)*ra3 - (2.0/3.0)*rb3)/(rb - ra)
     denom = rb2 - ra2 + frac
     if (abs(denom) == 0) :
        error_out ("arfcorr: ERROR: count fraction height calculation denominator is zero.\n")

     #  forumla for hh derived from solving axially symmetric integral equation:
     #  dECF = integral from radius = r_n to r_n+1 of (2*PI*radius)*(height)*dRADIUS
     #
     #  and solving for  h_n given  h_n+1 ( h_max assumed to be zero for finite PSF)
     #  height assumed to vary linearlly between r_n and r_n+1
     hh[ii] = ((ecf[ii+1] - ecf[ii])/PI) +  hh[ii+1] * frac
     hh[ii] = hh[ii] / denom
 
     if (debug > 2) :
        display_msg ("ii, CF height are: "+str(ii)+"  "+str(hh[ii]))

        if (debug > 3) :
           display_msg("    frac, denom are: "+str(frac)+"  "+str(rb2 - ra2))
  # end: for ii in range ( 99, -1, -1 ) : 

  return hh

# end: calc_cf_heights


##  return the value (height) of the cf (counts fraction) function
##  at a specific radius
def get_cf_height( dd, rad, hh, max_rad, debug ) :

  #  if the distance is outside the maximum radius, the cf height
  #  (and therefore volume = height*area) is zero for this pixel.
  #
  if ( dd < max_rad) :
    #  Find first radius in rad array greater than distance to pixel
    #  (i.e., dist is between rad[radindex-1] and rad[radindex])
    t_modes = where ( dd < rad )
    modes   = t_modes[0] 
    radindex = modes[0] 

    if (debug > 4) :
      display_msg("get_cf_h:   dist = "+str(dd)+",  radindex = "+ str(radindex))

    # linearly interpolate count fraction (h) between radindex & radindex-1
    radfrac    = (dd - rad[radindex-1])/(rad[radindex]-rad[radindex-1])
    if (debug > 4) :
      display_msg("get_cf_h:   radfrac = "+str(radfrac))

    cf_height = hh[radindex-1] + radfrac*( hh[radindex] - hh[radindex-1])

    if (debug > 4) :
      display_msg("get_cf_h:   cf_height = "+str(cf_height))
  else :
    cf_height = 0.0

  return cf_height

# end: get_cf_height


##  get_pixel_height calculates pixel height by dividing the pixel
##  into (npix x npix) smaller pixels (subpixelating), calculating
##  the cf_height at the center of each smaller pixel, and then
##  taking the average of all subpixel heights to use as the height
##  of the input pixel.
def get_pixel_height(pstmp,rad,hh,max_rad,npix,x_pix_cntr,y_pix_cntr,pixel,debug):
  sum   = 0.0
  npix2 = npix*npix      # npix = pp["nsubpix"]

  #  set starting point in pixel where cf_height will be calculated
  #  based on subpixelization of pixel into (npix x npix) subpixels.
  offset = -0.5 + 1.0/(2.0*npix)

  #  loop through subpixel center positions (npix x npix) to get
  #  cf_height at multiple points within this single pixel
  iir = range( npix )
  for ii in iir :
    rx = x_pix_cntr+offset+double(ii)/npix
    for jj in iir :
      #  identify centers of subpixels.  Input pixel will
      #   be divided into (npix x npix) subpixels.
      ry = y_pix_cntr+offset+double(jj)/npix

      if (debug > 4) :
        cmd =  "get_pix_h: ii, jj, rx, ry are:  "+str(ii)+", "+ str(jj)+", "
        cmd += str(rx)+", "+str(ry)
        display_msg( cmd )

      #  calculate distance to subpixel from image center
      #  and convert from logical pixels to arcseconds
      dist = hypot(rx, ry) * pstmp["scale"]

      #  get height at this subpixel
      cf_height = get_cf_height ( dist, rad, hh, max_rad, debug)

      if (debug > 1) :
        #  track summed pixel values within a given radius, to see how well
        #  summed values compare to the expected enclosed count fraction.
        if (dist <= rad[25]) :
          pixel["sum25"] += cf_height/npix2
          pixel["sum50"] += cf_height/npix2
          pixel["sum75"] += cf_height/npix2
        elif (dist <= rad[50]) :
          pixel["sum50"] += cf_height/npix2
          pixel["sum75"] += cf_height/npix2
        else :
          if (dist <= rad[75]) :
             pixel["sum75"] += cf_height/npix2
      #end: if (debug > 1) :

      #  add heights at all subpixels together
      sum += cf_height

    # end: for jj in range ( 0, npix) :
  # end: for ii in range ( 0, npix) :

  #  divide summed heights by number of pixels to calculate avg height
  pixel["height"] = sum / npix2

  return

# end: get_pixel_height


def fill_cf_img( pstmp, rad, hh, npix, debug ) :    

  img = zeros ( [ pstmp["size"][0], pstmp["size"][1] ], float ) 

  #  convert input PSF centroid from physical (sky) x,y
  #  to logical (image) x,y coordinates
  xcen = pstmp["log_pos"][0] 
  ycen = pstmp["log_pos"][1]

  if (debug > 3) :
    display_msg ("xcen, ycen are: "+repr(xcen)+" "+repr(ycen))

  #  pstmp.scale is the logical pixel size (length of one side of the
  #  square pixel) in arcseconds
  pixel_area = pstmp["scale"] * pstmp["scale"] 
  max_rad    = rad[ len ( rad ) - 1 ]

  pixel = {}
  pixel["sum25"] = 0.0
  pixel["sum50"] = 0.0
  pixel["sum75"] = 0.0

  iir = range( pstmp["size"][0] )
  jjr = range( pstmp["size"][1] )
  for ii in iir :
    rx = abs(xcen-(ii+1))
    for jj in  jjr :
      #  Calculate absolute value of distances between center of psf
      #  and center of current pixel in logical (image) coords.
      #
      #  Note: logical (image) coords range from 1 to pstamp.size but
      #  indices ii and jj range from 0 to pstamp.size-1.  Therefore, ii+1
      #  and jj+1 are the correct values of the pixel center in image coords.
      #
      ry = abs(ycen-(jj+1))

      get_pixel_height (pstmp, rad, hh, max_rad, npix, rx,ry, pixel, debug)
      img[ii,jj] = ( pixel["height"] ) * pixel_area

      # don't allow negative psf values
      if (img[ii,jj] < 0.0) :
        img[ii,jj] = 0.0

      if (debug > 4) :
         display_msg("fill_cf_img: pix cntr ii+1,jj+1,rx,ry,pix value are: ["+repr(ii+1)+", "+repr(jj+1)+" ] ["+repr(rx)+", "+repr(ry)+" ] "+repr(img[ii,jj]))

    # end: for jj in range(0, pstmp["size"][1]) :
  # end: for ii in range(0, pstmp["size"][0]) :

  if (debug > 1) :
    pixel["sum25"] *= pixel_area
    pixel["sum50"] *= pixel_area
    pixel["sum75"] *= pixel_area
    display_msg("25% ecf table radius encloses "+repr(pixel["sum25"])+" img pixel counts")
    display_msg("50% ecf table radius encloses "+repr(pixel["sum50"])+" img pixel counts")
    display_msg("75% ecf table radius encloses "+repr(pixel["sum75"])+" img pixel counts")

  return  img

# end:  fill_cf_img


def calc_psf_frac_in_reg( img, pstmp, srcreg, nfracpix, debug ) :
  subpixfrac = 1.0/(nfracpix*nfracpix)
  psf_frac = 0.0
  rinr = 0

  #   pre-compute parts of the calculation of the pixel's 
  #   sky x,y coords to improve computation speed, from:
  #   pix_sky_x = (ii+1-pstmp.pcrpix[0])*pstmp.pcdelt[0] + pstmp.pcrval[0]
  #   pix_sky_y = (jj+1-pstmp.pcrpix[1])*pstmp.pcdelt[1] + pstmp.pcrval[1]
  #
  #   Note: logical (image) coords range from 1 to pstamp.size but 
  #   indices ii and jj range from 0 to pstamp.size-1.  Therefore, ii+1
  #   and jj+1 are the correct values of the pixel center in image coords.
  #
  temp = [ -pstmp["pcrpix"][0]*pstmp["pcdelt"][0] + pstmp["pcrval"][0],
           -pstmp["pcrpix"][1]*pstmp["pcdelt"][1] + pstmp["pcrval"][1] ]

  if (debug > 3) :
     display_msg("calc_psf_frac: Subpix frac is :  "+str(subpixfrac))
     display_msg("calc_psf_frac: temp[0-1] is :  "+str(temp[0])+", "+str(temp[1]))

  #  loop through all PSF CF image pixels.  If any pixel is inside the
  #  source region, add the value of that pixel (the fraction of the
  #  counts represented by that pixel's area) to the running total of 
  #  the fraction of PSF counts within the source region.
  #
  iir = range( pstmp["size"][0] )
  jjr = range( pstmp["size"][1] )
  kkr = range( nfracpix )

  for ii in iir :
     for jj in jjr :
        # ii_shift and jj_shift subpixelate the input pixel to check
        # (nfracpix x nfracpix) points within the input pixel to see
        # whether or not that point is within the region.
        # 
        for kk in kkr :
           ii_shift = -0.5+1.0/(2.0*nfracpix)+double(kk)/nfracpix

           for ll in kkr :
              jj_shift = -0.5+1.0/(2.0*nfracpix)+double(ll)/nfracpix

              pix_sky_x = (ii+1+ii_shift)*pstmp["pcdelt"][0] + temp[0]
              pix_sky_y = (jj+1+jj_shift)*pstmp["pcdelt"][1] + temp[1]

              if (regInsideRegion( srcreg, pix_sky_x, pix_sky_y)) :
                 psf_frac += img[ii,jj] * subpixfrac
                 rinr = 1
              else :
                 rinr = 0

              if (debug > 4) :
                 cmd =  "In Reg: "+'rinr'+"  ii,jj: "+repr(ii)+", "+repr(jj)+"  Sky x,y: "  
                 cmd += repr(pix_sky_x)+", "+repr(pix_sky_y)+"  pix psf_frac: "+repr(img[ii,jj])
                 cmd += "  total psf_frac: "+repr(psf_frac)
                 display_msg( cmd )

           # end: for (ll=0;ll<nfracpix;ll++)
        # end: for (kk=0;kk<nfracpix;kk++) :
     # end: for (jj in range(0, pstmp["size"][1] ) :
  # end: for (jj in range(0, pstmp["size"][1] ) :

  #  PSF_FRAC can be greater than 1 due to the quantization/binning
  #  type error of approximating the volume under a square pixel with
  #  a single height when, in fact, the height changes continually over
  #  the surface of the pixel.  Since PSF_FRAC should never be > 1,
  #  reset it to if it is found to be greater.
  #
  if (psf_frac > 1.0) :
     psf_frac = 1.0

  return  psf_frac

#end: calc_psf_frac_in_reg


def init_out_arf( inrow, outfile, ecffile, region, cmd_line, ptr,type,unit,desc):
    
  o_blk = dmTableCreate( outfile+"[SPECRESP]" )

  #  copy header & subspace info from input ARF
  dmBlockCopy(inrow["block"], o_blk, "HEADER")
  dmBlockCopy(inrow["block"], o_blk,"SUBSPACE")

  dmTableClose(inrow["block"] )

  #  add ECF related header keywords, update history
  dmKeyWrite(o_blk, "ECFFILE", ecffile, "", "ECF file used to generate approximate PSF")
  dmKeyWrite(o_blk, "ECFREG", region, "", "Value of supplied region parameter")
  dmBlockWriteComment(o_blk, "HISTORY", cmd_line)
    
  #  create new output col
  #...(1) ENERG_LO 
  out_elo_col = dmColumnCreate(o_blk, "ENERG_LO", type["elo"], 
                unit=unit["elo"], desc=desc["elo"])
  #...(2) ENERG_HI
  out_ehi_col = dmColumnCreate(o_blk, "ENERG_HI", type["ehi"], 
                unit=unit["ehi"], desc=desc["ehi"])
  #...(3) SPECRESP
  out_spr_col = dmColumnCreate(o_blk, "SPECRESP", type["spr"], 
                unit=unit["spr"], desc=desc["spr"])
  #...(4) new PSF_FRAC
  out_psffrac_col = dmColumnCreate(o_blk, "PSF_FRAC", float64, 
                    unit="", desc="Fraction of PSF enclosed in SRC region")
  dmTableClose(o_blk)

# end: init_out_arf


## variables 
##      inrow     :  arffile's data info from read_ARF()
##      outfile   :  par[ outfile ] ;  
##      psf_frac  :  the computed psf_frac_full_array
##      tname     :  tool_name
def write_updated_arf(inrow, outfile, psf_frac, tname):
  total_row = inrow["nRows"]
  if ( total_row != len( psf_frac )):
     error_out(tname+": ERROR: total psf_frac mismatch w/ input ARF file.\n")
  tab = dmTableOpen( outfile, update=True )

  col_1 = dmTableOpenColumn( tab, "ENERG_LO")
  dmSetData( col_1,  inrow["elo"] )

  col_2 = dmTableOpenColumn( tab, "ENERG_HI")
  dmSetData( col_2,  inrow["ehi"] )

  val3 = zeros( total_row, float )
  val4 = zeros( total_row, float )
  uwr = range(total_row)
  for uw in uwr :
     val3[ uw ] = inrow["spr"][ uw ] * psf_frac[ uw ]  # psf_frac_full_energy[]
     val4[ uw ] = psf_frac[ uw ]                       # psf_frac_full_energy[]
  col_3 = dmTableOpenColumn( tab, "SPECRESP")          
  dmSetData( col_3, val3)
  col_4 = dmTableOpenColumn( tab, "PSF_FRAC")         
  dmSetData( col_4,  val4 )

  dmTableClose(tab)

  return

# end: write_updated_arf


def interpolate_psf_frac( psf_frac, energies_samp, energies_full, debug, tname) :  
  ee=0       # elem of energies_full[]
  ii=0       # idx for energies_samp[]
  lll = len( psf_frac )
  ff=0 ; pp=0 ;
  psf_frac_full_array = {}
  uur = range( len(energies_full) )
  for uu in uur :
    ee = energies_full[ uu ] 

    if (ii == 0 ) :
      psf_frac_full_array[ uu ] = psf_frac[ii]
      ii += 1
    else :
      if ( ee == energies_samp[ ii ] ) :
        psf_frac_full_array[ uu ]  = psf_frac[ii]
        ii += 1
      else :
        #  check for logic error - last sampled energy value must equal last
        #  full energy value, so we shouldn't be here unless an energy, ee,
        #  somehow mistankenly is present after the last sampled energy,
        #  energies_samp[ll-1]
        if ( ii == lll ) :
          error_out(tname+": ERROR: psf_frac interpolation error: too few sampled energies.\n")

        #  interpolate between sampled energy points
        ff = (ee-energies_samp[ii-1]) / (energies_samp[ii]-energies_samp[ii-1])
        pp = psf_frac[ii-1] + ff * ( psf_frac[ii] - psf_frac[ii-1] )
        psf_frac_full_array[ uu ]  =  pp

      # end: else of if (ee==energies_samp[ii])
    # end: else of if (ii==0)
  # end: for uu in range(0, len(energies_full))

  return  psf_frac_full_array

# end: interpolate_psf_frac


def add_history_to_file(file_name, tool_name, param_names, param_values):
   if os.path.exists( file_name) == False :
      return
   crate = read_file(file_name, mode="rw")

   histnum = crate.get_key("HISTNUM")
   if histnum is None:
      histnum = CrateKey()
      histnum.name = "HISTNUM"
      histnum.value = 1
      crate.add_key(histnum)
        
   startat = histnum.value
   startat = 1 if startat is None else startat

   hh = HistoryRecord(tool=tool_name,param_names=param_names,
                      param_values=param_values, asc_start=startat)    
   add_history(crate, hh) 
   
   # Increment HISTNUM for new lines
   hlen = hh.as_ASC_FITS().split("\n")
   histnum.value = startat+len(hlen)

   crate.write()
# end: add_history_to_file 


####
# --- main code
#
def main() :

   tname = "arfcorr"     # tool_name
   ##  --- get parameters
   pp = get_params( "arfcorr" )
   debug  = int ( pp["verbose"] )

   if (debug > 1 ) :
      cmd =  "Input Parameters:\n"+"  pstampfile:   "+pp["infile"]+"\n"
      cmd += "   outfile:   " + pp["outfile"]+"\n"
      cmd += "    region:   " + pp["region"]+"\n"
      cmd += "    energy:   " + str(pp["energy"])+"\n"
      cmd += "         x:   " + str(pp["x"]) +"\n"
      cmd += "         y:   " + str(pp["y"]) +"\n"
      cmd += "       arf:   " + pp["arf"]+"\n"
      cmd += "    e_step:   " + str(pp["e_step"]) +"\n"
      cmd += "    radlim:   " + str(pp["radlim"]) +"\n"
      cmd += "   nsubpix:   " + str(pp["nsubpix"]) +"\n"
      cmd += "  nfracpix:   " + str(pp["nfracpix"]) +"\n"
      cmd += "   ecffile:   " + pp["ecffile"]+"\n"
      cmd += "    tmpdir:   " + pp["tmpdir"]+"\n"
      cmd += "   clobber:   " + pp["clobber"]+"\n"
      cmd += "     debug:   " + pp["verbose"]+"\n"
      display_msg( cmd )

   ##  check for outfile
   outfile = pp["outfile"]
   check_clob( pp["clobber"], outfile)

   # convert wcs region to physical coords via dmmakereg
   rfile = pp["region"]
   mypth = os.getenv("ASCDS_INSTALL")
   if mypth != None:
      mypth += "/bin/dmmakereg"
   else:
      mypth = "dmmakereg"
   if os.access(pp["tmpdir"], os.W_OK) == False:
      error_out("\nERROR: tmpdir setting ("+pp["tmpdir"]+") does not have write permissions- can not generate necessary temporary files.\n")

   tmpth = pp["tmpdir"] + "/arfcor_tmp_" + str(os.getpid()) 

   if (debug > 2) :
      display_msg("running dmmakereg on "+ rfile + " to ensure coordinates are in physical coords")
      
   pproc=subprocess.Popen(['dmmakereg', rfile, 'outfile='+tmpth, 'wcsfile='+pp["infile"], 'kernel=fits','cl+'], shell=False)

   # Allow upto 1 minute for dmmakereg to complete then kill process
   mytime = 60
   while (mytime):
      time.sleep(.5)
      mysts=pproc.poll()
      if mysts != None:
         mytime = 0
      else:
         mytime -= .5
   mysts=pproc.poll()
   if mysts == None:
      pproc.terminate()
      if (debug > 2):
         display_msg("dmmakereg failed to complete successfully in a reasonable amount of time- using the original region file.")
   elif mysts == 0:
      rfile = 'region('+tmpth+')'
      if (debug > 2):
         display_msg("using the output of dmmakereg as the region")

   srcreg = None
   try :
       srcreg = regParse(rfile)
   except :
      msg=tname+": ERROR:  Could not parse region parameter : '"+rfile+"'\n"
      msg+='         Note :  the syntax to use a file is :  region(sources.fits)\n'
      error_out( msg+"\n")

   ##  open infile
   pstmp = {}
   read_pstamp( pp, pstmp, debug )

   energies_samp = {} 
   arf_in_F = pp["arf"]
   penergy  = pp["energy"]
   e_step   = pp["e_step"]
   inrow={}; ptr={}; type={}; unit={}; desc={};

   if ( penergy < 0.000001) :
      if (debug > 2) :
         display_msg("energy near 0, = "+str(penergy) )
      
      #  read energy array from arf
      energies_samp, inrow,ptr,type,unit,desc=read_ARF(arf_in_F,e_step,debug,tname)
   else :
      if (debug > 2) :
         display_msg("energy > 0, = "+repr(penergy))
      energies_samp = zeros( 1, float)
      energies_samp[0] = penergy

   #  loop over all eneriges for PSF lib
   rad =  zeros( 101, float) 
   rad_prior = 0
   num_energies = len( energies_samp )
   if ( debug > 2 ) :
      display_msg("num_energies = "+repr(num_energies))
      display_msg("energies_samp are : "+repr(energies_samp))

   ecf = arange(0, 1.005, 0.01)    # 0 to 1 by .01, inclusive
   psf_frac = zeros ( num_energies, float)
   psf = psfInit( pp["ecffile"] )

   jjr =  range( num_energies )
   iir =  range( 100 )
   for jj in jjr :
      #  read the correct ECF radii for each energy
      for ii in iir :
         rad[ii]=psfSize(psf,energies_samp[jj],pstmp["theta"],pstmp["phi"],ecf[ii])
         if (debug > 3) :
            if (rad[ii] != 0.0) :
              count_frac_density = 0.01/(3.14159*(rad[ii]*rad[ii]-rad_prior*rad_prior))
            else :
              count_frac_density = 0.0
            rad_prior = rad[ii]

            ecf_s = "%.6f"%(ecf[ii])
            rad_s = "%.6f"%(rad[ii])
            cfd_s = "%.6f"%(count_frac_density)
            display_msg("ii,ecf,rad,ecf_density are: "+repr(ii)+" "+ecf_s+" "+rad_s+" "+cfd_s)
         # end: if debug > 3 :
      # end: for ii < 100

      if (debug > 3) :
         display_msg ("ecf, rad sizes are:  "+repr(len(ecf))+"  "+repr(len(rad)))
   
      #  set last element of radius array.  rad[100] is the finite
      #  approximation to the infinite tail-off of the PSF.  In the
      #  approximation, all of the PSF is within rad[100].
      rad[100] = rad[99] * pp["radlim"]

      #  calculate count fraction height array values
      height = calc_cf_heights( rad, ecf, debug )

      #  create count fraction image
      npix = pp["nsubpix"]
      img = fill_cf_img(pstmp, rad, height, npix, debug )

      #  write img when only one single energy was specified
      if (num_energies == 1) :
        write_image(outfile, img, pstmp)

      #  calculate psf fraction inside source region
      psf_frac[jj] = calc_psf_frac_in_reg(img, pstmp, srcreg, pp["nfracpix"], debug)
      if (debug > 0) :
        cmd="Energy, PSF fraction inside source region (psf_frac["+str(jj)+"]) are: "
        cmd+=str(energies_samp[jj]) + ",  "+str( psf_frac[jj] )
        display_msg( cmd ) 
   # end:  for jj < num_energies


   #  check for > 1 energy: need to interpolate PSF fractions and write ARF
   if (num_energies != 1) :
      #  interpolate PSF_FRAC values to all energies in ARF
      psf_frac_full_array = interpolate_psf_frac( psf_frac, energies_samp,
                            inrow["ehi"], debug, tname)

      #  write updated ARF with PSFFRAC and EFFAREA cols when 
      #  energies were read from ARF
      cmd = tname+' infile='+pp["infile"]+' '+'arf='+pp["arf"]+' '
      cmd += 'outfile='+outfile+' '+'region='+pp["region"]+' '
      cmd += 'x='+str(pp["x"])+' '+'y='+str(pp["y"])+' '+'energy='
      cmd += str(pp["energy"])+' '+'e_step='+str(pp["e_step"])+' '+'radlim='
      cmd += str(pp["radlim"])+' '+'nsubpix='+str(pp["nsubpix"])+' '+'nfracpix='
      cmd += str(pp["nfracpix"])+' '+'ecffile='+pp["ecffile"]+' '+'clobber='
      cmd += pp["clobber"]+' '+'verbose='+str(pp["verbose"])
      init_out_arf(inrow,outfile,pp["ecffile"],pp["region"],cmd,ptr,type,unit,desc)
      write_updated_arf( inrow, outfile, psf_frac_full_array, tname)
   # end: if (num_energies != 1) 

   # Get parameter name/values in order stored in parameter file
   pnames = paramio.plist("arfcorr")
   pvals = [pp[x] for x in pnames]

   # Add history to either image or modified arf
   add_history_to_file(pp["outfile"], "arfcorr", pnames, pvals)

   # clean up region file generated by dmmakreg - verify the filename contains
   # the process pid as a sanity check 
   if tmpth.find(str(os.getpid())) != -1 and os.path.isfile(tmpth): 
      os.remove(tmpth)

   ##  all done
   if debug > 0 :
      display_msg(".... done! ....\n")

   sys.exit(0)

# end: def main 

if __name__=="__main__":
    main()
