#!/usr/bin/env python
#
#  Copyright (C) 2015,2020  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
 

"""
Create custom color lookup tables based on color names

This package provides a set of functions to create custom
color maps based on named colors ( eg 'orange', 'yellow', 'slateblue').

A simple example looks like this:

>>> r,g,b = white_to_color("darkgreen")

which creates a custom color lookup table that fades from white to 
the color 'darkgreen'.  

The named colors are taken from the colors.par parameter file

>>> from paramio import plist
>>> plist("colors.par")

"""

__all__ = [ 'white_to_color', 'black_to_color', 'black_to_color_to_white', 'lut_colors']

import numpy as np

def _color_to_tripple( color ):
    """
    Convert a named color into a normalized RGB tripple
    
    >>> _color_to_tripple("red")
    [1.0, 0, 0]
    
    The named colors are taken from the 'colors.par' 
    parameter file which includes a superset of 
    the named colors available in chips.
    
    """
    from paramio import pget, plist

    if color in plist("colors.par"):
        rgb = pget('colors.par', color )    
        rgb = rgb.split()
        rgb = [float(x) for x in rgb]
        return rgb

    if len(color.split()) == 3:
        rgb = color.split()
        rgb = [float(x) for x in rgb]
        if max(rgb) > 1.0:
            rgb = [x/255.0 for x in rgb]
        if max(rgb) > 1.0 or min(rgb) < 0.0:
            raise ValueError("Color values must be between 0 and 255")
        return rgb

    if len(color) == 8 and color.lower().startswith('0x') is True:
        def __hex_to_int(cc):
            return int( '0x'+cc, base=16 )
        
        rr = __hex_to_int( color[2:4])
        gg = __hex_to_int( color[4:6])
        bb = __hex_to_int( color[6:])
        rgb = [ rr/255.0, gg/255.0, bb/255.0] 
        return rgb
    

    raise ValueError("Unable to parse color value and cannot locate color='{}' in colors.par".format(color))



def white_to_color(color, **kwargs):
    """
    Create a custom color lookup table that fades from white to 
    the specified color.  
    
    >>> r,g,b, = white_to_color("darkgreen")    
 
    The colors can be interpolated either in the RGB color system (default)
    or in the HSV or HLS color systems.
    
    >>> r,g,b = white_to_color("red", colorsys="rgb")
    >>> r,g,b = white_to_color("red", colorsys="hsv")
    >>> r,g,b = white_to_color("red", colorsys="hls")

    The default is to create a lookup table with 256 colors.  
    The number of colors can be changed using the num_color parameter
    
    >>> r,g,b = white_to_color("yellow", num_colors=16)
    
    """

    return lut_colors( ['white', color], **kwargs)


def black_to_color(color, **kwargs):
    """
    Create a custom color lookup table that fades from black to 
    the specified color. 
    
    >>> r,g,b = black_to_color("yellow")

    The colors can be interpolated either in the RGB color system (default)
    or in the HSV or HLS color systems.
    
    >>> r,g,b = black_to_color("red", colorsys="rgb")
    >>> r,g,b = black_to_color("red", colorsys="hsv")
    >>> r,g,b = black_to_color("red", colorsys="hls")

    The default is to create a lookup table with 256 colors.  
    The number of colors can be changed using the num_color parameter
    
    >>> r,g,b = black_to_color("yellow", num_colors=16)
    
    """

    return lut_colors( ['black', color], **kwargs )


def black_to_color_to_white( color, **kwargs):
    """
    Create a custom color lookup table that fades from black to 
    the specified color to white. The color map is loaded into the 
    specified slot (default="usercmap1")
    
    >>> r,g,b = black_to_color_to_white("firebrick")

    The colors can be interpolated either in the RGB color system (default)
    or in the HSV or HLS color systems.
    
    >>> r,g,b = black_to_color_to_white("red", colorsys="rgb")
    >>> r,g,b = black_to_color_to_white("red", colorsys="hsv")
    >>> r,g,b = black_to_color_to_white("red", colorsys="hls")

    The default is to create a lookup table with 256 colors.  
    The number of colors can be changed using the num_color parameter
    
    >>> black_to_color_to_white("yellow", num_colors=16)
    
    """
    return lut_colors( ['black', color, 'white'], **kwargs )



def _circular_interp( x0, xx, hh ):
    # Hue values are cyclical going from 0 to 1.  To interpolate
    # we pick the shortest path (length < 0.5) and shift the 
    # values so that we can interpolate them.

    h_out = []
    for x in x0:
        idx = np.where(xx<=x)[0][-1]
        if idx == len(hh)-1:
            hh_o = hh[-1]
        elif hh[idx+1]-hh[idx] > 0.5 :
            th = hh[:]
            th[idx] = th[idx]+1
            hh_o = np.interp( x, xx, th )
            hh_o = np.mod( hh_o, 1.0 )
        elif hh[idx+1]-hh[idx] < -0.5 :
            th = hh[:]
            th[idx+1] = th[idx+1]+1
            hh_o = np.interp( x, xx, th )
            hh_o = np.mod( hh_o, 1.0 )
        else:
            hh_o = np.interp( x, xx, hh )
        h_out.append(hh_o)
        
    return np.array( h_out )

    

def lut_colors( colors, num_colors=256, colorsys="rgb"):
    """
    Create a custom color lookup table that fades from each
    color listed.  
    
    >>> r,g,b = lut_colors( ["red","white","blue"] )
    >>> r,g,b = lut_colors( ["black","red","orange", "yellow","white"])
    >>> r,g,b = lut_colors( ["cyan", "magenta"] )
    
    The colors can be interpolated either in the RGB color system (default)
    or in the HSV or HLS color systems.
    
    >>> r,g,b = lut_colors(["pink", "cadetblue"], colorsys="rgb")
    >>> r,g,b = lut_colors(["pink", "cadetblue"], colorsys="hsv")
    >>> r,g,b = lut_colors(["pink", "cadetblue"], colorsys="hls")

    The default is to create a lookup table with 256 colors.  
    The number of colors can be changed using the num_color parameter
    
    >>> r,g,b = lut_colors(["yellow","mediumblue"], num_colors=16)
    
    """

    from colorsys import rgb_to_hsv, hsv_to_rgb, rgb_to_hls, hls_to_rgb

    def _noop( x,y,z):  return x,y,z

    rgbs = [_color_to_tripple(c) for c in colors] 

    if "rgb" == colorsys:
        hsvs = rgbs
        interp = np.interp
        invert = _noop
    elif "hsv" == colorsys:
        hsvs = [ rgb_to_hsv(r,g,b) for r,g,b in rgbs ]
        interp = _circular_interp
        invert = hsv_to_rgb
    elif "hls" == colorsys:
        hsvs = [ rgb_to_hls(r,g,b) for r,g,b in rgbs ]
        interp = _circular_interp
        invert = hls_to_rgb
    else:
        raise ValueError("Unknow value of colorsys")

    hh = [x[0] for x in hsvs]
    ss = [x[1] for x in hsvs]
    vv = [x[2] for x in hsvs]

    xx = np.arange(len(hh))/(len(hh)-1.0)
    x0 = np.arange(num_colors)/(num_colors-1.0)

    hh_i = interp( x0, xx, hh )
    ss_i = np.interp( x0, xx, ss )
    vv_i = np.interp( x0, xx, vv )

    rgbs_i = [invert(a,b,c) for a,b,c in zip(hh_i,ss_i,vv_i)]
    
    rr_i = [x[0] for x in rgbs_i]
    gg_i = [x[1] for x in rgbs_i]
    bb_i = [x[2] for x in rgbs_i]

    return rr_i, gg_i, bb_i
        

import sys
import os
import subprocess as sp
from tempfile import NamedTemporaryFile    


def check_for_single(ds9):
    """
    Check for multiple ds9 instances; there can be only 1 #Highlander
    """

    runcmd = "xpaget {} version".format(ds9).split(" ")
    ver = sp.run( runcmd, check=False, stdout=sp.PIPE).stdout
    ver=ver.decode().strip().split("\n")
    if len(ver) > 1 or 'BEGIN' in ver[0]:
        raise RuntimeError("Multiple ds9's are running.  Please close the others and try again.")


    
if __name__ == "__main__":
    xpa=sys.argv[1]
    numcolors=int(sys.argv[2])
    colorsys=sys.argv[3]    
    colors = [ x for x in sys.argv[4:] if len(x)>0 ]

    check_for_single(xpa)

    rgb = lut_colors( colors, num_colors=numcolors, colorsys=colorsys)
    r_g_b = zip(*rgb)

    lut_filename = os.path.join(os.environ["DAX_OUTDIR"],
        "_".join(colors)+".lut")

    with open(lut_filename, "w") as fp:
        for r,g,b in r_g_b:
            fp.write("{:.3f} {:.3f} {:.3f}\n".format(r,g,b))
    
    cmd = "xpaset -p {} cmap load {}".format(xpa,lut_filename)
    cmd=cmd.split(" ")
    sp.run(cmd, check=False)




