#!/opt/conda/envs/ciao-4.17.0/bin/python3
# 
#  Copyright (C) 2007-2008,2011-2012,2016  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

from __future__ import print_function

from pytransform import *
from numpy import array
import numpy
from os import environ


#-----------------------------------------------
#  Setup input/output directories
#-----------------------------------------------
testin  = environ.get('TESTIN')
if testin == None :
  indir = environ.get('PWD')
else :
  indir = testin + '/test_transform'

testout = environ.get('TESTOUT')
if testout == None :
  outdir = environ.get('PWD')
else :
  outdir = testout + '/transform'

#-----------------------------------------------
# Test and verify copy_transform()
#-----------------------------------------------

# Create and define a WCS Transform 
tr = WCSTANTransform("UNIT")
param = tr.get_parameter("CRPIX")
param.set_value([50.0, 50.0])
param = tr.get_parameter("CRVAL")
param.set_value([50.0, 50.0])
param = tr.get_parameter("CDELT")
param.set_value([1.0, 1.0])

# Show it
print("Initial Transform definition")
print("Transform Type: ", get_transform_type(tr))
print(tr.print_parameter_list())

# Copy it
print("")
print("Copy Transform")
print("  - change name of copy")
ntr = copy_transform(tr)
ntr.set_name("UNIT_COPY")

# Show copy defintion
print("")
print("Copy Transform definition")
print("Transform Type: ", get_transform_type(ntr))
print(ntr.print_parameter_list())

# Show values are not linked..
# Modify initial transform attributes and display both
print("")
print("Modify Initial Transform")
print("  - NAME")
print("  - CRVAL")
print("  - CROTA")
tr.set_name("UNIT_MOD")
param = tr.get_parameter("CRVAL")
param.set_value([88.0, 88.0])
param = tr.get_parameter("CROTA")
param.set_value( 20.0 )

print("")
print("Initial Transform definition (modified)")
print("Transform Type: ", get_transform_type(tr))
print(tr.print_parameter_list())

print("")
print("Copy Transform definition (unchanged)")
print("Transform Type: ", get_transform_type(ntr))
print(ntr.print_parameter_list())


print("")

# Create and define a LINEAR2D Transform 
lintr = LINEAR2DTransform("LIN")
param = lintr.get_parameter("SCALE")
param.set_value([-0.136667, 0.136667])
param = lintr.get_parameter("ROTATION")
param.set_value([0.000000])
param = lintr.get_parameter("OFFSET")
param.set_value([4145.697897, 4098.235660])

# Show it
print("")
print("Initial LINEAR2DTransform definition")
print("Transform Type: ", get_transform_type(lintr))
print(lintr.print_parameter_list())

# Copy it
print("")
print("Copy LINEAR2DTransform")
print("  - change name of copy")
nlintr = copy_transform(lintr)
nlintr.set_name("LIN_COPY")

# Show copy defintion
print("")
print("Copy LINEAR2DTransform definition")
print("Transform Type: ", get_transform_type(nlintr))
print(nlintr.print_parameter_list())

# Show values are not linked..
# Modify initial transform attributes and display both
print("")
print("Modify Initial LINEAR2DTransform")
print("  - NAME")
print("  - SCALE")
print("  - OFFSET")
lintr.set_name("LIN_MOD")
param = lintr.get_parameter("SCALE")
param.set_value([0.333333, -0.222222])
param = lintr.get_parameter("OFFSET")
param.set_value( 5.0 )

print("")
print("Initial LINEAR2DTransform definition (modified)")
print("Transform Type: ", get_transform_type(lintr))
print(lintr.print_parameter_list())

print("")
print("Copy LINEAR2DTransform definition (unchanged)")
print("Transform Type: ", get_transform_type(nlintr))
print(nlintr.print_parameter_list())

#-----------------------------------------------
# Test and verify apply_transform()
#  - extracted from low-level interface.
#-----------------------------------------------
# Data to be transformed.
sky = array([[0984.5, 1020.5],
             [0994.5, 1021.5],
             [1004.5, 1022.5],
             [1014.5, 1023.5],
             [1024.5, 1024.5],
             [1034.5, 1025.5],
             [1044.5, 1026.5],
             [1054.5, 1027.5],
             [1064.5, 1028.5],
             [1074.5, 1029.5]], dtype=numpy.double)

print("")
print("----------------------------------------------------------------------")
print("Define a 'real' WCS transform")

# Define Transform (sky2cel)
sky2cel = WCSTANTransform()
param = sky2cel.get_parameter("CRPIX")
param.set_value([1.0245000000000E+03,1.0245000000000E+03])
param = sky2cel.get_parameter("CRVAL")
param.set_value([2.0152004149394E+02,-4.2949051379390E+01])
param = sky2cel.get_parameter("CDELT")
param.set_value([-0.000585778,0.0005857778])
param = sky2cel.get_parameter("CROTA")
param.set_value(0.0)
param = sky2cel.get_parameter("EQUINOX")
param.set_value(2000.0)
param = sky2cel.get_parameter("EPOCH")
param.set_value(2000.0)

sky2cel1 = copy_transform(sky2cel)

print("Transform Type: ", get_transform_type(sky2cel1))
print(sky2cel1.print_parameter_list())

print("")
print("Apply Transform to data array..")
cel = apply_transform(sky2cel1, sky)

print("")
print("  Input array length.. ", len(sky))
print("  Input array ndim.. ", sky.ndim)
print("  Input array shape.. ", sky.shape)
print("  Input array dtype.. ", sky.dtype)
print("")
print("  Output cel values are type... ", type(cel).__name__)
print("  Output array length.. ", len(cel))
print("  Output array ndim.. ", cel.ndim)
print("  Output array shape.. ", cel.shape)
print("  Output array dtype.. ", cel.dtype)
print("")
print("     ( x,y )     ->    ( RA, DEC )")
for ii in range(0, len(sky)) :
  print("%7.1f, %7.1f -> %7.3f, %7.3f " % (sky[ii][0], sky[ii][1], cel[ii][0], cel[ii][1]))
