# This is a CamiTK python action
#
# Apply a Gaussian filter to an ImageComponent using ITK and SciPy

import camitk

import itk
import numpy as np
from scipy import ndimage
from PySide2.QtWidgets import *

def gaussian_filter_scipy(input_array: np.ndarray, sigma: float = 1.0):
    scypi_smoothed = ndimage.gaussian_filter(input_array, sigma=sigma)
    return scypi_smoothed

def gaussian_filter_itk(input_array: np.ndarray, sigma: float = 1.0):
    #-- 1. Convert NumPy array to ITK image
    image_type = itk.Image[itk.F, 3]  # Float image, 3D
    itk_image = itk.image_from_array(input_array.astype(np.float32))

    #-- 2. Apply Gaussian filter
    gaussian_filter = itk.SmoothingRecursiveGaussianImageFilter[image_type, image_type].New()
    gaussian_filter.SetInput(itk_image)
    camitk.info("Sigma: " + str(sigma))
    gaussian_filter.SetSigma(sigma)
    gaussian_filter.Update()

    #-- 3. Convert back to NumPy
    smoothed_array = itk.array_from_image(gaussian_filter.GetOutput())
    return smoothed_array

def process(self:camitk.Action):
    #-- 1. get the target image data as numpy array
    image_component = self.getTargets()[-1] # last target (should be the same as [0])
    image_data = image_component.getImageDataAsNumpy()
    
    #-- 2. get the sigma value from the UI
    sigma = self.getParameterValue("Sigma")
    camitk.info("Sigma: " + str(sigma))

    #-- 3. use itk filter
    smoothed_itk = gaussian_filter_itk(input_array=image_data, sigma=sigma)

    #-- 4. create image from data (do not forget to use the original image spacing)
    new_image_component_itk = camitk.newImageComponentFromNumpy(smoothed_itk, image_component.getName() + " (gaussian itk)", image_component.getSpacing())
    new_image_component_itk.setFrameFrom(image_component)

    #-- 5. use scipy
    smoothed_scipy = gaussian_filter_scipy(input_array=image_data, sigma=sigma)

    #-- 6. create image from data (do not forget to use the original image spacing)
    new_image_component_scipy = camitk.newImageComponentFromNumpy(smoothed_scipy, image_component.getName() + " (gaussian scipy)", image_component.getSpacing())
    new_image_component_scipy.setFrameFrom(image_component)

    #-- 7. show the images
    self.refreshApplication() # similar to what would be done in C++
    # or camitk.refresh()

    #-- compare
    image_itk = new_image_component_itk.getImageDataAsNumpy()
    mean_itk, var_itk = np.mean(image_itk), np.var(image_itk)
    print(f"ITK: mean={mean_itk} var={var_itk}")

    image_scipy = new_image_component_scipy.getImageDataAsNumpy()
    mean_scipy, var_scipy = np.mean(image_scipy), np.var(image_scipy)
    print(f"SciPy: mean={mean_scipy} var={var_scipy}")

    correlation = np.corrcoef(image_data.flatten(), image_itk.flatten())[0, 1]
    print(f"Correlation between original and ITK: {correlation}")
    correlation = np.corrcoef(image_data.flatten(), image_scipy.flatten())[0, 1]
    print(f"Correlation between original and SciPy: {correlation}")

    correlation = np.corrcoef(image_itk.flatten(), image_scipy.flatten())[0, 1]
    print(f"Correlation between ITK and Scipy: {correlation}")

    if correlation < 0.999:
        camitk.warning("Gaussian images are different.")

def targetDefined(self:camitk.Action):
    msg = QMessageBox()
    msg.setIcon(QMessageBox.Information)

    # setting message for Message Box
    msg.setText(f"Itk version: {itk.__version__}")
    
    # setting Message box window title
    msg.setWindowTitle("ITK VERSION")
    
    # declaring buttons on Message Box
    msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel)
    msg.exec_()
    