from pathlib import Path
from typing import List, Union
import numpy as np
import os

import pycrates

from ciao.tools.io import InputOutputFactory, FitsLikeFile, Column, AbstractBlock, Observation, Stack


# Adapted from https://stackoverflow.com/q/35751306
def pad(array, reference_shape):
    """
    array: Array to be padded
    reference_shape: tuple of size of ndarray to create
    offsets: list of offsets (number of elements must be equal to the dimension of the array)
    will throw a ValueError if offsets is too big and the reference_shape cannot handle the offsets
    """

    # Create an array of zeros with the reference shape
    result = np.full(reference_shape, np.nan)
    # Create a list of slices from offset to offset + shape in each dimension
    insert_here = [slice(0, array.shape[dim]) for dim in range(array.ndim)]
    # Insert the array in the result at the specified offsets
    result[insert_here] = array
    return result


class CratesInputOutputFactory(InputOutputFactory):
    def create_file(self, filename: Path, cls=None, clobber=False, mode='r', **kwargs):
        if clobber and filename.exists():
            filename.unlink()
        file = CratesFitsFile(str(filename), mode=mode)
        if cls is None:
            return file
        return cls(file, **kwargs)


class CratesColumn(Column):
    def __init__(self, crates_column):
        self.crates_column = crates_column

    @property
    def values(self):
        return self.crates_column.values

    @values.setter
    def values(self, values):
        self.crates_column.values = values

    # FIXME This is very simplistic and will only work when the new value is smaller in size than the rest of the column
    def set_variable_row_value(self, row_number, value):
        orig_shape = self.values.shape[1:]
        value = pad(value, orig_shape)
        if row_number >= self.values.shape[0]:
            shape = list(self.values.shape + tuple())
            shape[0] = row_number + 1
            self.values = pad(self.values, shape)
        self.values[row_number] = value

    def add_value(self, value):
        self.values = np.append(self.values, value)


class CratesBlock(AbstractBlock):
    def __init__(self, crate: pycrates.TABLECrate, parent: FitsLikeFile, header_class=None, header=None):
        super().__init__(header_class=header_class, header=header)
        self.parent = parent
        self.crate = crate

    def get_header_value(self, name, default=None):
        try:
            return self.crate.get_key(name).value
        except AttributeError:
            if hasattr(self.header, name):
                return getattr(self.header, name)
            return default

    def get_header_names(self):
        return set(self.crate.get_keynames() + list(self.header.names))

    def get_history(self):
        return self.crate.get_history_records()

    def get_comments(self):
        return self.crate.get_comment_records()

    def get_parent(self):
        return self.parent

    def get_column(self, name) -> Column:
        return CratesColumn(self.crate.get_column(name))

    def new_column(self, name) -> Column:
        try:
            return self.get_column(name)
        except ValueError:
            new_crates_column = pycrates.CrateData()
            new_crates_column.name = name
            self.crate.add_column(new_crates_column)
            return self.get_column(name)

    def write(self):
        for key, value in self.header:
            if not self.crate.key_exists(key):
                self.crate.add_key(pycrates.CrateKey((key, value, '', '')))
            else:
                self.crate.get_key(key).value = value

    def get_transform(self, name):
        return self.crate.get_transform(name)

    @property
    def observation(self):
        return Observation(self.crate.get_key('OBS_ID').value, self.crate.get_key('OBI_NUM').value)

    @property
    def stack(self) -> Stack:
        return Stack(self.crate.get_key('STACK_ID').value)

    @property
    def column_names(self) -> List[str]:
        return self.crate.get_colnames()


class CratesFitsFile(FitsLikeFile):
    def __init__(self, filename: Union[str, Path], mode='r'):
        self.filename = Path(filename)
        self._blocks = {}
        if not self.filename.exists():
            self.crate = self._create()
        else:
            self.crate = pycrates.CrateDataset(filename, mode=mode)
            for i in range(self.crate.get_ncrates()):
                ds = self.crate.get_crate(i)
                self._blocks[ds.name] = CratesBlock(ds, parent=self)

    def get_block(self, name, header_class=None):
        if header_class is not None:
            block = self._blocks[name]
            header_dict = {key.lower(): block.get_header_value(key) for key in block.get_header_names()}
            self._blocks[name] = CratesBlock(self.crate.get_crate(name), parent=self, header_class=header_class,
                                             header=header_dict)
        return self._blocks[name]

    @property
    def blocks(self):
        return self._blocks

    def new_block(self, name, data='table', header_class=None, header=None):
        if name is not 'PRIMARY':
            block = pycrates.TABLECrate()
        else:
            block = pycrates.IMAGECrate()
            fake_image = pycrates.CrateData()
            fake_image.values = np.array([])
            block.add_image(fake_image)
        block.name = name
        self.crate.add_crate(block)
        block = CratesBlock(self.crate.get_crate(name), parent=self, header_class=header_class,
                            header=header)
        self._blocks[name] = block
        return block

    def _create(self):
        crate = pycrates.CrateDataset()
        crate.write(str(self.filename))
        return crate

    def write(self, fname=None):
        for block in self._blocks.values():
            block.write()
        if fname is None:
            fname = str(self.filename)
        self.crate.write(fname, clobber=True)

    def get_caldb_version(self):
        default = "-999"
        caldbvar = os.environ.get('CALDB')
        if caldbvar is not None:
            try:
                caldb = pycrates.read_file(os.path.join(caldbvar, "docs/chandra/caldb_version/caldb_version.fits"), 'r')
                caldbver = str(caldb.get_column("CALDB_VER").values[-1])
            except IOError:
                caldbver = default
        else:
            caldbver = default
        return caldbver
