#!/opt/conda/envs/ciao-4.17.0/bin/python3
# 
#  Copyright (C) 2011,2012,2013,2016,2022,2023  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


# 4/2011 - initial python version
# rhain [01NOV12] - rotate line segment used to determine
#   rotation angle by 90 deg. to align with streak direction
# rhain [18JAN13] - exit cleanly with no chip data


import os
import sys
import glob
import paramio
import string
import math
import copy
import numpy as np
from cxcdm import *       
from region import *     

'''Use the history module to add HISTORY records to crate'''
from pycrates import read_file
from pycrates import add_history, CrateKey
from history import HistoryRecord



REG_COORD_TYPE = "PHYS"     #  PHYS or WCS
DEG2RAD = np.pi/180.0

#   SRC_ON_STRK_RATIO is the assumed max ratio of srcs to
#   total streak length (ie, .1 means no more than 10% of
#   the streak contains sources).  Used to exlclude srcs
#   along a streak when computing avg summed-over-rows
#   counts along a streak
SRC_ON_STRK_RATIO = 0.10;




cRX = 0;
cRY = 1;
cRA = 2;
cDEC= 3;
cXX = 4;
cYY = 5;

## get parameters
#
def get_params( toolpar ) :
   pp = {}
   try :
      pfile = paramio.paramopen(toolpar, "rw", sys.argv)
   except :
      error_out("ERROR: can't open parameter file "+toolpar )

   pp["infile"] = paramio.pgetstr(pfile,"infile")
   pp["fovfile"] = paramio.pgetstr(pfile,"fovfile")
   pp["bkgroot"] = paramio.pgetstr(pfile,"bkgroot")
   pp["regfile"] = paramio.pgetstr(pfile,"regfile")
   pp["nsigma"] = float(paramio.pgetstr(pfile,"nsigma"))
   pp["msigma"] = float(paramio.pgetstr(pfile,"msigma"))
   pp["ssigma"] = float(paramio.pgetstr(pfile,"ssigma"))
   pp["dither"] = int(paramio.pgetstr(pfile,"dither"))
   pp["binsize"] = int(paramio.pgetstr(pfile,"binsize"))
   pp["k1"] = float(paramio.pgetstr(pfile,"k1"))
   pp["k2"] = float(paramio.pgetstr(pfile,"k2"))
   pp["tmppath"] = paramio.pgetstr(pfile,"tmppath")
   pp["clobber"] = paramio.pgetstr(pfile,"clobber")
   pp["verbose"] = paramio.pgetstr(pfile,"verbose")
   pp["mode"] = paramio.pgetstr(pfile,"mode")
   paramio.paramclose(pfile)
   return pp


## display message
#
def display_msg( msg ) :
    sys.stderr.write( msg +"\n" )


## error out w/ message
#
def error_out( msg ) :
    display_msg( msg )
    sys.exit(1)


## execute task (cmd); if there's a problem, error_out w/ message
#
def exe_task (tool, cmd) :
   paramio.punlearn( tool )
   status = os.system ( tool+" "+cmd )
   if status != 0  :
      error_out( "problem with "+tool+"\n")


## get the key value from File header
#
def get_key_value ( File, Key ) :
   cmd=File+" "+Key+r" echo-"
   exe_task("dmkeypar", cmd )
   value = paramio.pget("dmkeypar", "value")

   if Key == "RA_NOM" or Key =="DEC_NOM" :
      vv = "%.3f"%(float(value))
      value = vv

   return value


## check for the output file existence :
## If found, return True 
## Else return False
#
def find_output_file( File ) :
   if os.path.exists( File ) == True :
      return True
   else :
      return False


## if clobber is yes and the outfile exists, error out.
#
def check_clob( Clob, File ) :
   if find_output_file( File ) == True :
      if ( Clob[0] == "n") or ( Clob[0] == "N") :
         error_out("\nERROR: outfile '"+File+"' exists and clobber=no.\n")
      else :
         remove_files( File )


## is it a blank string ?
#
def isBlank( ss ) :
   if len(ss.strip(" ")) == 0 :
      return True
   else :
      return False


## remove a list of files
#
# allow any number of spaces between files.
# allow syntax w/ '*' ( eg.  *file*  or file* ).
#     eg. List=" f1 f2  any*  *any*  f4  "
#
def remove_files( List ) :
   for FN in List.split() :
      for fn in glob.glob( FN ):
         if isBlank(fn) == False :
            try:
               os.unlink( fn )
            except :
               pass


## check for the input file existence :
## If not found, error out.
## Else return 0.
#
def find_input_file( File ) :
   nameA = File.split("[")
   name = nameA[0]
   if os.path.exists( name ) != True :
      error_out("ERROR: input file '"+name+"' not found.\n" )
   return 0


## Open an input table file :
## If fail to open, error out. 
## Else return dmBlock
#
def open_input_tabFile( File ) :
   find_input_file( File )     # error out if not found

   inBlock = None
   try :
      inBlock = dmTableOpen( File )
   except :
      error_out("ERROR: can't open the file "+File+".\n" )

   return inBlock


## Open an input image file "chip_img" :
## If fail to open, error out.
## Else return dmBlock, img_dd
#
def open_input_imgFile( File ) :
   find_input_file( File )     # error out if not found

   img_block = None
   img_dd    = None
   try :
      img_block = dmImageOpen( File )
   except :
      error_out("acis_streak_map: ERROR: can't open image file "+File+".\n" )

   try : 
      img_dd = dmImageGetDataDescriptor( img_block )
   except :
      error_out("acis_streak_map: ERROR: can't open image data descriptor for "+File+".\n")

   return img_block, img_dd

# 
#  looking for the first/last  row's/column's index that contains
#  non-zero element in that row||column
#
#  depending on the passing variables, it will return  :
#     xmin = min_max_index( 0,        xx_dim,  1, sss, 0 )
#     ymin = min_max_index( 0,        yy_dim,  1, sss, 1 )
#     xmax = min_max_index( xx_dim-1,     -1, -1, sss, 0 )
#     ymax = min_max_index( yy_dim-1,     -1, -1, sss, 1 )
#
#  varible  x_or_y :   0 for xlim||row ;   1 for ylim||row ;
#
def  min_max_index( beg, end, step, sss, x_or_y, chip ) :
  mm  = 0
  xylim = -999
  for ii in range(beg, end, step ) :         #  beg <= ii < end -1
      if mm == 0 :
         if x_or_y == 0 :
            idx = np.where( sss[ii, :] != 0 )   # for xlim
         else :
            idx = np.where( sss[:, ii] != 0 )   # for ylim

         if (len( idx[0] ) > 0) :
            xylim = ii
            mm = 1
  if xylim == -999:
     error_out("ERROR: No non-zero data found on chip {0}".format(chip))
  return xylim
# --- end: min_max_index()

#
## image_limits()
#     Input img, xlen, ylen
#     Output  xmin,xmax,ymin,ymax 
#
#
def image_limits( img, xlen, ylen, chip ) :
  xx_dim = xlen 
  yy_dim = ylen 
  sss = img.copy() 

  xmin = min_max_index( 0,        xx_dim,  1, sss, 0, chip )
  ymin = min_max_index( 0,        yy_dim,  1, sss, 1, chip )
  xmax = min_max_index( xx_dim-1,     -1, -1, sss, 0, chip )
  ymax = min_max_index( yy_dim-1,     -1, -1, sss, 1, chip )

  return xmin, xmax, ymin, ymax 
## --- end: image_limits()


# 
# -- create _rotimg_1_sfr.txt 
# 
def src_free_rows(row_sum_all,yminrow,ymaxrow,nsigma,method,binsize,k1,k2,infile,debug) :
   ww=np.zeros(shape=1)
   sig=0

   ## --- skip calculation when img is 0 events
   if ((yminrow==0) and (ymaxrow==0)) :
      return ww, sig

   ## --- row_sum_all[] :  evt cnts in each y_row after sum in x_col;
   row_sum = np.zeros( ymaxrow - yminrow + 1, float )  
   row_sum = row_sum_all [ yminrow:(ymaxrow+1) ]

   ## --- Histogram method
   if (method == "histo") : 
        
      mean      = np.sum( row_sum ) / len(row_sum) 

      sort_sum  = np.zeros ( ymaxrow - yminrow + 1, int) 
   
      indx = np.zeros ( ymaxrow - yminrow + 1, int ) 
      indx = row_sum.argsort()

      sort_sum  = indx
      mid_index = len( sort_sum ) / 2.0
      ii        = int( np.ceil(mid_index) )    

      ## --- calculate median of summed counts values
      median = 0.0 
      if (mid_index != np.floor(mid_index)) :
         median  = row_sum[sort_sum[ii]]
      else :
         median  = (row_sum[sort_sum[ii]] + row_sum[sort_sum[ii+1]])/2.0;

      ## --- compute std. dev. of summed counts values
      var = 0.0
      x = 0.0
      for  rrr in range(0, len(row_sum) ) :
          x = row_sum[ rrr ] - mean;      
          var += x*x;

      var /= len (row_sum);
      sig = np.sqrt(var);

      ## --- create row counts histogram : rc_histo[]
      rc_max = 0.0
      bins = 0.0
      rc_bins = 0

      rc_max   = max(row_sum);

      bins     = rc_max/binsize;

      if (bins == np.ceil(bins))  :
         bins += 1
      rc_bins = int( bins )

      rc_histo = np.zeros( rc_bins, int )

      ## --- fill histogram of "number of counts summed in x_col"
      counts = 0.0
      idx = 0
      for jn in range(0, len(row_sum) ):
         counts = row_sum[ jn ]
         if (counts < 0)  :
             counts = 0;

         idx = int( np.floor( counts/binsize ) )
         rc_histo[ idx ] += 1 
        
      ## --- find mode2, avg of counts for all bins with equal, max values.
      max_histo = 0
      max_histo = max(rc_histo)     # int

      ## --- indices in histo of max values
      t_modes     = np.where(rc_histo == max_histo)
      modes = t_modes[0]

      mode2  = 0;
      for mmm in range(0, len(modes) ) :
         m_value = modes[ mmm ]
         mode2 += m_value * binsize

      mode2 /= len(modes);

      ## --- compute candidate source free row limits, select smaller one
      s_max1 = median + k1*sig;
      s_max2 = k2 * mode2;
      s_max  = min(s_max1,s_max2);

      wwT = np.where(row_sum < s_max); 
      ww =  wwT[0]
      ww += yminrow;

      if (debug > 0) :
         sss = infile.split(".fits")
         sfr_file = sss[0] +"_"+str(binsize)+"_sfr.txt"
         fp  = open(sfr_file,'w');
         for ii in range( 0, len(rc_histo), 1) :     # rc_histo = array of INT
            fp.write( str(ii)+"\t"+str(rc_histo[ii])+"\n" );
         fp.close()
         display_msg("# src_free_rows: INFO: sfr_file="+sfr_file+" median="+str(median)+" mode2="+repr(mode2)+"  binsize="+repr(binsize))
   # --- end :  method == "histo"

   else  :        # ...method = "min"

      zz = min(row_sum)       # float
      
      # ------------------------------------------------------- 
      #  Assume Poisson distribution for summed-over-columns
      #  background, so that the variance (sig squared) = mean 
      #  summed-over-columns background value (zz).
      # ------------------------------------------------------- 
      sig = np.sqrt(zz); 
      wwT = np.where((row_sum_all>=zz)&(row_sum_all<=(zz+nsigma*sig)));
      ww =  wwT[ 0 ]

   # end :  method = histo , min....

   return ww,sig    # ww=array; sig=std.dev=scalar

# --- end: src_free_rows


def get_streak_xcoords( src_free,xmin,xmax,xlen,dither,msigma,debug):
  xmincol = xmin + dither;
  xmaxcol = xmax - dither;

  # 
  # for the sum of the source free rows find:
  #  - mean value
  #  - std. dev. 
  #  - streak_limit ( stk threshold ) 
  #
  mean_src_free = np.sum(src_free)/(xmaxcol - xmincol + 1);
  i_streak=0
  new_streak=1

  indx = np.zeros( len(src_free), int )
  indx = src_free.argsort()

  #...sort src_free in ascending order
  sort_src_free = np.zeros( len(src_free), int )
  sort_src_free = indx
  mid_index     = len(sort_src_free)/2.0;
  ii            = int(np.ceil(mid_index))    

  # 
  #  calculate median of summed src free counts values
  #
  #...sort_src_free is odd -> single median value
  if (mid_index != np.floor(mid_index)) :
    median_src_free = src_free[sort_src_free[ii]]

  #...sort_src_free is even -> avg 2 middle values
  else : 
    median_src_free = (src_free[sort_src_free[ii]] + src_free[sort_src_free[ii+1]])/2.0;

  # 
  #  compute std. dev. of summed src free counts values
  #
  var=0; x=0;
  for value in (src_free) :
    x = value - mean_src_free;
    var += x*x;

  var /= len(src_free);
  sig_src_free  = np.sqrt(var);

  # 
  #  calculate limit above which a streak is found
  # 
  streak_limit  = median_src_free + msigma*sig_src_free;

  if (debug > 0) :
    display_msg("median_src_free,sig_src_free,streak_limit= "+str(median_src_free)+", "+str(sig_src_free)+",  "+str(streak_limit)+"\n")

  # 
  # the width (short dim.) of steak region rectangle
  # is from streak_xlo to streak_xhi
  #
  streak_xlo = np.zeros( xlen, int )      #Integer_Type[xlen];
  streak_xhi = np.zeros( xlen, int )      #Integer_Type[xlen];
  src_streak = np.zeros( xlen, float )    # Float_Type [xlen];

  # 
  #  identify streaks
  #
  for ii in range(0, xlen) :
    #
    #   check if summed counts exceed min streak value
    #
    if (src_free[ii] > streak_limit) :
      if (new_streak == 1) :
        i_streak = i_streak+1
        # 
        #  set streak start x-value
        #
        streak_xlo[i_streak-1] = ii; 
        new_streak = 0;
       
      # 
      #  set streak stop x-value
      #
      streak_xhi[i_streak-1] = ii;

      # 
      #  src_streak array shows only streak counts or 0
      #
      src_streak[ii] = src_free[ii];
    else :
      src_streak[ii] = 0.0;
      new_streak = 1
  # end : for ii

  if (debug > 1) :
    display_msg("total streaks = "+str(i_streak))
    for ii in range(0, i_streak) :
       display_msg("streak "+repr(ii)+" : xlo="+repr(streak_xlo[ii])+"  xhi="+repr(streak_xhi[ii]))

  return (i_streak,streak_xlo,streak_xhi);

# --- end: get_streak_xcoords 


def get_single_streak_regions(src_free,xmin,xmax,ymin,ymax,xlen,ylen,dither,msigma,debug) :

  ## --- find streak box start and stop x-coords
  (i_streak,streak_xlo,streak_xhi) = get_streak_xcoords(src_free,xmin,xmax,xlen,dither,msigma,debug);

  if (i_streak > 0) :
     # reg_pts[i, 0-3, k ] -> 4 points begin w/ lower left -> then counter-clockwise
     # reg_pts[i, j,  0-5] -> 6 variables are slangs struct: rx,ry,ra,dec,xx,yy
     reg_pts = np.zeros( [ i_streak, 4, 6 ], float ) 
  else :
     reg_pts = np.zeros( [1,1,1],float)
   

  for ii in range(0, i_streak) :

    ## --- reg_pts[i,j,*]: i => box number, j => [0-3] box corner,
    reg_pts[ii, 0, cRX ] = streak_xlo[ii]+1.0-0.5;
    reg_pts[ii, 0, cRY ] = ymin+1.0-0.5;
    reg_pts[ii, 1, cRX ] = streak_xhi[ii]+1.0+0.5;
    reg_pts[ii, 1, cRY ] = ymin+1.0-0.5;
    reg_pts[ii, 2, cRX ] = streak_xhi[ii]+1.0+0.5;
    reg_pts[ii, 2, cRY ] = ymax+1.0+0.5;
    reg_pts[ii, 3, cRX ] = streak_xlo[ii]+1.0-0.5;
    reg_pts[ii, 3, cRY ] = ymax+1.0+0.5;

  return (reg_pts)

# --- end: get_single_streak_regions

def get_multi_streak_regions(arr_T,src_free,xmin,xmax,ymin,ymax,xlen,ylen,dither,msigma,ssigma,debug) :

  ## --- find streak box start and stop x-coords
  (i_streak,streak_xlo,streak_xhi) = get_streak_xcoords(src_free,xmin,xmax,xlen,dither,msigma,debug);

  if (i_streak > 0) :
     streak_sum = np.zeros( [ i_streak, ylen ], float )
     First_Streak = True
  else :
     reg_pts = np.zeros( [1,1,1],float)
   

  for ii in range(0, i_streak) :

     ##   for each streak (ii), travel up the column in the y-direction,
     ##   analyzing one row at a time.  At each row, sum the counts across
     ##   the small number of cols in the streak region.  The result is
     ##   a 1 col by ylen row array of summed cols for each streak region.
     streak_sum[ii,:] = arr_T[streak_xlo[ii]:streak_xhi[ii]+1,:].sum(axis=0);

     ##   idx is index to streak_sum which sorts 
     ##   streak_sum by size, smallest to largest
     idx = dither + streak_sum[ii,dither:ylen-dither].argsort()

     ##   streak_max is the scalar value of the largest
     ##   streak_sum element, (hopefully) excluding srcs
     streak_max  = streak_sum[ii,idx[int((1.0-SRC_ON_STRK_RATIO)*len(idx))]];
     streak_clip = streak_sum[ii,idx[np.where(streak_sum[ii,idx]<=streak_max)[0]]];

     ##   streak_mean is the avg. summed-over-rows value of counts
     ##   along the streak, (hopefully) exluding sources on the streak
     streak_mean = np.sum(streak_clip)/len(streak_clip);

     ##   Use Poisson approximation for std. dev.
     streak_stddev = np.sqrt(streak_mean);

     ##   streak_box has all indices which are in the box (ie, in that 
     ##   portion of the streak which is determined to be an actual streak, 
     ##   and not a source on the streak).
     streak_box = np.where(streak_sum[ii,:]<=(streak_mean+ssigma*streak_stddev))[0];

     ##   Collect start/stop y-indices for boxes along streak.
     ##
     ##   'N' boxes are on the streak, where a box is a rectangle along
     ##   the streak where there are no sources.  One box ends before a source
     ##   and the next begins after that source, whenever there is a source 
     ##   on the streak.  So a streak with 3 sources on it will have 4 boxes. 
     ##   N boxes cause the "boxes" list to be assembled like this:
     ##
     ##     -- start index for box #1
     ##     -- stop index for box #1,  start index for box #2,
     ##     -- stop index for box #2,  ... start index for box n
     ##     -- stop index for box n,   ... start index for box n+1
     ##     -- stop index for box n+1, ... start index for box N
     ##     -- stop index for box N.
     ##
     ##   Note: because it is holding start/stop indices, the "boxes" list
     ##         always has an even number of indices.
     ##
     boxes = []

     ##   always add first streak_box array element as start of box 1.
     boxes.append(streak_box[0])  ## start index for box #1

     ##   check for gaps in streak_box (where streak_sum values were > cutoff)
     ##   to stop current box and then restart next box
     for jj in range(0,len(streak_box)-2):
        if streak_box[jj+1] - streak_box[jj] != 1 :
           boxes.append(streak_box[jj]) ## stop index for box n
           boxes.append(streak_box[jj+1]) ## start index for box n+1
           if debug > 2 :
              print("streak {0}, sum larger than {1} for indices {2} to {3}".format(ii,streak_mean+ssigma*streak_stddev,streak_box[jj],streak_box[jj+1]))
              print("--> values (starting at first index) are: {0} {1} {2} {3} {4}".format(streak_sum[ii,streak_box[jj]],streak_sum[ii,streak_box[jj]+1],streak_sum[ii,streak_box[jj]+2],streak_sum[ii,streak_box[jj]+3],streak_sum[ii,streak_box[jj]+4]))

     ##   always add last streak_box array element to close last box
     boxes.append(streak_box[len(streak_box)-1]) ## stop index for box N

     if First_Streak == True :
        First_Streak = False
        j0 = 0
        ##   reg_pts[i, 0-3, k ] -> 4 points begin w/ lower left -> then counter-clockwise
        ##   reg_pts[i, j,  0-5] -> 6 variables are slangs struct: rx,ry,ra,dec,xx,yy
        ##   initialize reg_psts to hold first boxes list (=> len(boxes)/2 4 by 6 arrays)
        reg_pts = np.zeros( [ len(boxes)//2, 4, 6 ], float ) 
     else :
        ##   resize reg_pts to accommodate the next set of boxes
        ##
        ##   since the boxes list has separate indices for the start
        ##   and stop of boxes, reg_pts needs to increase by
        ##   len(boxes)/2.  Also, resizing must be done one additional
        ##   increase at a time instead of all at once, so looping is needed.
        j0 = reg_pts.shape[0] # next index after previous boxes in reg_pts
        for jj in range(0,len(boxes)//2) :
           dims = reg_pts.shape
           reg_pts.resize(dims[0]+1,dims[1],dims[2])

     ##   reg_pts[i,j]: i => box number, j=> [0-3] box corner #
     ##   Since the "boxes" list has separate indices for the start and 
     ##   stop y-position of the box and reg_pts has just 1 index
     ##   per box, two indices in the boxes list map to a single index, 
     ##   i, in the reg_pts array.  This means j0+jj/2 will advance 
     ##   by 1 value for every 2 steps in jj
     for jj in range(0,len(boxes),2) :
        reg_pts[j0+jj//2, 0, cRX] = streak_xlo[ii]+1.0-0.5;
        reg_pts[j0+jj//2, 0, cRY] = boxes[jj]+1.0-0.5;
        reg_pts[j0+jj//2, 1, cRX] = streak_xhi[ii]+1.0+0.5;
        reg_pts[j0+jj//2, 1, cRY] = boxes[jj]+1.0-0.5;
        reg_pts[j0+jj//2, 2, cRX] = streak_xhi[ii]+1.0+0.5;
        reg_pts[j0+jj//2, 2, cRY] = boxes[jj+1]+1.0+0.5;
        reg_pts[j0+jj//2, 3, cRX] = streak_xlo[ii]+1.0-0.5;
        reg_pts[j0+jj//2, 3, cRY] = boxes[jj+1]+1.0+0.5;

  return (reg_pts)

# --- end: get_multi_streak_regions

#
## Create img_bkg with the same size as arr_T
#
def bkg_img(arr_T, bkg, xmin,xmax,ymin,ymax,ydither) :
  xlen = len( arr_T [:,0])
  ylen = len( arr_T [0,:])
  img_bkg = np.zeros ( [ xlen, ylen ], float )
 
  for ii in range(xmin, xmax+1) :
     for jj in range(ymin, ymax+1) :

        img_bkg[ii,jj] = bkg[ii]

        if (jj <= ymin+(ydither-2)) :
           mm = ((jj-ymin)/(ydither-1.0))*img_bkg[ii,jj]+bkg[xmax] 
           img_bkg[ii,jj] = mm 
       
        if (jj >= ymax-(ydither-2)) :
           img_bkg[ii,jj] = ((ymax-jj)/(ydither-1.0))*img_bkg[ii,jj]+bkg[xmax] 

  return img_bkg

# --- end: bkg_img

#
## wrote _rotbkg
#
def write_image(outfile, out_img, img_block, tool_name) :

   img_name = "STREAKMAP";

   # if rotbkg exits, remove it
   remove_files( outfile )

   # create output image dataset
   out_ds = None
   out_ds = dmDatasetCreate( outfile )
   if ( out_ds == None ) :
      msg = tool_name+": ERROR: unable to create temporary file '"+outfile+"'\n",
      error_out( msg )

   # copy input image block to output image dataset, omitting img data
   out_block = None
   out_block = dmBlockCreateCopy(out_ds,img_name,img_block,False)

   # Get output image data descriptor, write image data
   out_dd = dmImageGetDataDescriptor(out_block);
   dmSetData(out_dd,out_img);

   # Close output image file
   dmDatasetClose( dmBlockGetDataset( out_block));

   return(0);

# --- end: write_image

#
##  make_bkg_map() in acis_streak_map (slang version )
#     infile :    _rotimg.fits
#     outfiles :  _rotbkg.fits
#                 _sfr.txt ( non-source-containing; source free row; )
#     input parameters : nsigma, msigma, dither, binsize, k1, k2
# 
def make_bkg_map(infile,nsigma,msigma,ssigma,dither,binsize,k1,k2,
                 outfile,chip,debug,tool_name) : 
  img_block = None 
  img_block = dmImageOpen(infile)

  img_dd = None
  img_dd = dmImageGetDataDescriptor(img_block)
  
  in_data  = dmGetData(img_dd)   
  in_data_xdim = len(in_data[:,0] )
  in_data_ydim = len(in_data[0,:] )

  matrix_T = in_data.transpose()
  arr_T  = np.array ( matrix_T )      # arr_T = slang's   img_data 

  xlen = len( arr_T [:,0])
  ylen = len( arr_T [0,:])

  xmin = 0
  xmax = 0
  ymin = 0
  ymax = 0

  (xmin,xmax,ymin,ymax)=image_limits( arr_T,xlen,ylen,chip)

  row_sum = np.zeros( ylen, float ) 

  for ii in range(0, ylen) :
      row_sum[ii] = np.sum( arr_T[:,ii])

  yminrow = ymin + int(dither)
  ymaxrow = ymax - int(dither)


  if (yminrow > ymaxrow) :
    yminrow=0
    ymaxrow=0 

  method="histo"    #"histo" -> use summed-over-rows counts histogram
	            #           statistics to determine src free rows
	            # "min"  -> use summed-over-rows counts min value +
                    #           n*sigma to determine src free rows
  (ww,sig) = src_free_rows(row_sum,yminrow,ymaxrow,nsigma,method,binsize,k1,k2,infile,debug)

  nrows = len(ww)
  if ((nrows==1) and (sig==0)) : 
     nrows = 0

  if (debug > 0) :
    display_msg("# acis_streak_map: INFO: infile="+infile+" nrows="+repr(nrows)+" sigma="+repr(sig))

  src_free = np.zeros( xlen, float )
  if (nrows > 0) :  # img has non-zero pixel
     # ------------------------------------------------------
     # create the 1-d array "src_free", the elements of which 
     # are the sum over non-source-containing (source free) 
     # rows of counts along a single column
     # ------------------------------------------------------
     for ii in range(0,xlen) :
        src_free[ii] = sum ( arr_T[ii,ww])

     # ----------------------------------------------------------------------
     #...calculate regions around streaks, excluding srcs embedded in streaks
     # ----------------------------------------------------------------------
     #  reg_pts =  [ i_streak,  4,  6 ]

     if ssigma <= sys.float_info.epsilon: #ssigma == -1 => don't exclude srcs from strks
        reg_pts = get_single_streak_regions(src_free,xmin,xmax,ymin,ymax,
                                            xlen,ylen,dither,msigma,debug) 
     else:
        reg_pts = get_multi_streak_regions(arr_T,src_free,xmin,xmax,ymin,ymax,
                                           xlen,ylen,dither,msigma,ssigma,debug) 
        

     bkg = src_free/(nrows*1.0) 
     new_img = bkg_img(arr_T,bkg,xmin,xmax,ymin,ymax,dither)   # arr_T=img_data

  # end: if nrows > 0 ; img has non-zero pixels 

  else  :    # nrows<=0 ; img has no non-zero pixels

     xlen = len( arr_T [:,0])
     ylen = len( arr_T [0,:])
     new_img = np.zeros ( [ xlen, ylen ], float )
     reg_pts = np.zeros ( [1,1,1 ], float )
  # end : else  if (nrows<=0) 

  #
  ## reorder image to reverse ordering on input
  #
  mat_T = new_img.transpose()
  amat_T    = np.array( mat_T)
  out_img = np.reshape( amat_T, (in_data_ydim, in_data_xdim))

  # 
  ## create _rotbkg
  #
  if ( in_data_ydim != in_data_xdim ) :
     write_image(outfile,out_img,img_block,tool_name)
  else :
     adj3 = np.zeros ( [in_data_xdim, in_data_ydim], float )
     adj3 = out_img.copy()
     write_image(outfile, adj3 ,img_block,tool_name)

  return reg_pts

# --- end: make_bkg_map

def rotate_pnt(xx, yy, ang, midx, midy ) :
  xx -= midx;
  yy -= midy;

  ca = np.cos(ang*DEG2RAD);
  sa = np.sin(ang*DEG2RAD);

  xp = ca*xx - sa*yy;
  yp = sa*xx + ca*yy;

  xp += midx;
  yp += midy;

  return (xp,yp);

# --- end: rotate_pnt

def calc_rot_ang( top_xy, bot_xy, debug ) :
   if (debug > 2) :
      display_msg("input: top_xy="+str(top_xy[0])+"; "+str(top_xy[1]))
      display_msg("input: bot_xy="+str(bot_xy[0])+"; "+str(bot_xy[1]))

   dx=( top_xy[0] - bot_xy[0] )     #delta_x
   dy=( top_xy[1] - bot_xy[1] )     #delta_y
   ang=math.atan(dy/dx )
   ang = (ang *180.0/math.pi)

   #...rhain [01NOV12] [top|bot]_xy now define a line rotated by 90 deg
   #   from prior definition.  To maintain same angle, add 90 deg to the
   #   computed angle.
   ang += 90.0

   if (ang<0)  :
      ang+=360

   if ( dx<0 )  :
      ang+= 180

   if ( ang>360 )  :
      ang-=360

   s_ang = "%.3f"%ang
   ang = float(s_ang )

   return -ang

# --- end: calc_rot_ang
 
def calc_logical_pixel( img_dd, mid_sky) :

  mid_logic = np.zeros(2, float) 

  # - Get physical coordinate descriptor
  phys = None
  phys = dmArrayGetAxisGroup(img_dd,1) 

  #  Get dimension of physical coords
  dim = dmGetElementDim(phys) 

  # - Get WCS coords
  world = dmDescriptorGetCoord(phys) 

  # - Convert SKY to LOGICAL coords
  mid_logic = dmCoordInvert(phys,mid_sky) 

  return mid_logic 

# --- end:  calc_logical_pixel

#
## convert chip coords to sky
#
def  calc_sky_pixel_ ( evtfile, chipId_num, chipx, chipy, debug) :
   raDecTanPt = {} 
   raDecTanPt[0] = get_key_value( evtfile, "RA_NOM" )
   raDecTanPt[1] = get_key_value( evtfile, "DEC_NOM" )

   cmd ="infile="+evtfile+" asol= opt=chip  chipx="+str(chipx)
   cmd =cmd+" chipy="+str(chipy)+" chip_id="+str(chipId_num)+"  verb=0"
   cmd =cmd+" ra_nom="+raDecTanPt[0]+" dec_nom="+raDecTanPt[1]+" celfmt=deg"
   exe_task( "dmcoords", cmd )

   skypix = np.zeros( 3, float )
   skypix[ 0 ] = paramio.pget( "dmcoords", "x" )
   skypix[ 1 ] = paramio.pget( "dmcoords", "y" )

   return skypix

# --- end: calc_sky_pixel_
  
#
## Convert logical coord to physical/wcs coord if coordType is physical/wcs
#
def calc_raDec_from_imgpix_( image_file, imgpix, coordType) :

   logicalx=imgpix[0]
   logicaly=imgpix[1]

   cmd = "infile="+image_file+"  asol=  opt=logical logicalx="+str(logicalx)
   cmd = cmd+" logicaly="+str(logicaly)+" verb=0"
   exe_task("dmcoords", cmd)
   skypix = np.zeros( 2, float )
   skypix[0] = paramio.pget( "dmcoords", "x" )   #float
   skypix[1] = paramio.pget( "dmcoords", "y" )

   if coordType != "WCS" :
      return  skypix

   raDecTanPt = {}
   raDecTanPt[0] = get_key_value( image_file, "RA_NOM")   #string
   raDecTanPt[1] = get_key_value( image_file, "DEC_NOM")  

   cmd = "infile="+image_file+"  asol=  opt=sky x="+str(skypix[0])
   cmd = cmd+" y="+str(skypix[1])+" ra_nom="+raDecTanPt[0]
   cmd = cmd+" dec_nom="+raDecTanPt[1]+" celfmt=deg verb=0"
   exe_task("dmcoords", cmd)

   raDec = np.zeros ( 2, float )
   raDec[0] = paramio.pget( "dmcoords", "ra" )     #wcs; float 
   raDec[1] = paramio.pget( "dmcoords", "dec" )

   return  raDec 

# --- end: calc_raDec_from_imgpix_


def add_history_to_file( file_name, tool_name, param_names, param_values):
   if os.path.exists( file_name) == False : 
      return 
   crate = read_file(file_name, mode="rw")
 
   histnum = crate.get_key("HISTNUM")
   if histnum is None:
      histnum = CrateKey()
      histnum.name = "HISTNUM"
      histnum.value = 1
      crate.add_key(histnum)
        
   startat = histnum.value
   startat = 1 if startat is None else startat

   hh = HistoryRecord(tool=tool_name,param_names=param_names,
                      param_values=param_values, asc_start=startat)    
   add_history(crate, hh) 
    
   # Increment HISTNUM for new lines
   hlen = hh.as_ASC_FITS().split("\n")
   histnum.value = startat+len(hlen)

   crate.write() 
# end: add_history_to_file  

# --- main code 
#
def main() :
   ##  --- get parameters
   pp = get_params( "acis_streak_map" )

   ##  --- open evt input file
   evt_block = None
   evt_block = open_input_tabFile( pp["infile"] )  # error out if fail to open

   tmpfilelist = []

   ##  --- check for output regfile  
   regfile = pp["regfile"]          # don't add history to regfile
   check_clob( pp["clobber"], regfile )

   # ------------------------------------------------------------ 
   #   open output streak region file to store streak regions,  
   #   write first 2 lines, to output either WCS or PHYS coords,
   #   based on REG_COORD_TYPE
   # ------------------------------------------------------------
   if ( regfile != "") :
      regFP = open(regfile,"w");      # outside chipId loop
      csys=""
      regFP.write( "# Region file format: DS9 version 3.0\n");
      regFP.write("global color=green select=1 highlite=1 edit=1 move=1 delete=1 include=1 fixed=0 source\n")

      if ("WCS" == REG_COORD_TYPE) :
         csys="fk5;"
      else :
         csys="physical;"
   # end: regfile != ""

   ## bkg output files 
   #
   #  get chip numbers from evtfile's DETNAM
   #
   debug  = int ( pp["verbose"] )
   detnam_dd, detnam = dmKeyRead(evt_block,"detnam")
   if (debug > 2) :
      display_msg("DETNAM is: "+detnam.decode('utf-8').split("-")[0] )

   ## --- close evt input file
   dmDatasetClose( dmBlockGetDataset(evt_block))

   chipnums = detnam.split(b"-")[1]
   bkgfiles = {}

   ## --- loop for chip numbers ii
   for ii in range( len(chipnums) ) : 
      if ( debug > 2 ) :
         display_msg("Making chip image "+repr(ii))

      # chip number (python 2/3 compatibility)
      chip_number_str = chipnums.decode("utf-8")[ii]
      chip_number = int(chip_number_str)

      # tmp files 
      path_string   = pp["tmppath"]+"/zzzchip_"+chip_number_str
      chip_img      = path_string+"_img.fits"
      chip_rotimg   = path_string+"_rotimg.fits"
      chip_rotbkg   = path_string+"_rotbkg.fits"

      bkgfiles[ ii ] = pp["bkgroot"]+chip_number_str+".fits"   # _bkg#.fits
      check_clob( pp["clobber"], bkgfiles[ ii ] )

      xl=0.0 ; xh=0.0 ; yl=0.0 ; yh=0.0 ; xmid=0.0 ; ymid=0.0 ;

      ## --- start creating files
      if ( debug > 1 ) :         # processing chip 0 ...
         display_msg("\n processing chip "+chip_number_str+" ...")
      
      ## --- creating _img.fits
      regstr1 = "bounds(region("+pp["fovfile"]+"[ccd_id="+chip_number_str+"]))"

      xl,yl,xh,yh  = regExtent(regParse(regstr1))    #region lib; float;
      xmid = (xl+xh)/2.0            # float 
      ymid = (yl+yh)/2.0

      xl = abs(xh-xl) + int(pp["dither"])
      yl = abs(yh-yl) + int(pp["dither"])
      if ( xl < ( 1024.0+2.0*int(pp["dither"])  ) ) : 
         xl = 1024.0+2.0 * int(pp["dither"])

      s_xmid ="%.3f"%xmid
      s_ymid ="%.3f"%ymid

      regstr2="box("+s_xmid+","+s_ymid+","+str(int(xl))+","+str(int(yl))+")"

      # always set clob=yes
      cmd = r'"'+pp["infile"]+r'[ccd_id='+chip_number_str+r',sky='+regstr2+r'][bin sky=1][opt type=i4]"  '+chip_img+" clob+"

      if (debug > 2) :
         display_msg("\n dmcopy command is : dmcopy "+cmd+"\n")

      exe_task("dmcopy", cmd)

      ## --- checking _img.fits
      find_output_file( chip_img ) 

      ## --- open image data descriptor for header keywords, axis dim's
      img_block = None
      img_dd    = None
      img_block,img_dd = open_input_imgFile(chip_img)  # /tmp/zzzchip_#_img.fits

      ## --- get sky pix coords in the middle of top edge of chip
      top_xy = np.zeros( 2, float) 
      #...rhain [01NOV12] change top_xy and bot_xy to define a line which
      #   is rotated 90 deg from prior definition, and is now in the same
      #   direction as the potential streaks (FOV regions are not square,
      #   and streak is aligned with FOV segments parallel to streak).
      top_xy = calc_sky_pixel_( pp["infile"], chip_number, 512, 1, debug)

      if (debug > 2 ):
         display_msg("top_x and top_y are: "+str(top_xy[0])+"  "+str(top_xy[1]))

      ## --- get sky pix coords in the middle of bottom edge of chip
      bot_xy = np.zeros( 2, float)
      raDecTanPt = {}

      #...rhain [01NOV12] change top_xy and bot_xy to define a line which
      #   is rotated 90 deg from prior definition, and is now in the same
      #   direction as the potential streaks (FOV regions are not square,
      #   and streak is aligned with FOV segments parallel to streak).
      bot_xy = calc_sky_pixel_( pp["infile"], chip_number, 512, 1024,debug)
      if (debug > 2 ):
         display_msg("bot_x and bot_y are: "+str(bot_xy[0])+"  "+str(bot_xy[1]))

      ## --- calculate chip rotation angle
      ang = 0.0 
      ang = calc_rot_ang(top_xy,bot_xy,debug)

      if (debug > 3) :
         display_msg("ang =  " + str(ang)) 

      ## --- get sky pixel coords in the middle chip in logical units
      mid_sky = np.zeros(2, float) 
      mid_log = np.zeros(2, float)
      mid_sky = [xmid, ymid]
      mid_log = calc_logical_pixel( img_dd, mid_sky )

      if (debug > 2) :
        display_msg("mid_sky x,y   = %.3f "%mid_sky[0]+"%.3f"%mid_sky[1])
        display_msg("mid_logic x,y = "+str(mid_log[0])+"  "+str(mid_log[1]))

      ## --- get image axis lengths
      arrlen = {}         # float[2]
      arrlen = dmGetArrayDimensions(img_dd)
      if (debug > 3 ) :
         display_msg("chip image axis lengths = "+str(arrlen[0])+"  "+str(arrlen[1]))

      ## --- rotate chip image and create _rotimg.fits
      smid_x = "%.3f"%mid_log[0]
      smid_y = "%.3f"%mid_log[1]

      cmd = "infile="+chip_img+" outfile="+chip_rotimg+" coord_sys=logical theta="+str(ang)+" rotxcen="+smid_x+" rotycen="+smid_y+" xoffset=0 yoffset=0  clob+ res=0"

      if (debug > 2 ) :
         display_msg("\n Rotate image dmregrid command is :  dmregrid2 "+cmd+"\n")

      exe_task("dmregrid2", cmd )

      find_output_file( chip_rotimg )

      ## --- make background map
      nsigma = pp["nsigma"] 
      msigma = pp["msigma"] 
      ssigma = pp["ssigma"] 
      dither = pp["dither"] 
      binsize= pp["binsize"] 
      k1     = pp["k1"] 
      k2     = pp["k2"] 
      tool_name = "acis_streak_map"
      reg_pts = make_bkg_map(chip_rotimg,nsigma,msigma,ssigma,dither,binsize,k1,k2,chip_rotbkg,chip_number,debug,tool_name) 
      if (debug > 1 ) :
        display_msg("len(regpts[:,0]) is: "+str(len(reg_pts[:,0])))
        display_msg("len(regpts[0,:]) is: "+str(len(reg_pts[0,:])))

      ## --- rotate chip image back to original orientation ; creating _bkg;
      cmd = "infile="+chip_rotbkg+" outfile="+bkgfiles[ ii ]+" coord_sys=logical theta="+str(-1.0*ang)+" rotxcen="+smid_x+" rotycen="+smid_y+" xoffset=0 yoffset=0  clob+ res=0" 

      if (debug > 2 ) :
         display_msg("\nRotate image back dmregrid command is: \n dmregrid2 "+cmd);

      exe_task("dmregrid2", cmd )

      ## --- check success of dmregrid command
      find_output_file( bkgfiles[ ii ] )
      tmpfilelist.append ( bkgfiles[ ii ] ) 

      ## --- creating pp["regfile"] 
      ##     if pp["regfile"] is specified, rotate region boundary point back to
      ##     original orientation, then convert to sky pixel coords 
      if (regfile != "") :
         ## 
         ## -- If reg_pts is Not Nx4x6, no streak regions were found, skip output.
         ##    dimm = [num_of_streaks, 4_box_corners, cRA/cDEC/cRX/cRY/cXX/cYY]
         dimm = reg_pts.shape       
         if ( dimm[1] == 4 ) :
            for jj in range(0, dimm[0]) :
               for kk in range(0, 4) : 
                  imgpix = np.zeros( 2, float ) 
                  raDec  = np.zeros( 2, float )

                  (reg_pts[jj,kk,cXX], reg_pts[jj,kk,cYY])=rotate_pnt(reg_pts[jj,kk,cRX], reg_pts[jj,kk,cRY], -1.0*ang, mid_log[0],mid_log[1])
                  imgpix = [ reg_pts[ jj,kk,cXX ], reg_pts[ jj,kk,cYY ] ]  
                  raDec = calc_raDec_from_imgpix_( chip_img, imgpix, REG_COORD_TYPE);
                  reg_pts[jj,kk, cRA ]  = raDec[0] 
                  reg_pts[jj,kk, cDEC ] = raDec[1]
               # end: for kk

               ## --- write region in ra, dec coords (or physical coords,)
               ##     as specified in REG_COORD_TYPE
               # csys = 'physical'
               # create pp["regfile"]
               oRa0  = "%.3f"%reg_pts[jj,0,cRA]
               oDec0 = "%.3f"%reg_pts[jj,0,cDEC]
               oRa1  = "%.3f"%reg_pts[jj,1,cRA]
               oDec1 = "%.3f"%reg_pts[jj,1,cDEC]
               oRa2  = "%.3f"%reg_pts[jj,2,cRA]
               oDec2 = "%.3f"%reg_pts[jj,2,cDEC]
               oRa3  = "%.3f"%reg_pts[jj,3,cRA]
               oDec3 = "%.3f"%reg_pts[jj,3,cDEC]
               cmd =  csys+"Polygon("+oRa0+","+oDec0+","+oRa1+","
               cmd += oDec1+","+oRa2+","+oDec2+","+oRa3+","+oDec3+")\n"
               regFP.write( cmd )
            # end: for jj
         # end:  if (len(reg_pts[0,*]) == 4)
      # end:  if (regfile != "") :
   ##  end: loop chipid

   if (regfile != "") :
      regFP.close()     # outside chipId loop 

   ## add hiistory
   # Get parameter name/values in order stored in parameter file
   pnames= paramio.plist( "acis_streak_map" )
   pvals = [ pp[x] for x in pnames] 
   for ofile in tmpfilelist : 
      add_history_to_file( ofile, "acis_streak_map", pnames, pvals)

   ##  clean up 
   remove_files ( pp["tmppath"]+"/zzzchip_*" )

   ##  all done 
   if debug > 0 :
      display_msg(".... done! ....\n")

   sys.exit(0)

if __name__=="__main__":
    main()
