#!/opt/conda/envs/ciao-4.17.0/bin/python3
# 
#  Copyright (C) 2011-2013,2016,2021-2022  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


# 1/2012 - initial python version 
# 3/2012 - add ddof ; replace calquiz w/ pixlib ;
# 4/2012 - (dph) change LETG/ACIS zo exclusion to a strip instead of a circle 
#                in step #11 :  Select streak ..."   (TBR)
# 4/2012 - get sky in double
# 3/2013 - add keys to the primary header
#
# 5/7/2013 - bug#13612 ( add a new way to compute c0 for y=c0+x )
# 8/20/2013 - add two new keys to the src1a's header.
#                angle_st ( readout/streak angle in degree )
#                angle_gt ( grating arm angle in degree )
#

import os
import sys
import glob
import paramio
import string
from math import sqrt
import numpy as np 
from cxcdm import *       
# 5/7/2013 - import sherpa.astro.ui as ui
import logging
from pixlib import *

'''Use the history module to add HISTORY records to crate'''
from pycrates import read_file 
from pycrates import add_history, CrateKey
from history import HistoryRecord


def _dmKeyRead (blk, name):
   '''
   Overload the dmKeyRead to support python3 and python2 to return str
   '''

   descriptor, value  = dmKeyRead( blk, name )

   # for python3 the data will come out in bytes not str - convert it to a str
   try:
      value = value.decode ('utf8')
   except:
      pass

   return descriptor, value


class _TZO() :
   ## get parameters
   def get_params( s, tname) :
      try :
         pfile = paramio.paramopen( tname, "rw", sys.argv)
      except :
         s.error_out("ERROR: can't open parameter file "+tname)

      # Get parameter name/values in order stored in parameter file
      s.pnames= paramio.plist( pfile )
      s.pvals = [ paramio.pget(pfile, pp) for pp in s.pnames]

      s.infile = paramio.pgetstr(pfile,"infile")
      s.outfile = paramio.pgetstr(pfile,"outfile")
      s.zo_pos_x = paramio.pgetstr(pfile,"zo_pos_x")   #string; default
      s.zo_pos_y = paramio.pgetstr(pfile,"zo_pos_y")   #string; default
      s.clob  = paramio.pgetstr(pfile,"clobber")
      s.debug = int(paramio.pgetstr(pfile,"verbose"))
      s.mode = paramio.pgetstr(pfile,"mode")
      paramio.paramclose(pfile)

   ## execute task (cmd); if there's a problem, error_out w/ message
   #
   def exe_task (s, tool, cmd) :
      paramio.punlearn( tool )
      status = os.system ( tool+" "+cmd )
      if status != 0  :
         s.error_out( "problem with "+tool+"\n")

   ## display message
   def display_msg( s, msg ) :
      sys.stderr.write( msg +"\n" )

   ## error out w/ message
   def error_out( s, msg ) :
      s.display_msg( msg )
      os._exit(1)

   ## check for the output file existence :
   ## If found, return True 
   ## Else return False
   #
   def find_output_file( File ) :
      if os.path.exists( File ) == True :
         return True
      else :
         return False

   ## if clobber is yes and the outfile exists, error out.
   #
   def check_clob( s, Clob, File ) :
      if s.find_output_file( File ) == True :
         if ( Clob[0] == "n") or ( Clob[0] == "N") :
            s.error_out("\nERROR: outfile '"+File+"' exists and clobber=no.\n")
         else :
            s.remove_files( File )

   ## is it a blank string ?
   #
   def isBlank( s, inS ) :
      if len(inS.strip(" ")) == 0 :
         return True
      else :
         return False

   ## remove a list of files
   #
   # allow any number of spaces between files.
   # allow syntax w/ '*' ( eg.  *file*  or file* ).
   #     eg. List=" f1 f2  any*  *any*  f4  "
   #
   def remove_files( s, List ) :
      for FN in List.split() :
         for fn in glob.glob( FN ):
            if s.isBlank(fn) == False :
               try:
                  os.unlink( fn )
               except :
                  pass
   ## 
   #
   def evtKeys( s ) :
     list_s = ["GRATING","DETNAM","READMODE"]
     list_f = ["ROLL_NOM","RA_TARG","DEC_TARG"]

     for kp in list_s+list_f :
       try :
          _dd, _vv  = _dmKeyRead( s.inBlk, kp )   # '_vv' is in float 

          if kp in list_f :
             s.k_v[ kp ] = _vv
          else :    # list_s :
             s.k_v[ kp ] = _vv.upper()
       except :
          s.error_out("ERROR: "+kp+" is missing in '"+ s.infile +"'")

     s.verify_keys()
     s.set_corrections()

     # - set exptime
     try :
        _dd, _vv  = _dmKeyRead( s.inBlk, "EXPOSURE" )  
     except :
        try :
           _dd, _vv  = _dmKeyRead( s.inBlk, "LIVETIME" )
        except :
           _vv = 1.0
     s.exptime = _vv
   # end: evtKeys

   ## 
   # -- Error out if values not expected.
   #
   def verify_keys ( s ) :
      ## 4.1 :   Error if GRATING is NONE
      uv = s.k_v[ "GRATING" ]
      if  uv=="NONE" :
         s.error_out("ERROR: GRATING key can't be NONE.\n" )

      ## 4.1 :   Error if DETNAM/INSTRUMENT is HRC or ACIS-I
      uv = s.k_v[ "DETNAM" ]
      if (uv.find("ACIS") != -1 ) :      # acis
        if uv.find("0")>-1 or uv.find("1")>-1 or uv.find("2")>-1 or uv.find("3")>-1:
           s.error_out("ERROR: tool is not for acis-I.\n")    #acis-I
      else :
         s.error_out("ERROR: tool is not for HRC.\n")   # hrc

      #
      ## 4.1 :   Error if READMODE is not TIMED
      uv = s.k_v[ "READMODE" ]
      if uv != "TIMED" : 
         s.error_out("ERROR: READMODE is not TIMED.\n")    #acis-I

      #
      ## 7 : 
      # if parFile("zo_pos_x/_y") is not numerical (eg. ="default"),then
      #    set s.zo_pos_x/_y to sky_targ[] ( ie. from RA_TARG/DEC_TARG )
      # otherwise
      #    set s.zo_pos_x/_y to the parFile("zo_pos_x/_y")
      #
      s.upd_zo_pos_xy()
   # end: verify_keys()


   # ###### 
   ## 2 : Read grating angle alpha from CALDB geom file :
   #
   def set_corrections( s ) :

     p=Pixlib("chandra","geom")

     p.grating="meg"
     s.alpha_meg=p.grt_prop[1]

     p.grating="leg"
     s.alpha_leg=p.grt_prop[1]

     if s.k_v["GRATING"] == "HETG" :
        s.alpha         = s.alpha_meg
        s.alpha_corr    = s.alpha_corr_meg
        s.streak_offset = s.streak_offset_meg
        s.acis_s_rotation_corr = s.acis_s_rotation_corr_meg

     else :                # LETG
        s.alpha         = s.alpha_leg
        s.alpha_corr    = s.alpha_corr_leg
        s.streak_offset = s.streak_offset_leg
        s.acis_s_rotation_corr = s.acis_s_rotation_corr_leg


     s.rotang           =  s.k_v["ROLL_NOM"] + s.acis_s_rotation_corr
     s.grating_angle    =  s.alpha + s.alpha_corr - s.acis_s_rotation_corr
     s.rotang_rad       =  s.rotang * np.pi / 180
     s.grating_angle_rad = s.grating_angle * np.pi / 180.0

     # 8/20/2013
     s.grating_arm_angle =  - (s.rotang + s.grating_angle)
     s.grating_arm_angle = s.check_angle_range( s.grating_arm_angle )
     s.readout_angle =  90.0 - s.rotang
     s.readout_angle = s.check_angle_range( s.readout_angle )

   # end: set_corrections()


   #
   # 8/2013: check the angle range and adjust it to  '0 to 360.0'  degree if needed.
   #
   def check_angle_range (s, angle_deg ) :
      deg360 = 360.0
      adj_angle = angle_deg
      while ( adj_angle < 0.0 ) :
         adj_angle += deg360

      while ( adj_angle > deg360 ) :
         adj_angle -= deg360

      return adj_angle

   # end: check_angle_range


   #
   ## 7 :  zo_pos_xy from params "zo_pox_x & _y" or from evt keys ra/dec_targ
   #
   def upd_zo_pos_xy (s) :
      # -- check param("zo_pox_x/_y")
      def zo_xy_par ( parStr ) :
         flag = 0 
         # -- set flag to 1 Only If parStr is a positive number
         try:
            aa = float( parStr )   # numerical
            if ( aa > 0.0 )  :
               flag = 1            # parStr is a positive number 
         except :                  # parStr is not numerical
            aa = 0
         return aa, flag 

      # convert ra/dec to sky 
      def sky_targ(file,radeg,decdeg) :
         cmd = "infile="+file+" asol= opt=cel ra="+str(radeg)
         cmd +=" dec="+str(decdeg)+" celfmt=deg"
         s.exe_task("dmcoords", cmd)
         s.sky_targ = np.zeros(2, float)
         s.sky_targ[0] = paramio.pget("dmcoords", "x")
         s.sky_targ[1] = paramio.pget("dmcoords", "y")

      #
      # If params of zo_pos_x/_y are Both positive numbers, 
      # then use them for s.zo_pos_x/y.
      # Else, set s.zo_pos_x/y to the values of evt keys : ra/dec_targ.
      #
      sky_targ(s.infile, s.k_v["RA_TARG"], s.k_v["DEC_TARG"] )

      # -- check the params "zo_pox_x & _y"
      ax, flag1 = zo_xy_par( s.zo_pos_x )
      ay, flag2 = zo_xy_par( s.zo_pos_y )
      if ( flag1==1 ) and ( flag2==1) :
         # -- both params "zo_pox_x & _y" are positive numbers (ie. ax & ay)
         s.zo_pos_x = ax
         s.zo_pos_y = ay
      else :
         # set to evtfile ra/dec_targ 
         s.zo_pos_x = s.sky_targ[0]
         s.zo_pos_y = s.sky_targ[1]
   # end: upd_zo_pos_xy

   ## initial output file 
   # 
   def init_outTab( s, outTab ):
     s.remove_files( outTab )
     o_blk = dmTableCreate( outTab+"[SRCLIST]" )
     # --- copy header & subspace info from input evt file
     try :
         tmpBlk = dmTableOpen( s.infile)
     except :
         s.error_out("ERROR: can't open the table file "+s.infile+".\n" )
     dmBlockCopy( tmpBlk, o_blk, "HEADER")

     # 3/2013 - add keys to the primary header
     inDs = dmBlockGetDataset( tmpBlk )
     outDs = dmBlockGetDataset( o_blk )
     nnn = dmBlockGetNo( tmpBlk ) 
     if ( nnn > 1 ) :
        inPrimary = dmDatasetMoveToBlock( inDs, 1 )
        outPrimary = dmDatasetMoveToBlock( outDs, 1 )
        dmBlockCopy( inPrimary, outPrimary, "HEADER")
        dmDatasetMoveToBlock( outDs, nnn )

     dmDatasetClose( inDs )

     #---- create columns ptrs :     
     raP=dmColumnCreate(o_blk,"RA",np.float64,unit="deg",desc="right ascension") 
     decP=dmColumnCreate(o_blk,"DEC",np.float64,unit="deg",desc="declination") 
     posP=dmColumnCreate(o_blk,"POS",np.float64,cptnames=['X','Y'],unit="pixel",desc="0th order sky centroid")
     shpP=dmColumnCreate(o_blk,"SHAPE",np.string_,itemsize=10,unit="",desc="shape of the region") 
     radP=dmColumnCreate(o_blk,"R",np.float32,shape=[2],unit="pixel",desc="generalized radii")
     cntP=dmColumnCreate(o_blk,"NET_COUNTS",np.float32,unit="count",desc="cnts in 0th order region") 
     rateP=dmColumnCreate(o_blk,"NET_RATE",np.float32,unit="count/s",desc="cnt rate in 0th order region") 
     angP=dmColumnCreate(o_blk,"ROTANG",np.float32,unit="deg",desc="rotate angle of src region") 
     comP=dmColumnCreate(o_blk,"COMPONENT",int,unit="",desc="index of region components") 
     idP=dmColumnCreate(o_blk,"TG_SRCID",np.short,unit="",desc="0th order source count") 

     return o_blk
   # end: init_outTab


   ## update header keys
   # 
   def update_keys ( s, o_blk ):
     dmKeyWrite(o_blk,"CONTENT","TGSRC",unit="",desc="")
     dmKeyWrite(o_blk,"HDUCLASS","ASC",unit="",desc="")
     dmKeyWrite(o_blk,"HDUCLAS1","SRCLIST",unit="",desc="")
     dmKeyWrite(o_blk,"HDUCLAS2","CANDIDATES",unit="",desc="")
     dmKeyWrite(o_blk,"HDUCLAS3","FINDZO",unit="",desc="")
     dmKeyWrite(o_blk,"COUNT_ST",s.streak_counts,"count", "counts in ACIS CCD frame-shift streak")
     dmKeyWrite(o_blk,"COUNT_TG",s.grating_counts,"count", "counts in grating arm ")
     dmKeyWrite(o_blk,"ANGLE_ST",s.readout_angle,"degree", "readout angle") #streak angle
     dmKeyWrite(o_blk,"ANGLE_TG",s.grating_arm_angle,"degree", "grating arm angle")
   # end: update_keys 


   ##
   #  output source table w/ extname=SRCLIST
   #
   def write_tabFile( s  )  :
        o_blk = s.init_outTab( s.outfile) 
        s.update_keys( o_blk )
        raP = dmTableOpenColumn( o_blk, "RA")
        decP = dmTableOpenColumn( o_blk, "DEC")
        posP = dmTableOpenColumn( o_blk, "POS")
        cntP = dmTableOpenColumn( o_blk, "NET_COUNTS")
        rateP = dmTableOpenColumn( o_blk, "NET_RATE")
        shapeP = dmTableOpenColumn( o_blk, "SHAPE")
        radiusP = dmTableOpenColumn( o_blk, "R")
        angP = dmTableOpenColumn( o_blk, "ROTANG")
        comP = dmTableOpenColumn( o_blk, "COMPONENT")
        idP = dmTableOpenColumn( o_blk, "TG_SRCID")

        # set columns ranges
        max_xy  = 8192.5
        dbl_max = sys.float_info.max               # 1.7976931348623157e+308
        dmDescriptorSetRange( raP,  0.0, 360.0 )
        dmDescriptorSetRange( decP, -90.0, 90.0 )
        dmDescriptorSetRange( posP, 0.5, max_xy)
        dmDescriptorSetRange( cntP, 0.0, dbl_max)  # 0:3.402823466E+38
        dmDescriptorSetRange( rateP, 0.0, dbl_max)
        dmDescriptorSetRange( radiusP,  0.0, max_xy)
        dmDescriptorSetRange( angP, 0.0, 180.0)
        dmDescriptorSetRange( comP, 1, 32767)
        dmDescriptorSetRange( idP, 1, 32767)

        # set data  
        dmSetData( raP, s.raOut)        # RA  outCol
        dmSetData( decP, s.decOut)      # DEC outCol
        dmSetData( posP, [s.x_zo,s.y_zo])          # POS outCol
        dmSetData( cntP, s.net_count)   # NET_COUNTS outCol
        dmSetData( rateP, s.net_rate)   # NET_RATE outCol
        dmSetData( shapeP, "circle")    # SHAPE outCol
        radiV    = np.zeros(2, float )
        radiV[0] = s.radius_zo
        radiV[1] = 0.0
        dmSetData( radiusP, radiV )   # R outCol
        dmSetData( angP, 0 )          # ROTANG outCol
        dmSetData( comP, 1 )          # COMPONENT outCol
        dmSetData( idP,  1 )          # TG_SRCID outCol

        dmDatasetClose( dmBlockGetDataset( o_blk ))

   # end: write_tabFile

   ###
   #
   def write_tmpFile( s, tmpFile, tmpXdata, tmpYdata):
     s.remove_files( tmpFile )
     o_blk = dmTableCreate( tmpFile+"[EVENTS]" )
     tmp_xPtr = dmColumnCreate(o_blk, "X", np.float64, unit="pix", desc="sky x" ) 
     tmp_yPtr = dmColumnCreate(o_blk, "Y", np.float64, unit="pix", desc="sky y" ) 
     dmDatasetClose( dmBlockGetDataset( o_blk ))
     if len(tmpXdata) > 0 : 
        o_blk = dmTableOpen( tmpFile+"[EVENTS]", update=True )
        tmp_xPtr = dmTableOpenColumn( o_blk, "X")
        tmp_yPtr = dmTableOpenColumn( o_blk, "Y")
        dmSetData( tmp_xPtr, tmpXdata )
        dmSetData( tmp_yPtr, tmpYdata )
        dmDatasetClose( dmBlockGetDataset( o_blk ))
   # end: write_tmpFile

#end: _TZO


class TZO(_TZO) :
   def __init__(s, tname ) :
      s.streak_offset_meg = -0.011       #pix
      s.streak_offset_leg = 0.000        #pix
      s.acis_s_rotation_corr_meg = 0.03  #deg
      s.acis_s_rotation_corr_leg = 0.00  #deg
      s.alpha_corr_meg = -0.006          #deg
      s.alpha_corr_leg = 0.00            #deg
      s.pixel_tolerance = 1.e-4 #pix; threshold for convergence to a pixel

      s.box_width_streak     = 80.0      #pix, default: not sub-array
      s.box_length_streak    = 2000.0    #pix 
      s.box_width_grating    = 50.0      #pix
      s.box_length_grating   = 5000.0    #pix 

      s.bin_size_grating     = 20.0      #pix
      s.bin_size_streak      = 20.0      #pix

      s.radius_grating       = 50.0      #pix
      s.radius_streak        = 50.0      #pix, default: not sub-array
      s.radius_zo            = 30.0      #pix, outCol R[2]

      s.clipping_factor    = 1.3     #empirically determined
      s.energy_filter_low  = 500     #ev
      s.energy_filter_high = 4000    #ev

      # dph hardwired val. 
      s.energy_m  = 8000.0     #ev
      s.yguess_m  = 4000.0     #pix, initial y guess
      s.mloop     = 20         #max iterate loop

      s.k_v = {} # keys values
      s.get_params( tname )       
      # s.tname = tname
      s.evtData()

   # end: __init__

   ##
   #
   def evtData ( s ) :
      #  open column Ptr 
      def colPtr ( colname ) :
        try :
            col_ptr  = dmTableOpenColumn( s.inBlk, colname)
        except :
            cmd = "ERROR: unable to open the column "+colname
            cmd+= " in input file '"+s.infile+"'\n"
            s.error_out(cmd)
        return  col_ptr
     
      try :
         s.inBlk = dmTableOpen( s.infile)
      except :
         s.error_out("ERROR: can't open the table file "+s.infile+".\n" )

      s.evtKeys() 
      nRows = dmTableGetNoRows( s.inBlk )

      #  Error if evtfile is empty 
      if nRows < 1 :
         s.error_out("ERROR: No data in evt file "+s.infile+"\n")

      Xptr = colPtr("X")
      Yptr = colPtr("Y")
      Eptr = colPtr("ENERGY")
      Cptr = colPtr("CHIP")

      dmTableSetRow( s.inBlk, 1 )

      X_t  = dmGetData( Xptr,  1, nRows )    # float32; SkyX
      Y_t  = dmGetData( Yptr,  1, nRows )    # float32; SkyY
      s.X_v  = X_t.astype(np.float64)
      s.Y_v  = Y_t.astype(np.float64)
      s.E_v  = dmGetData( Eptr,  1, nRows )    # float32; Energy
      chipxy = dmGetData( Cptr,  1, nRows )    # int16  ; chipX,Y

      dmDatasetClose( dmBlockGetDataset( s.inBlk ))

      minChipY = chipxy[:,1].min()
      maxChipY = chipxy[:,1].max()
      if ((maxChipY - minChipY) < 768) :  #sub-array in use
         s.radius_streak    *= 0.4        #pix; 20=50*0.4
         s.box_width_streak *= 0.5        #pix; 40=80*0.5
   # end: evtData ( s ) :

   ##
   ## 10: Rotate the X,Y event sky coordinates around the target coords
   ##
   def rotxy( s, x1, y1, x0, y0, ang_rad) :
       angle = ang_rad   
       ca = np.cos (angle)   
       sa = np.sin (angle)
       xca = ( x1 - x0 ) *  ca 
       ysa = ( y1 - y0 ) *  sa
       xsa = ( x1 - x0 ) *  sa
       yca = ( y1 - y0 ) *  ca 
       xrot =   xca + ysa + x0
       yrot =  -xsa + yca + y0

       return xrot, yrot
   # end: rotxy()

   ## 
   ## Clipping of array elements found nsig*sdev times above or below the 
   ## average value of the input array. Return the indices of the input array
   # that meets the given criterion nsig.
   # 
   def clip_array( s, yy2, nsig) : 
      tmp_idx = np.where( yy2 == yy2 )       
      idx = tmp_idx[0]
      idxDim = len(idx)
      prevDim = 0     

      # while(x-y) = while(x!=y)
      while ( idxDim-prevDim):
         if (  idxDim == 0 ) :    # not error
            return idx
         prevDim =  idxDim

         # mt[0:1] = ave,sdev
         mt = np.zeros( 2, float )
         mt[0] = np.mean(yy2[idx], dtype=np.float64)
         mt[1] = np.std(yy2[idx], dtype=np.float64, ddof=1)

         nsig_M = nsig * mt[1]
         tmp_idx = np.where (abs( yy2 - mt[0] ) < nsig_M)
         idx = tmp_idx[0]      # new idx
         idxDim = len(idx)
      # end: while 
   
      return idx
   # end: clip_array
   
   ## 
   ## This function performs the following steps:
   ##
   ## (1) read in two data arrays (X,Y); the elements in Y is binned up
   ##     along X with a specified bin width BIN.
   ##
   ## (2) Use the clip_array function to clip data points that are considered
   ##     as insignificant (its significance controlled via DEL).
   ##
   ## (3) then derive the median value of data points in each data bin along X.
   ##
   ## (4) return the median value in each bin.
   ## 
   def median_peak(s, xx, yy, binSz, thresh ) :
       if len(xx) == 0 : 
          return -1, -1, 1  # errSts=1; Must error out in main(); 

       min_xx = min(xx) 
       max_xx = max(xx) 
       NN = int ( ( max_xx - min_xx ) / binSz ) 
   
       xp =  np.zeros( NN,  float )
       yp =  np.zeros( NN,  float )

       halfBin = binSz / 2.0
       xp =  np.array(list(range(0, NN))) * binSz + min_xx + halfBin

       # from finzo.sl :
       yp[:] = -32768.
       # Sometimes the median values are not found in the routine
       # when that happens, the value YP stays as signed integer -32768.
       # The array elements with this value are rejected and not returned.
   
       len_xp = len(xp)

       # begin:  loop
       kkr = list(range( len_xp))
       for kk in kkr : 
          wjj=np.where((xx >= (xp[kk]-halfBin)) & (xx < (xp[kk]+halfBin)))
          jj = wjj[0]
          jjL = len( jj )
   
          xxj = xx[ jj ]
          yyj = yy[ jj ]
   
          # - not an error if jjL < 2
          ll = {}
          if jjL >= 2 :
             ll = s.clip_array( yyj, thresh ) 
             if  len(ll) >= 2 :
                # don't use median() in numpy 
                def sMedian( values ):
                   sValues = sorted( values )   
                   len_ = len( sValues )
                   if len_ % 2 == 1:
                      return sValues[(len_ + 1 )//2 - 1 ]
                   else:
                      return sValues[len_//2 - 1 ]

                yp[kk] = sMedian( yyj[ll] )
             #end: if len(ll) >= 2  
          #end: if  jjL > 2
       # end:  loop for kk
   
       wtt = np.where(yp != -32768.)
       wt = wtt[0]
       xpt = xp [ wt ] 
       ypt = yp [ wt ] 

       return xpt, ypt, 0     # errSts=0; no error;
   
   # end: median_peak

   ## 
   ## 5/7/2013 - new way to compute c0 for the linear fit:  y=c0+c1*x   
   ##            c1=1 ; 
   ##            c0=[sum(y)-sum(x)]/N   where N is the number of x,y pairs.
   ## 
   def fit_linear( s, xx, yy ) :
      c0 = (sum(yy) - sum(xx)) / len(xx) 
      return c0

   ## Use sherpa to do a linear fit ( obsolete as of 5/7/2013)
   ##     linear fit :   yy=c0+c1*xx  and set c1=1
   # def fit_linear( s, xx, yy, intercept=-9999 ) :
   #    ui.load_arrays(1, xx, yy, ui.Data1D )
   #    ui.set_source(1, "polynom1d.pol1")    ## y = c0 + c1 * x
   #    ui.freeze(pol1)
   #    ui.thaw(pol1.c0)
   #    pol1.c0 = intercept
   #    pol1.c1 = 1      # freeze slope to 1 
   #    ui.freeze(pol1.c1)
   #    ui.set_stat("leastsq")
   #    slog=logging.getLogger("sherpa")
   #    slog.setLevel(0)
   #    ui.fit( 1 )
   #    return  pol1.c0.val

   #end: fit_linear 

   def add_history_to_file(s, file_name, tool_name, param_names, param_values):
      if os.path.exists( file_name) == False : 
         return 
      crate = read_file(file_name, mode="rw")
    
      histnum = crate.get_key("HISTNUM")
      if histnum is None:
         histnum = CrateKey()
         histnum.name = "HISTNUM"
         histnum.value = 1
         crate.add_key(histnum)
        
      startat = histnum.value
      startat = 1 if startat is None else startat

      hh = HistoryRecord(tool=tool_name,param_names=param_names,
                         param_values=param_values, asc_start=startat)    
      add_history(crate, hh) 
    
      # Increment HISTNUM for new lines
      hlen = hh.as_ASC_FITS().split("\n")
      histnum.value = startat+len(hlen)

      crate.write() 
   #end: add_history_to_file  

#end: TZO

# --- main code 
#
def main() :

   ##  --- get parameters
   p = TZO("tg_findzo")

   ## 
   ## 10: Rotate the X,Y event sky coordinates around the target coords
   ##
   ( xrot_streak, yrot_streak ) = p.rotxy ( p.X_v, p.Y_v, 
                                         p.zo_pos_x, p.zo_pos_y, -p.rotang_rad )

   if (p.debug >=5 ) :
      p.write_tmpFile("/tmp/zzz10_rotxy.fits", xrot_streak, yrot_streak)

   ## 
   ## 11:  Select streak events via geometry and energy, excluding zeroth order
   ##
   if p.k_v["GRATING"] == "HETG" :
      wsi = np.where( ( xrot_streak > (p.zo_pos_x - p.box_width_streak/2.0) )
                      & ( xrot_streak < (p.zo_pos_x + p.box_width_streak/2.0) )
                      & ( yrot_streak > (p.zo_pos_y - p.box_length_streak/2.0) )
                      & ( yrot_streak < (p.zo_pos_y + p.box_length_streak/2.0) )
                      & ( p.X_v > 0.0 )
                      & ( p.Y_v > 0.0 )
                      & (np.hypot(xrot_streak-p.zo_pos_x,yrot_streak-p.zo_pos_y)>p.radius_streak)
                      & ( p.E_v > p.energy_filter_low )
                      & ( p.E_v < p.energy_filter_high ))
   else : 
      wsi = np.where( ( xrot_streak > (p.zo_pos_x - p.box_width_streak/2.0) )
                      & ( xrot_streak < (p.zo_pos_x + p.box_width_streak/2.0) )
                      & ( yrot_streak > (p.zo_pos_y - p.box_length_streak/2.0) )
                      & ( yrot_streak < (p.zo_pos_y + p.box_length_streak/2.0) )
                      & ( p.X_v > 0.0 )
                      & ( p.Y_v > 0.0 )
                      # for LETG/ACIS, exclude a strip instead of a circle: (2012.04.09 dph)
                      #     & (np.hypot(xrot_streak-p.zo_pos_x,yrot_streak-p.zo_pos_y)>p.radius_streak)
                      & (abs(yrot_streak-p.zo_pos_y)>p.radius_streak)                   
                      & ( p.E_v > p.energy_filter_low )
                      & ( p.E_v < p.energy_filter_high ))

   streak_indices = wsi[0]

   ## 
   ## 12:  Determine streak position (assumed vertical)
   ##
   y_streak =  yrot_streak[ streak_indices ]
   x_streak =  xrot_streak[ streak_indices ]
   if (p.debug >=5 ) :
     p.write_tmpFile("/tmp/zzz12_rot_streak.fits", x_streak, y_streak)

   y_median,x_median, errSts = p.median_peak( y_streak, x_streak, 
                               p.bin_size_streak, p.clipping_factor)
   
   if ( errSts == 1 ) :      # y_median,x_median are scalars 
      p.error_out("ERROR: no median_peak points found.\n")
   else :
      if (p.debug >=5 ) :
        p.write_tmpFile("/tmp/zzz12_median_peak.fits", x_median, y_median)

   ## 
   ## 13:  Find the mean of the x values
   ## 
   x_fit_streak = np.mean ( x_median )  #scalar; average of x_median elem.

   ## 
   ## 14:  Iterate, by discarding 2 'sigma' outliers and repeating.
   ##
   old_x = 0.0
   while ( abs( old_x - x_fit_streak ) > p.pixel_tolerance ) :
      old_x = x_fit_streak
      indices = p.clip_array( x_median, 2.0)
      x_fit_streak = np.mean( x_median[ indices ] )
   #end: while loop

   ## 
   ## 15:  Repeat the filter to get streak counts (plus background).
   ##
   ## outside the while loop :
   wsi = np.where( ( xrot_streak > (x_fit_streak - p.box_width_streak/2.0) )
     & ( xrot_streak < (x_fit_streak + p.box_width_streak/2.0) )
     & ( yrot_streak > (p.zo_pos_y - p.box_length_streak/2.0) )
     & ( yrot_streak < (p.zo_pos_y + p.box_length_streak/2.0) )
     & ( p.X_v > 0.0 )
     & ( p.Y_v > 0.0 )
     & ( np.hypot(xrot_streak-x_fit_streak, yrot_streak-p.zo_pos_y) >p.radius_streak)
     & ( p.E_v > p.energy_filter_low )
     & ( p.E_v < p.energy_filter_high ))
   streak_indices = wsi[0]

   if (p.debug >=5 ) :
     p.write_tmpFile("/tmp/zzz15_filter_rotStk.fits",xrot_streak[streak_indices],yrot_streak[streak_indices] )

   ## 
   ## 16:  Save the number of streak counts (for output to tool param):
   ## 
   p.streak_counts = len( streak_indices )      # outKey  COUNT_ST

   ## 
   ## 17.  Apply empirical offset (compensates for MEG, HEG alignments):
   ##
   x_fit_streak += p.streak_offset

   ## 
   ## 18.  Set initial guess for the y coordinate of zeroth order, 
   ##           and a saved previous value; 
   ##
   num_iter = 0
   max_outer = p.mloop    # 20 
   prev_y_fit_zo = 0.0
   y_fit_zo  = p.zo_pos_y

   ## 
   ## begin: while :   outer loop
   ## 
   ## 19. While the y fit zo value changes, iterate, up to max_outer tries:
   ##
   while ((num_iter < max_outer) and 
          (abs(y_fit_zo - prev_y_fit_zo) > p.pixel_tolerance)):

      num_iter += 1

      ## 
      ## 20. Copy the y-value :
      ## 
      prev_y_fit_zo = y_fit_zo
      ## 
      ## 21. Rotate event coordinates (xrot_streak, yrot_streak) about the
      ##  current center (y-coord is being iterated) so the grating arm is 
      ##  horizontal ( starting with events rotated for vertical streak ):
      ##
      ( xrot_grat, yrot_grat ) = p.rotxy( xrot_streak, yrot_streak,
                                 x_fit_streak, y_fit_zo, -p.grating_angle_rad )

      if (p.debug >=5 ) :
        p.write_tmpFile("/tmp/zzz21_rot_grat.fits", xrot_grat, yrot_grat )

      ## 
      ## 22. Index the grating arm events in the box region,
      ##      excluding zeroth order and energies above a limit:
      ## 
      wsi = np.where( ( xrot_grat > ( x_fit_streak - p.box_length_grating/2.0 ) )
         & ( xrot_grat < ( x_fit_streak + p.box_length_grating/2.0 ) )
         & ( yrot_grat > ( y_fit_zo - p.box_width_grating/2.0 ) )
         & ( yrot_grat < ( y_fit_zo + p.box_width_grating/2.0 ) )
         & ( p.X_v > 0.0 )
         & ( p.Y_v > 0.0 )
         & (np.hypot(xrot_grat-x_fit_streak,yrot_grat-y_fit_zo)>p.radius_grating)
         & ( p.E_v < p.energy_m ) )   # 8000.

      grating_indices = wsi[0]

      ## 
      ## 23. Select the filtered events (in x,y rot_streak coords)
      ## 
      x_grat = xrot_streak[ grating_indices ]
      y_grat = yrot_streak[ grating_indices ]

      if (p.debug >=5 ) :
        p.write_tmpFile("/tmp/zzz23_xy_grat"+str(num_iter)+".fits", x_grat, y_grat)

      ## 
      ## 24. Find the median of the grating arm in coarse bins:
      ##
      x_median, y_median, errSts = p.median_peak ( x_grat, y_grat,
                                   p.bin_size_grating, p.clipping_factor )
      if ( errSts == 1 ) :   # y_median,x_median are scalars 
        p.error_out("ERROR:  no median_peak points found.\n")
      else :
        if (p.debug >=5 ) :
          p.write_tmpFile("/tmp/zzz24_arm_median_peak_"+str(num_iter)+".fits",x_median,y_median)

      ## 
      ## 25. Do a linear fit, with slope assumed to be 1.0. Really fitting:
      ##
      angle = -p.grating_angle_rad
      x_values = angle * x_median   
      y_values = y_median
      y_intercept_guess = p.yguess_m       # 4000.0 
#     y_intercept = p.fit_linear( x_values, y_values, y_intercept_guess)
      y_intercept = p.fit_linear( x_values, y_values )     # 5/7/2013

      ##
      ## 26. Iterate (within the outer iteration) on this by 
      ##  clipping > 2'sigma' outliers from the resulting line (may want 
      ##  some max iteration limit, say 20 tries):
      ##
      iterate_ = 0 
      max_inner = p.mloop    # 20

      ## begin: while :   inner iteration 
      while iterate_ < max_inner : 
         iterate_ += 1
         prev_y_intercept  =   y_intercept
         indices  =  p.clip_array(y_values - x_values + y_intercept, 2.0)  
         Vx = x_values[ indices ]
         Vy = y_values[ indices ]
         x_values = Vx 
         y_values = Vy

         if ( len(Vx) <= 1 ) or ( len(Vy) <= 1) :
            break
#        y_intercept = p.fit_linear( Vx, Vy, y_intercept)
         y_intercept = p.fit_linear( Vx, Vy)          # 5/7/2013

         if (abs(y_intercept-prev_y_intercept) < p.pixel_tolerance) :
            break
      ## end : while :   inner iteration

      ## 
      ## 27. When done, compute the zeroth order y value:
      ## 
      y_fit_zo = y_intercept + angle * x_fit_streak

   ## end: while :   outer loop

   ## 
   ## 28. Count the grating events using the final coordinates:
   ## 
   wsi = np.where( ( xrot_grat > ( x_fit_streak - p.box_length_grating/2.0 ) )
         & ( xrot_grat < ( x_fit_streak + p.box_length_grating/2.0 ) )
         & ( yrot_grat > ( y_fit_zo - p.box_width_grating/2.0 ) )
         & ( yrot_grat < ( y_fit_zo + p.box_width_grating/2.0 ) )
         & ( p.X_v > 0.0 )
         & ( p.Y_v > 0.0 )
         & ( np.hypot(xrot_grat-x_fit_streak,yrot_grat-y_fit_zo)>p.radius_grating)
         & ( p.E_v< p.energy_m ) )   # 8000.
   grating_indices = wsi[0]
   if (p.debug >=5 ) :
     p.write_tmpFile("/tmp/zzz28_rot_grat.fits", xrot_grat[grating_indices],
                      yrot_grat[grating_indices] )

   ##
   ## 29. Save the number of counts in the grating-selected events 
   ## (this is approximate, but will be useful for downstream heuristics 
   ##  which compare accuracy of findzo and celldetect):
   ##
   p.grating_counts = len( grating_indices )  # outKey  COUNT_TG

   ##
   ## 30. Rotate the values back to the original SKY frame:
   ##
   # x_zo, y_zo   : outCol X,Y
   p.x_zo, p.y_zo = p.rotxy( x_fit_streak, y_fit_zo, p.zo_pos_x,
                       p.zo_pos_y, p.rotang_rad)

   ## 
   ## 31. Convert coords:
   ##         (RA, DEC) = f(x_zo, y_zo, ...)
   ##
   # raOut,decOut : outCol RA,DEC
   cmd=" infile="+p.infile+" asol= opt=sky x="+str(p.x_zo)+" y="+str(p.y_zo)+" celfmt=deg verb=0"
   p.exe_task("dmcoords", cmd)
   p.raOut = float(paramio.pgetstr("dmcoords","ra"))
   p.decOut = float(paramio.pgetstr("dmcoords","dec"))

   ## 
   ## 32. Determine the number of zeroth order counts by filtering 
   ##   in a circle centered on x_zo, y_zo (and other data needed to be
   ##   compatible w/ the src1a file):
   ##
   wsi = np.where( np.hypot( p.X_v - p.x_zo, p.Y_v - p.y_zo ) < p.radius_zo )
   zo_indices = wsi[0]  
   p.net_count = len(zo_indices)          # dph zo_counts; outCol NET_COUNT;
   p.net_rate  = p.net_count/p.exptime  # outCol NET_RATE

   # write output file
   p.write_tabFile()

   # add history
   p.add_history_to_file(p.outfile, "tg_findzo", p.pnames, p.pvals)

   os._exit(0)

if __name__=="__main__":
    main()
