#!/usr/bin/env python
#
#
#
# Copyright (C) 2012, 2014, 2016, 2018, 2020-2023
# Smithsonian Astrophysical Observatory
#
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

import os
import sys
import subprocess

toolname = "simulate_psf"
__revision__ = "15 July 2024"


from cxcdm import *
import paramio as pio
import stk as stk


import ciao_contrib.logger_wrapper as logWrap
logWrap.initialize_logger(toolname)
verb0 = logWrap.get_logger(toolname).verbose0
verb1 = logWrap.get_logger(toolname).verbose1
verb2 = logWrap.get_logger(toolname).verbose2
verb3 = logWrap.get_logger(toolname).verbose3
from ciao_contrib.logger_wrapper import handle_ciao_errors
from ciao_contrib.runtool import make_tool



class verboseCall( object ):
    """
    Wrap the routines with a verbose check
    """
    def __init__(self, verblvl):
        """
        each wrapper can have its own verbose level trigger
        """
        self.verblvl = verblvl
        self.verb_print = verb1
        if verblvl == 2:
            self.verb_print = verb2
        if verblvl == 3:
            self.verb_print = verb3

    def __call__(self, func ):
        """
        call the wrapped function
        """
        def wrapped_func( *args ):
            self.verb_print("Started "+func.__name__)
            pdq = func(*args)
            self.verb_print("Finished "+func.__name__)
            return pdq
        return wrapped_func


def gorm( infile ):
    """
    Wrapper around the os.remove command -- setting SAVE_ALL to anything
    will keep all temp products for debugging
    """
    if 'SAVE_ALL' not in os.environ:
        if os.path.exists( infile ):
            os.remove(infile)


@verboseCall(1)
def check_setup( pars ):
    """
    Check that marx and saotrace are setup
    """

    if pars["simulator"] == "saotrace":
        if is_none( pars["_saotrace_db"] ):
            raise Exception("SAOTrace database not provided in parameter saotrace_db")
        if is_none( pars["_saotrace_install"] ):
            raise Exception("SAOTrace path not provided in parameter saotrace_install")

        if not os.path.exists(os.path.join(pars["_saotrace_install"],"lib","perl5")):
            raise Exception("Missing SAOTrace perl libraries")
        os.environ["PERL5LIB"] = os.path.join(pars["_saotrace_install"],"lib","perl5")

        
    if "marx" in [pars["simulator"], pars["projector"]]:
        if is_none( pars["marx_root"]):
            from shutil import which
            marx = which('marx')
            if marx is None:
                raise Exception("ERROR: Cannot determine marx_root. Please specify marx installation directory.")

            bindir = os.path.dirname(marx)               # ie ".."
            pars["marx_root"] = os.path.dirname(bindir)  # another ".."

    if "yes" == pars["pileup"]:
        verb1("\n# {0} ({1}): Pileup warning:  PSF images are not correctly normalized when pileup=yes is used.\n".format(toolname,__revision__))

#
# set_env_vars routine
#
@verboseCall(2)
def set_env_vars( pars ):
    """
    Setup the environment varialbes needed by SAOTrace and
    optionally Marx
    """

    #
    # Sets path vars for saotrace, trace-nest
    #
    pfile_env = os.getenv("PFILES")
    os.environ["PFILES"] = pfile_env + ":" + pars["_saotrace_install"] + "/share/uparm"

    path_env = os.getenv("PATH")
    os.environ["PATH"] = path_env + ":" + pars["_saotrace_install"] + "/bin"
    os.environ["SAOTRACE_DB"] = pars["_saotrace_db"]

    #
    # Set path vars for marx
    #
    path_env = os.getenv("PATH")
    os.environ["PATH"] = path_env + ":" + pars["marx_root"] + "/bin"

    pfile_env = os.getenv("PFILES")
    os.environ["PFILES"] = pfile_env + ":" + pars["marx_root"] + "/share/marx/pfiles"

    verb3( "PFILES = {0}".format( os.environ["PFILES"] ))
    verb3( "PATH   = {0}".format( os.environ["PATH"] ))


#
# routine to call each command line tool
#
def run_task(command, params, punlearn=False):
    """
    Run a parameter, key=value style tool; quotes values to protect from
    shell.

    TODO: Could create a custon runtool; parameters are long though
    """

    cmd = str(command)
    if punlearn:
        pio.punlearn(cmd)

    for key in params:
        value = params[key]
        # construct the command line
        cmd += " " + key + "=" + "\"" + str(value) + "\""
    cmd += " mode=\"hl\"" # always add a mode=h to the end

    sub_cmd = [ str(command) ]
    for key in params:
        value = params[key]
        sub_cmd.append( '{}={}'.format(key,value))
    sub_cmd.append( "mode=hl" )

    try:
        subprocess.check_call( sub_cmd )
    except:
        raise Exception("Problem running {0}".format(cmd))


#
# cleanup all the temp files left behind by saotrace
#
@verboseCall(3)
def cleanup_saotrace( outroot ):
    import glob as glob

    for extn in [ "log", "block", "lua", "summary.yml", "gi", "totwt-in", "totwt-out"]:
        for tt in glob.glob( outroot+"_00[1346][\._]*"+extn ):
            gorm( tt )
        if extn not in ["log", "lua", "gi"]:
            gorm( outroot + "." + extn )


def is_none( val ):
    """
    Check for a non/blank/empty string
    """
    if None == val:
        return True
    if 0 == len(val):
        return True
    if "none" == val.lower():
        return True
    return False


@verboseCall(2)
def resolve_parameter_dependencies( pars ):
    """
    Depending on the parameters, different restrictions are
    imposed by saotrace or marx
    """
    pars["__has_asp"]     = True
    pars["__has_spect"]   = True
    pars["__use_ra_dec"]  = True
    pars["__marxmodel"]   = "SAOSAC"
    pars["__usemarx"]     = False
    pars["_rayfile_"]     = None
    pars["__xygrid"]      = None
    pars["__imgstk"]      = []
    pars["__raystk"]      = []
    pars["__spectrum"]    = None
    pars["__flux"]        = 1.0

    if pars["asolfile"].strip() in [ "", "INDEF"]:
        from ciao_contrib.ancillaryfiles import find_ancillary_files
        try:
            ff = find_ancillary_files(pars["infile"], ["asol", ])
            if ff is None:
                raise ValueError("")
            ff = ff[0]
            if any(map( is_none, ff)):
                raise ValueError("")
            pars["asolfile"] = ",".join(ff)
        except:
            sys.stderr.write("Cannot locate aspect solution file.  Will proceed without it.\n")
            pars["asolfile"]="none"


    if is_none( pars["asolfile"] ):
        pars["__has_asp"] = False
        pars["__use_ra_dec"] = False

    if is_none( pars["spectrumfile"] ):
        pars["__has_spect"] = False

        if (pars['monoenergy']=="INDEF") and ("file"!=pars["simulator"]):
            raise Exception("energy must be set if spectrum left blank")

        if (pars["flux"] == "INDEF")  and ("file"!=pars["simulator"]):
                raise Exception("flux must be set if spectrum left blank")

        pars["__flux"] = pars["flux"]

    else: # specturm is not none
        pars["__spectrum"] = make_spectrum( pars )
        if pars["flux"] == "INDEF":
            if "marx" == pars["simulator"]:
                pars["__flux"] = "-1"
            else:
                pars["__flux"] = "1.0"
        else:
            pars["__flux"] = pars["flux"]


    if "marx" == pars["simulator"]:
        if "marx" != pars["projector"]:
            raise Exception("If using simulator='marx' you must also use projector='marx'")
        pars["__marxmodel"] = "POINT"
        pars["__usemarx"]   = True

    if "marx" == pars["projector"] and "yes" == pars["pileup"] and "yes" == pars["extended"]:
        raise Exception("ACIS pileup cannot be simulated with extended detectors")

    if "file" == pars["simulator"]:
        raystk = stk.build( pars["rayfile"] )
        if len(raystk) != int( pars["numiter"] ):
            verb1("{} rayfiles provided; ignoring numiter={} parameter".format(len(raystk), pars["numiter"] ))
        pars["numiter"] = str(len(raystk))

        check_rayfiles(raystk)

        if ("INDEF" == pars["ra"]) and ("INDEF" == pars["dec"]):
            from pycrates import read_file
            f0=read_file(raystk[0])
            pars["ra"]=f0.get_key_value("SRC_RA")
            pars["dec"]=f0.get_key_value("SRC_DEC")

    if int(pars["random_seed"]) <= 0:
        import random as random
        pars["random_seed"] = str( random.getrandbits(29) )
        
    if pars["numiter"] == "INDEF":
        if pars["numrays"] == "INDEF":
            raise RuntimeError("Please specify the number of iterations or number of rays")
        pars["numiter"] = "1"   # start w/ 1 will increase as needed
    else:
        if pars["numrays"] != "INDEF":
            verb0("WARNING: Cannot specify both number of iterations "+
                  "and number of rays. Using number of iterations")
        pars["numrays"] = "INDEF"  # numiter wins
            


def check_rayfiles( raystk ):
    from pycrates import read_file

    keys_to_check = [ "SRC_RA", "SRC_DEC", "CONFFILE", "ASOLFILE" ]

    if 1 == len(raystk):
        return

    for k in keys_to_check:
        k0 = read_file(raystk[0]).get_key_value(k)

        for r in raystk[1:]:
            kn = read_file(r).get_key_value(k)
            if k0 != kn:
                raise RuntimeError("Mismatch in ray file keyword {} in {}".format(k, r))


@verboseCall(2)
def make_spectrum( pars ):
    """
    Take the inputs saotrace file and convert it into
    an .rdb file format that can be used by saotrace.

    Input file must have 3 columns.  Expected format
    is energy_lo, energy_hi, flux in photon*cm^2/sec

    This will create a temp file that will get cleaned up
    when saotrace is done.
    """
    import pycrates as pyc
    import numpy as np

    infile=pars["spectrumfile"]
    outfile=pars["outroot"]+"spectrum.rdb"

    if infile.lower().endswith(".rdb"):
        # DM/crates will not read .rdb files
        raise ValueError("Input spectrum file name cannot end with '.rdb'")

    tab = pyc.read_file( infile )  # crates won't open files with  .rdb file name

    if "saotrace" == pars["simulator"]:
        if len( tab.get_colnames() ) != 3:
            raise Exception("SAOTrace spectrum file '{0}' must have 3 columns: energy_lo, energy_hi, and flux".format(infile))
        elo = tab.get_column(0).values
        ehi = tab.get_column(1).values
        flx = tab.get_column(2).values
        de = None

    elif "marx" == pars["simulator"]:
        if len( tab.get_colnames() ) == 2:
            elo = tab.get_column(0).values
            ehi = elo
            flx = tab.get_column(1).values
            de = np.ones_like( elo )
        elif len( tab.get_colnames() ) == 3:
            elo = tab.get_column(0).values
            ehi = tab.get_column(1).values
            flx = tab.get_column(2).values
            de = ehi - elo
        else:
            raise Exception("MARX spectrum must have either 2 or 3 columns")
    else:
        verb1("Spectrum is ignored for rayfile input")
        return None

    if False == all( np.isreal( elo )):
        raise Exception("energy_low values must be real values")

    if False == all( np.isreal( ehi )):
        raise Exception("energy_hi values must be real values")

    if False == all( np.isreal( flx )):
        raise Exceptoin("flux values must be real values")

    if (min(elo) < 0.1) | ( min(ehi) < 0.1) :
        raise Exception("Minimum allowed energy is 0.1 keV")

    if (max(elo) > 10) | (max(ehi) > 10 ):
        raise Exception("Maximum allowed energy is 10 keV")

    for lo_hi in zip(elo,ehi):
        if lo_hi[0] > lo_hi[1]:
            raise ValueError("Max energy must be >= min energy")

    if min(flx) <= 0:
        raise Exception("Flux cannot be <= 0")

    fp = open( outfile, "w" )

    if "saotrace" == pars["simulator"]:
        """
          SAOTrace is 3 column rdb format
        """
        fp.write("emin\temax\tflux\n")
        fp.write("N\tN\tN\n")
        for dd in zip( elo, ehi, flx ):
            fp.write( "{0}\t{1}\t{2}\n".format( dd[0], dd[1], dd[2] ))
    else:
        """
          Marx is 2 column energy flux
        """
        for dd in zip( ehi, flx, de ):
            fp.write( "{0}\t{1}\n".format( dd[0], dd[1]/dd[2]))
    fp.close()

    return outfile




#
# Build srcpars parameter string
#
@verboseCall(2)
def build_source_pars( pars, obi_info, src_info):
    """
    There are 3 parts to the src param / lua script that is fed
    to saotrace
      - dither file
      - src location
      - src energy
    depending on the user inputs the choice of parameters are
    coded below
    """
    srcpars = ""

    #put together the srcpars parameter
    if pars['__has_asp']:
        srcpars += "dither_asol{ "
        srcpars += "  file = '"+pars["asolfile"]+"', "
        srcpars += "  ra = "  +str(obi_info["ra_pnt"])+", "
        srcpars += "  dec = " +str(obi_info["dec_pnt"])+", "
        srcpars += "  roll = "+str(obi_info["roll_pnt"])
        srcpars += "};"
    srcpars +=     "point{ "
    srcpars +=     "  position = { "
    if pars['__use_ra_dec']:
        srcpars += "    ra = '"    +str(src_info["ra_src"])+"', "
        srcpars += "    dec = '"   +str(src_info["dec_src"])+"', "
        srcpars += "    ra_aimpt=" +str(obi_info["ra_pnt"])+", "
        srcpars += "    dec_aimpt="+str(obi_info["dec_pnt"])
    else:
        srcpars += "    theta = "+str(src_info["theta"])+", "
        srcpars += "    phi = "+str(src_info["phi"])+", "
    srcpars +=     "  },"

    if pars['__has_spect']:
        srcpars += "  spectrum={ { "
        srcpars += "    file='"+pars["__spectrum"]+"',"
        srcpars += "    units='photons/s/cm2',"
        srcpars += "    format='rdb',"
        srcpars += "    scale="+str(pars["__flux"])
        srcpars += "  }}"
    else:
        srcpars += "  spectrum = { { "
        srcpars += "    "+str(pars["monoenergy"]) +","
        srcpars += "    "+str(pars["__flux"])
        srcpars += "  }}"
    srcpars +=     "}"

    return srcpars



#
# run_saotrace routine
#
@verboseCall(1)
def run_saotrace(pars, obi_info, src_info ):
    """
    Run SAO Trace v 2.0.4
      - Currently requires spectrum to be in correct format
        which diff from what contrib script creates
      - Needs a single asol file, have to dmmerge
      - Fudge factors on time

    """

    #create basename for output files
    outroot = pars["_outroot_"] + "_rays"

    #Running SAO-SAC
    saotrace = {}

    saotrace["limit_type"] = "sec"
    saotrace["tstart"]    = float(obi_info["tstart"]) + 2.5 # fudge factor
    saotrace["limit"]     = float(obi_info["tstop"]) - float(obi_info["tstart"]) - 5

    saotrace["srcpars"]       = build_source_pars( pars, obi_info, src_info )

    saotrace["tag"]           = outroot
    saotrace["shells"]        = "all"
    saotrace["z"]             = "10079.771666816"
    saotrace["src"]           = "default"
    saotrace["output"]        = "default"
    saotrace["output_fmt"]    = "fits-axaf"
    saotrace["output_coord"]  = "hrma"
    saotrace["output_fields"] = "min"
    saotrace["seed1"]         = str(pars["random_seed"])
    saotrace["seed2"]         = "1"
    saotrace["block"]         = "1"
    saotrace["block_inc"]     = "100"
    saotrace["focus"]         = "no"
    saotrace["tally"]         = "0"
    saotrace["throttle"]      = "0"
    saotrace["throttle_poisson"] = "no"
    saotrace["config_dir"]    = pars["_saotrace_db"] + "/ts_config"
    saotrace["config_db"]     = pars["_saotrace_mirror"]
    saotrace["clean"]         = "rays"
    saotrace["debug"]         = ""

    # Run with unique/temp PFILES so that no colossion if running in parallel
    from ciao_contrib.runtool import new_pfiles_environment
    with new_pfiles_environment() as zoo:
        run_task(pars["_saotrace_install"]+"/bin/trace-nest", saotrace)

    cleanup_saotrace( outroot )

    # cleanup extra keywords
    del_tdmin_max( pars["_rayfile_"])
    add_saotrace_metadata( pars, obi_info, src_info )



#
# del_tdmin_max
#
@verboseCall(3)
def del_tdmin_max(raysfile):
    """
    Delete the TDMIN and TDMAX keywords that saotrace writes.  These
    are not dealt with by DM properly when both TDMIN and TLMIN exist.
    """
    block = dmTableOpen(raysfile, update=True )
    num_cols = dmTableGetNoCols(block)

    # range 1 to num_cols inclusive
    for col in range(1, num_cols + 1):
        try:
            key = dmKeyOpen(block, "TDMIN" + str(col))
            dmDescriptorDelete( key )
            key = dmKeyOpen(block, "TDMAX" + str(col))
            dmDescriptorDelete( key )
        except:
            pass
    dmDatasetClose(dmBlockGetDataset(block))


def add_saotrace_metadata( pars, obi_info, src_info ):
    """

    """
    from datetime import datetime
    import pycrates as pc

    metadata = [ ('SRC_E'    ,pars["monoenergy"], "[keV] Energy used to generate rays"),
                 ('SRC_DENS' ,0, "[r/mm^2] Ray density"),
                 ('SRC_FLUX' ,pars["flux"], "Source flux"),
                 ('SRC_SPEC' ,pars["spectrumfile"], "Spectrum used to generate rays"),
                 ('SRC_EXPT' ,float(obi_info["tstop"]) - float(obi_info["tstart"]) - 5, "[sec] Exposure time of simulation"),
                 ('TOTCTS'   ,0, "Total counts (weight) in PSF image"),
                 ('SRC_THET' ,src_info['theta'], "[arcmin] input off axis angle"),
                 ('SRC_PHI'  ,src_info['phi'], "[deg] input azimuthal angle"),
                 ('SRC_RA'   ,src_info['ra_src'], "[deg] Input source Right Ascension"),
                 ('SRC_DEC'  ,src_info['dec_src'], "[deg] Input source Declination"),
                 ('TELESCOP' ,'CHANDRA' , "Telescope"),
                 ('INSTRUME' ,'TEL', "Instrument"),
                 ('DETNAM'   ,'HRMA', "Detector"),
                 ('CREATOR'  ,'saotrace', "Software that created this file"),
                 ('PSFSEED2' ,pars["random_seed"], "Secondary random seed"),
                 ('PSFBLOCK' ,1, "Block for saosac"),
                 ('CONTENT'  ,'RAYS', 'File contains simulated data'),
                 ('ORIGIN'   ,'Unknown', 'The origin of the FITS file'),
                 ('DATE'     ,datetime.now().strftime("%Y-%m-%dT%H:%M:%S"), "Date FITS file was created"),
                 ('DATE-OBS' ,obi_info["date-obs"], "Date simulated"),
                 ('HDUNAME'  ,'PSFRAYS', "" ),
                 ('HDUCLASS' ,'ASC', ""),
                 ('HDUCLAS1' ,'RESPONSE', ""),
                 ('HDUCLAS2' ,'PSF', ""),
                 ('HDUCLAS3' ,'PSFRAY', ""),
                 ('HDUVERS1' ,'1.0.0', ""),
                 ('RAYMETH'  ,'SAOTrace', ""),
                 ('CONFFILE' ,pars["_saotrace_mirror"], ""),
                 ('ASOLFILE' ,os.path.basename(pars["asolfile"]), ""),
              ]

    tab = pc.read_file(pars["_rayfile_"], mode="rw")

    for key,val,com in metadata:
        if val == None:  val = "none"
        pc.set_key( tab, key, val, desc=com)

    tab.write()
    add_history_to_file( pars["_rayfile_"], pars )



#
#  What marx detector will have the ccd on?
#
@verboseCall(2)
def determine_marx_detector( obi_info, src_info ):
    """

    We need to know which ccd the rays are going to hit.
    The way marx works, ACIS-0-3 are only "on" when detector is ACIS-I.
    So we cannot use the SIM_Z location to determine these quantities.
    (eg ACIS-7 on when sim_z says at -I aimpoint).  We need to know X,Y and map
    those to chipId.  With chipId then we can set the detector and aimpoint.

    """
    aimpoint = ""
    detector = ""

    chip = int( src_info['chip_id'] )
    if("ACIS" == obi_info["instrume"]):
        if(chip <= 3):
            aimpoint = "AI2"
            detector = "ACIS-I"
        else:
            aimpoint = "AS1"
            detector = "ACIS-S"
    else:
        if(chip == 0):
            aimpoint = "HI1"
            detector = "HRC-I"
        else:
            aimpoint = "HS1"
            detector = "HRC-S"
    return aimpoint, detector


#
# Lookup the aimpoint in the pixlib cal file and get
# the default SIM location.  These are used to offset
# to match observation
#
@verboseCall(2)
def determine_marx_sim( pars,aimpoint ):
    """
    Read the SIM values from the pixlib cal file for the current aimpoint
    """
    from pycrates import read_file as read_file
    CALFILE = pars["marx_root"] + "/share/marx/data/caldb/telD1999-07-23aimptsN0002.fits"
    ext = "[aimpoint_name=" + aimpoint + "]"
    sim = read_file(CALFILE + ext).get_column("AIMPOINT").values

    return sim

#
# Run the marx program
#
@verboseCall(2)
def run_marx_exe( pars, obi_info, src_info, detector, sim ):
    """
        Run the marx executable to create rays
    """
    marx = {}
    #
    # Always set marx offset from whichever aimpoint includes ccd
    #
    marx["DetOffsetX"] = float(obi_info["sim_x"]) - sim[0][0]
    marx["DetOffsetY"] = float(obi_info["sim_y"]) - sim[0][1]
    marx["DetOffsetZ"] = float(obi_info["sim_z"]) - sim[0][2]

    #
    # set the ra/dec and *_nom parameters
    # MARX requires decimal ra/dec
    #
    marx["SourceRA"]  = src_info["ra_src"]
    marx["SourceDEC"] = src_info["dec_src"]
    marx["RA_Nom"]    = obi_info["ra_pnt"]
    marx["Dec_Nom"]   = obi_info["dec_pnt"]
    marx["Roll_Nom"]  = obi_info["roll_pnt"]

    #
    # Setup rest of marx parameters
    #
    marx["SourceType"]         = pars["__marxmodel"]

    if "marx" != pars["simulator"]:
        marx["SAOSACFile"]         = pars["_rayfile_"]
    else:
        if pars["__has_spect"]:
            marx["SpectrumType"] = "FILE"
            marx["SpectrumFile"] = pars["__spectrum"]

        else:
            marx["SpectrumType"] = "FLAT"
            marx["MinEnergy"] = pars["monoenergy"]
            marx["MaxEnergy"] = pars["monoenergy"]

        marx["SourceFlux"] = pars["__flux"]

    marx["OutputDir"]          = pars["_outroot_"]+"_marx.dir"

    if pars["__has_asp"]:
        marx["DitherModel"]        = "FILE"
        marx["DitherFile"]         = pars["asolfile"]
    else:
        marx["DitherModel"] = "INTERNAL"
        if "ACIS" == obi_info["instrume"]:
            marx["DitherAmp_RA"]     = 8.0
            marx["DitherAmp_Dec"]    = 8.0
            marx["DitherPeriod_RA"]  = 1000
            marx["DitherPeriod_Dec"] = 707
        else: # HRC, taken from POG
            marx["DitherAmp_RA"]     = 20.0
            marx["DitherAmp_Dec"]    = 20.0
            marx["DitherPeriod_RA"]  = 1087
            marx["DitherPeriod_Dec"] = 768


    marx["AspectBlur"]         = pars["blur"]
    marx["RandomSeed"]         = int(pars["random_seed"])

    marx["DetectorType"]       = detector
    marx["GratingType"]        = obi_info["grating"]
    marx["Verbose"]            = "no" if int(pars["verbose"]) < 2 else "yes"
    marx["ACIS_Exposure_Time"] = obi_info["exptime"]
    marx["ExposureTime"]       = float(obi_info["tstop"]) - float(obi_info["tstart"]) - 5

    #
    # Readout streak?
    #
    if "yes" == pars["readout_streak"]:
        marx["ACIS_Frame_Transfer_Time"] = obi_info["__xfer"]
    else:
        marx["ACIS_Frame_Transfer_Time"] = 0

    #
    # Ideal / extended?
    #
    marx["DetIdeal"] = pars["ideal"]
    marx["DetExtendFlag"] = pars["extended"]

    #
    # Set the marx tstart
    #
    # as per marx docs, values > 2100 are interpreted as chandra
    # time, else values < 2100 are year (eg 2009.5 is mid 2019).
    #
    marx['TStart'] = obi_info["tstart"]

    # -----
    # Run marx
    #
    run_task("marx", marx, punlearn=True)

    if not os.path.exists( marx["OutputDir"]+"/time.dat"):
        raise RuntimeError("The output from MARX contains no events.  The flux value(s) is (are) probably too small.")

    if pars["keepiter"] == "yes":
        import stat as stat
        import shutil as shutil

        shutil.copyfile( pars["marx_root"] + "/share/marx/pfiles/marx.par",
                     pars["_outroot_"]+"_marx.par" )
        make_write = os.stat( pars["_outroot_"]+"_marx.par" ).st_mode
        os.chmod( pars["_outroot_"]+"_marx.par", make_write|stat.S_IRUSR|stat.S_IWUSR)
        for mm in marx:
            pio.pset( pars["_outroot_"]+"_marx.par", mm, str(marx[mm]))

    return marx["OutputDir"]


#
#  Run marx2fits w/ options for pileup
#
@verboseCall(2)
def run_marx2fits( pars, obi_info, output_dir ):
    #
    # Optionally run pileup module
    #

    dev_null = open("/dev/null", "wb")

    if "yes" == pars["pileup"]:
        cmd = "marxpileup MarxOutputDir={0} FrameTime={1} > /dev/null".format(output_dir, obi_info["exptime"])
        verb3(">>> "+cmd)

        try:
            subprocess.check_call( ['marxpileup', 'MarxOutputDir='+output_dir,
                                    'FrameTime={}'.format(obi_info["exptime"])],
                                    stdout=dev_null)
        except:
            raise Exception("Problem running {0}".format(cmd))

    #
    # Make into fits file.
    #  Set the pixadj parmeter based on config
    #
    if "HRC" == obi_info["instrume"]:
        pixadj = "RANDOMIZE"
    elif "NONE" != obi_info["grating"]:
        pixadj = "RANDOMIZE"
    else: # ACIS, no gratings
        pixadj = "EDSER"

    cmd = "marx2fits --pixadj="+pixadj+" "

    if "yes" == pars["pileup"]:
        cmd += "--pileup "
        cmd += output_dir+"/pileup "
    else:
        cmd += output_dir+" "
    cmd += pars["_outroot_"]+ "_marx.fits"
    verb3(">>> "+cmd)

    sub_cmd = cmd.strip().split(" ")
    try:
        subprocess.check_call( sub_cmd, stdout=dev_null )
    except:
        raise Exception("Problem running {0}".format(cmd))



#
# Run reproject events to get marx data back to same tan plane
# as original event file.
#
@verboseCall(2)
def run_marx_reproject_events( pars ):
    """
        MARX uses the _NOM values in the asol for the tan plane.

        If events have been reprojects then need to corrected
    """

    def get_xy_limits(infile):
        """
        The marx files may have events WAY! outside the tlmin/tlmax of the
        column.  We need to filter those out here or it messes up
        the limits and then when you display the ray file the data is
        centered in a bad place (eg far from 4096,4096 [acis]).
        """
        from pycrates import read_file
        tab = read_file(infile)
        xl=tab.get_column("x").get_tlmin()
        xh=tab.get_column("x").get_tlmax()
        yl=tab.get_column("y").get_tlmin()
        yh=tab.get_column("y").get_tlmax()
        retval="[x={}:{},y={}:{}]".format(xl,xh,yl,yh)
        return retval

    limits = get_xy_limits( pars["_outroot_"]+ "_marx.fits"  )

    reproject_events = make_tool("reproject_events")
    reproject_events.punlearn()
    reproject_events.infile  = pars["_outroot_"]+ "_marx.fits"+limits
    reproject_events.outfile = pars["_outroot_"]+ "_projrays.fits"
    reproject_events.random  = "-1"       # make sure no randomization
    reproject_events.match   = pars["infile"]
    reproject_events.clobber = "yes"

    reproject_events()

#
# Cleanup marx intermediate dir
#
@verboseCall(3)
def cleanup_marx( pars, output_dir ):
    """
     Cleanup marx temp dir

    """
    import shutil as sh
    sh.rmtree( output_dir )
    gorm( pars["_outroot_"] + "_marx.fits" )


#
# run_marx routine
#
@verboseCall(1)
def run_marx(pars, obi_info, src_info ):
    """
    Run SAOTrace rays through marx to match observation

         Determine aimpoint
         Determine sim offsets
         Run marx
         Run marx2fits (pileup optional)
         Run reproject events
    """
    #
    #The aimpoint MARX uses is based on the detector; determined above.  To get
    #the sim in the corerct location, we need to apply an offset via marx's
    #DetOffset parameters from the nominal aimpoint to the actual sim location
    #
    aimpoint,detector = determine_marx_detector( obi_info, src_info )

    #
    #So we get the aimpoint from marx's caldb file and subtract off the actual
    #SIM location to set the offsets
    #
    sim = determine_marx_sim( pars, aimpoint )

    # We wrap the actual MARX commands inside their own PFILES so that
    # they could be run in parallel if necessary:

    from ciao_contrib.runtool import new_pfiles_environment
    with new_pfiles_environment() as zoo:
        #
        # Run the marx excutable to project rays onto detector
        #
        output_dir = run_marx_exe( pars, obi_info, src_info, detector, sim )

        #
        # Run marx2fits
        #
        run_marx2fits( pars, obi_info, output_dir )

    #
    # Reproject to same tan plane as events
    #
    run_marx_reproject_events( pars )

    #
    # tweak meta data
    #
    add_metadata_to_rays( pars["_outroot_"]+ "_projrays.fits" , pars, obi_info, src_info )

    #
    # cleanup
    #
    cleanup_marx( pars, output_dir )


#
# run_psf_project routine
#
@verboseCall(1)
def run_psf_project(pars, obi_info, src_info ):
    """
    Run psf_project_ray to take saotrace rays to detector plane
    """

    psf_project_ray = make_tool("psf_project_ray")
    psf_project_ray.infile   = pars["_rayfile_"]
    psf_project_ray.evt      = pars["infile"]
    psf_project_ray.asolfile = pars["asolfile"]
    psf_project_ray.outfile  = pars["_outroot_"]+"_projrays.fits"
    psf_project_ray.xblur    = pars["blur"]
    psf_project_ray.randseed = pars["random_seed"]
    psf_project_ray.clobber  = True

    verb2( psf_project_ray() )

    add_metadata_to_rays( pars["_outroot_"]+"_projrays.fits", pars, obi_info, src_info )




#
# Wrapper around dmKeyRead to get value and toss descriptor
#
def get_dm_key( tab, key ):
    delme,_val = dmKeyRead( tab, key )
    try:
        val = _val.decode("ascii")
    except:
        val = _val
    return val

#
# provide info about ACIS subarray data
#

def check_acis_subarray( tab, pars):
    """
    In subarray mode frame time is shorter which should be taken care of but
    the spatial extent is different so things like pileup and readout streak
    (and spatial extent) are not correct.
    """
    try:
        nrow = get_dm_key( tab, "NROWS")
        if 1024 != nrow:
            verb0("WARNING: This observation used a {} row subarray.  The effects of this may not be correctly modeled especially if trying to simulate pileup and/or the readout streak.".format(nrow))
    except:
        # We have to assume full chip
        pass



def check_acis_interleaved( tab, pars ):
    """
    In ACIS interleaved mode, there is a large dead time/gap
    that will not be correctly simulated.
    """
    try:
        tbd = get_dm_key( tab, "TIMEDELB")
        if 0.0 != tbd:
            verb0("WARNING: Interleaved mode dataset detected.  The simulation may not be accurate in this mode especially if trying to simulate pileup and/or the readout streak.")
    except:
        # Keyword not found, assume is not interleaved.
        pass


#
# Get info about the observation
#
@verboseCall(2)
def get_obi_info( pars ):
    """
        Get info from infile and asol file about observatin.

        Evt:
            RA_PNT, DEC_PNT, ROLL_PNT
            SIM_X, SIM_Y, SIM_Z,
            GRATING, INSTRUME,
            ACIS:  EXPTIME and TIMEDEL [transfer time=(timdel-exptime)]

        Asol (if using one):
            TSTART, TSTOP
    """
    obi_info = {}
    tab = dmBlockOpen( pars["infile"] )
    for key in [ "ra_pnt", "dec_pnt", "roll_pnt", "sim_x", "sim_y", "sim_z" ]:
        val = get_dm_key( tab, key )
        obi_info[key] = "{0:.20g}".format(val)
    obi_info["grating"]  = get_dm_key( tab, "grating")
    obi_info["instrume"] = get_dm_key( tab, "instrume")
    obi_info["date-obs"] = get_dm_key( tab, "date-obs")

    if "ACIS" == obi_info["instrume"]:

        if get_dm_key(tab, "readmode") == "CONTINUOUS":
            raise ValueError("ERROR: Unable to simulate a PSF for Continuous Clocking Mode observations")    

        ## This should be OK for interleaved mode since
        # TIMEDEL should equal A or B depedning on P/S
        obi_info["exptime"] = get_dm_key(tab, "exptime")
        tdel = get_dm_key(tab, "TIMEDEL")
        obi_info["__xfer"] = float(tdel) - float( obi_info["exptime"] )

        check_acis_subarray( tab, pars )
        check_acis_interleaved( tab, pars )

    else:  #HRC
        obi_info["exptime"] = "0"
        obi_info["__xfer"] = "0"

    #
    # We get tstart/tstop from asol (since evt is generally much longer)
    # will fall thru to evt file if not using asol
    #
    if pars["__has_asp"]:
        dmDatasetClose(dmBlockGetDataset( tab ))
        tab = dmTableOpen( pars["asolfile"] )
    for key in [ "tstart", "tstop" ]:
        obi_info[key] = "{0:.20g}".format(get_dm_key(tab, key))
    dmDatasetClose(dmBlockGetDataset( tab ))
    return obi_info


#
# Get coordinates
#
@verboseCall(2)
def get_src_info( pars ):
    """
    Get the src location in various coord systems
    """

    dmcoords = make_tool("dmcoords")
    dmcoords.punlearn()
    dmcoords.infile = pars["infile"]
    dmcoords.asol   = pars["asolfile"]
    dmcoords.opt    = "cel"
    dmcoords.ra     = pars["ra"]
    dmcoords.dec    = pars["dec"]
    dmcoords.celfmt = "deg"

    dmcoords()

    src_info={}
    src_info["ra_src"]  = dmcoords.ra
    src_info["dec_src"] = dmcoords.dec
    src_info["x"]       = dmcoords.x
    src_info["y"]       = dmcoords.y
    src_info["theta"]   = dmcoords.theta
    src_info["phi"]     = dmcoords.phi
    src_info["chip_id"] = dmcoords.chip_id
    src_info["chipx"]   = dmcoords.chipx
    src_info["chipy"]   = dmcoords.chipy

    return src_info


#
# Different ways to get grids
#
@verboseCall(2)
def get_xygrid( pars, use_root, src_info):
    """
    Get grid from stats of ray data
    """
    if pars["__xygrid"]:
        return pars["__xygrid"]

    dmstat = make_tool("dmstat")
    dmstat.punlearn()
    dmstat.infile  = use_root+"projrays.fits[cols x,y]"
    dmstat.verbose = 0
    dmstat.sig     = "yes"
    dmstat.clip    = "yes"
    dmstat.med     = "no"
    dmstat()

    #
    # Use sky pixel of ra,dec as center
    #
    xx = int(float(src_info["x"])+0.5)+0.5
    yy = int(float(src_info["y"])+0.5)+0.5

    sigs = [float(r) for r in dmstat.out_sigma.split(",")]
    xlen = int(float(pars["numsig"]) * int(sigs[0] + 1)+0.5)
    ylen = int(float(pars["numsig"]) * int(sigs[1] + 1)+0.5)
    mlen = max(xlen, ylen)  # let's make square images, people like those

    if "INDEF" != pars["minsize"]:
        mm = int(int( pars["minsize"] ) / 2.0 )
        mlen = max( mlen, mm )

    if "INDEF" != pars["maxsize"]:
        mm = int(int( pars["maxsize"] ) / 2.0 )
        mlen = min( mlen, mm )

    # TODO integer / half-pixel ize?
    xl = xx - mlen
    xh = xx + mlen
    yl = yy - mlen
    yh = yy + mlen
    xgrid="x={0}:{1}:{2}".format(xl, xh, pars["binsize"] )
    ygrid="y={0}:{1}:{2}".format(yl, yh, pars["binsize"] )

    pars["__xygrid"] = xgrid+","+ygrid
    return pars["__xygrid"]

#
# Determine normalization factor, may need eg ray weights/etc
#
@verboseCall(2)
def get_normalization( pars, src_info ):
    """
    Return normalization factor for psf
    """
    dmstat = make_tool("dmstat")
    dmstat.punlearn()
    dmstat.infile  = pars["_outroot_"]+"_projrays.fits[cols x]"
    dmstat.verbose = 0
    dmstat.sig     = "no"
    dmstat.med     = "no"
    dmstat.clip    = "no"
    dmstat()
    totl = int(dmstat.out_good)
    return totl



@verboseCall(2)
def add_history_to_file( infile, pars ):
    """
    """
    from ciao_contrib.runtool import add_tool_history
    tpars = [x for x in pars if not x.startswith("_") ] # remove internal pars
    tvals = [pars[t] for t in tpars ]
    add_tool_history( infile,  toolname, dict(zip(tpars,tvals)), toolversion=__revision__)



@verboseCall(2)
def add_metadata_to_rays( infile, pars, obi_info, src_info ):
    dmlist = make_tool("dmlist")
    tot = dmlist( infile, "counts")
    dmhedit = make_tool("dmhedit")
    dmhedit( infile, file="", op="add", key="TOTCTS", value=tot, datatype="long" )


#
# create_psf_image
#
@verboseCall(1)
def create_psf_image(pars, src_info ):
    """
    Create an image from the projected ray file.

    """
    xygrid = get_xygrid( pars, pars["_outroot_"]+"_", src_info )
    totl = get_normalization( pars, src_info )
    outfile = pars["_outroot_"]+".psf[PSF]"

    #from ciao_contrib.runtool import dmimgcalc as dmimgcalc
    #dmimgcalc.punlearn()
    #dmimgcalc.infile=pars["_outroot_"]+"_projrays.fits[bin {0}][opt type=i4]".format(xygrid)
    #dmimgcalc.infile2="none"
    #dmimgcalc.outfile=pars["_outroot_"]+"_psf.img[PSF]"
    #dmimgcalc.operation="imgout=(img1/((float){0}))".format(totl)
    #dmimgcalc.clobber="yes"
    #dmimgcalc()

    dmcopy = make_tool("dmcopy")
    dmcopy.infile=pars["_outroot_"]+"_projrays.fits[2][bin {0}][opt type=i4][PSF]".format(xygrid)
    dmcopy.outfile=outfile
    dmcopy.clobber=True
    dmcopy()

    # We need to add here (again) since marx doesn't copy header
    add_history_to_file( outfile, pars )

    # save name in .lis file, only write name, not path since
    # path gets prepended as part of stk_build
    pars["__imgstk"].append( outfile)
    pars["__raystk"].append( pars["_outroot_"]+"_projrays.fits" )


@verboseCall(1)
def create_average_image( pars, src_info ):
    """
      Create mean from stack of psf images

    """


    if not pars["__imgstk"]:
        return

    from pycrates import read_file

    dmmerge = make_tool("dmmerge")
    dmmerge( pars["__raystk"], pars["outroot"]+"projrays.fits", clobber=True)

    cts = [read_file(i).get_key_value("TOTCTS") for i in pars["__raystk"]]
    cts = sum(cts)

    dmhedit = make_tool("dmhedit")
    dmhedit( dmmerge.outfile, file="", op="add", key="TOTCTS", value=cts)

    pars["__xygrid"] = None
    pars["_outroot_"] = pars["outroot"]
    xygrid = get_xygrid( pars, pars["outroot"], src_info )

    dot = "" if os.path.isdir(pars["outroot"]) else "."
    outfile = pars["outroot"].rstrip("_")+dot+"psf[PSF]"

    dmimgcalc = make_tool("dmimgcalc")
    dmimgcalc.infile=pars["outroot"]+"projrays.fits[2][bin {0}][opt type=i4][PSF]".format(xygrid)
    dmimgcalc.infile2=""
    dmimgcalc.outfile=outfile
    dmimgcalc.operation="imgout=((float)(img1))/img1_totcts"
    dmimgcalc.clobber=True
    dmimgcalc()

    add_history_to_file( outfile, pars )

    return outfile

    """
    from ciao_contrib.runtool import dmkeypar as kpar
    from pycrates import read_file
    cts = [read_file(i).get_key_value("TOTCTS") for i in pars["__imgstk"]]
    cts = sum(cts)

    from ciao_contrib.runtool import dmimgfilt as dmimgfilt
    dmimgfilt.punlearn()
    dmimgfilt.infile = pars["__imgstk"]
    dmimgfilt.outfile = pars["outroot"]+"psf.img[PSF]"
    dmimgfilt.function = "sum" # "mean"
    dmimgfilt.mask = "point(0,0)"
    dmimgfilt.clobber = "yes"
    dmimgfilt()

    from ciao_contrib.runtool import dmhedit
    dmhedit( dmimgfilt.outfile, file="", op="add", key="TOTCTS", value=cts)

    dmhedit( dmmerge.outfile, file="", op="add", key="TOTCTS", value=cts)
    """



@verboseCall(3)
def cleanup_iterations( pars ):
    """
      cleanup
    """
    clnextn = [".psf", "_projrays.fits"]
    if "saotrace" == pars["simulator"]:
        clnextn.append( "_rays.fits" )

    if "no" == pars["keepiter"]:
        for extn in clnextn:
            for nn in range( int(pars["numiter"]) ) :
                gorm( "{0}i{1:04d}{2}".format( pars["outroot"],nn,extn ))

    if pars["__spectrum"]:
        gorm( pars["__spectrum"] )


def check_numrays(numiter, numrays, pars, rayfile, fudge_factor=1.0):
    "Compute number of rays"

    numiter = numiter + 1    
    if pars["numrays"] == "INDEF":
        return numiter, numrays
        
    from pycrates import read_file
    tab = read_file(rayfile)
    nrays = tab.get_nrows()
    numrays = numrays + nrays

    if (numrays*fudge_factor) < int(pars["numrays"]):
        nextiter = int(pars["numiter"])+1
        pars["numiter"] = str(nextiter)
    
    return numiter, numrays


def iteration_loop( pars, obi_info, src_info ):
    """
        Loop over number of iterations to simulate psf.  Each iterations
        file name is stored in a psf.lis file that will get cleaned up
        at the end.
    """

    nn = 0
    numrays = 0
    while nn < int(pars["numiter"]) and nn < 10000:

        pars["_outroot_"] = pars["outroot"]+"i{0:04d}".format(nn)
        pars["random_seed"] = int( pars["random_seed"] )+nn

        verb1( "Performing iteration {0} of {1}".format(nn+1, pars["numiter"]))
        verb2( "  output root is {0}".format( pars["_outroot_"] ))
        verb3( "  with seed={0}".format(pars["random_seed"]))

        # run saotrace
        if "saotrace" == pars["simulator"]:
            pars["_rayfile_"] = pars["_outroot_"]+"_rays.fits"
            run_saotrace( pars, obi_info, src_info )
            if pars["projector"] == "none":
                nn, numrays = check_numrays(nn, numrays, pars, pars["_rayfile_"], fudge_factor=0.7)
                continue
                                
        elif "file" == pars["simulator"]:
            pars["_rayfile_"] = stk.build(pars["rayfile"])[nn]

        if pars["projector"] == "marx":
            # run marx
            run_marx( pars, obi_info, src_info )
        else:
            # run psf_project_ray
            run_psf_project( pars, obi_info, src_info )
        create_psf_image(pars, src_info)
        nn, numrays = check_numrays(nn, numrays, pars, pars["_outroot_"]+"_projrays.fits")

    if nn == 10000:
        verb0("WARNING: Could not get requested number of rays in 10000 iterations. Stopping now.")
 
#
# Main Routine
#
@handle_ciao_errors( toolname, __revision__)
def main():
    """
    """
    # get parameters
    from ciao_contrib.param_soaker import get_params

    pars = get_params("simulate_psf", "rw", sys.argv,
                      verbose={"set":logWrap.set_verbosity, "cmd":verb1},
                      musthave=("infile","outroot","ra","dec",
                                "spectrumfile","monoenergy","flux",
                                "simulator","rayfile","projector",
                                "random_seed","blur","readout_streak",
                                "pileup","ideal","extended","binsize",
                                "numsig","minsize","maxsize","numiter",
                                "numrays","keepiter","asolfile",
                                "marx_root","verbose",)                      
                      )

    # Hard code these for now, restore w/o underscore when ready
    pars["_saotrace_install"] = ""
    pars["_saotrace_db" ] = ""
    pars["_saotrace_mirror"] = "orbit-200809-01f-a"

    # check saotrace and marx setup parameters
    check_setup( pars )


    if not os.path.isdir( pars["outroot"] ):
        pars["outroot"] = pars["outroot"]+"_"
    elif not pars["outroot"].endswith("/"):
        pars["outroot"] = pars["outroot"]+"/"
    else:
        pass

    # check parameter combinations
    resolve_parameter_dependencies( pars )

    # set up the environment vars for saotrace and marx
    set_env_vars( pars )

    # Load obs info
    obi_info = get_obi_info( pars )

    # Get src info (coords)
    src_info = get_src_info( pars )

    # Do n-many simulations
    iteration_loop( pars, obi_info, src_info )

    # Create an the average from sims
    outfile = create_average_image( pars, src_info )

    # Cleanup temp files (rays, projrays, and images)
    cleanup_iterations( pars )

    if outfile:
        verb1( "\nFinal output PSF image is : "+outfile )



if __name__ == "__main__":
    try:
        main()
    except Exception as E:
        print("\n# "+toolname+" ("+__revision__+"): ERROR "+str(E)+"\n", file=sys.stderr)
        sys.exit(1)
    sys.exit(0)

