"""PSF_Models.
A module which provides general utility methods
to manage the parameters of the psf model.
:Authors: Tobias Liaudat <tobiasliaudat@gmail.com> and Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import numpy as np
import tensorflow as tf
from wf_psf.sims.psf_simulator import PSFSimulator
from wf_psf.utils.utils import zernike_generator
from wf_psf.utils.optimizer import is_optimizer_instance, get_optimizer
import glob
import logging
logger = logging.getLogger(__name__)
PSF_FACTORY = {}
[docs]
class PSFModelError(Exception):
"""PSF Model Parameter Error exception class.
This exception class is used to handle errors related to PSF (Point Spread Function) model parameters.
Parameters
----------
message : str, optional
Error message to be raised. Defaults to "An error with your PSF model parameter settings occurred."
"""
def __init__(
self, message="An error with your PSF model parameter settings occurred."
):
self.message = message
super().__init__(self.message)
[docs]
def register_psfclass(psf_factory_class):
"""Register PSF Factory Class.
A function to register a PSF factory class in a dictionary.
Parameters
----------
factory_class: type
PSF Factory Class
"""
for id in psf_factory_class.ids:
PSF_FACTORY[id] = psf_factory_class
logger.info(id, PSF_FACTORY)
return psf_factory_class
[docs]
class PSFModelBaseFactory:
"""Base factory class for PSF models.
This class serves as the base factory for instantiating PSF (Point Spread Function) models.
Subclasses should implement the `get_model_instance` method to provide specific PSF model instances.
Attributes
----------
None
Methods
-------
get_model_instance(model_params, training_params, data=None, coeff_matrix=None)
Instantiates a PSF model with the provided parameters.
Notes
-----
Subclasses of `PSFModelBaseFactory` should override the `get_model_instance` method to provide
implementation-specific logic for instantiating PSF model instances.
"""
[docs]
def get_model_instance(
self, model_params, training_params, data=None, coeff_matrix=None
):
"""Instantiate a PSF model instance.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
training_params: Recursive Namespace
A Recursive Namespace object containing training hyperparameters for this PSF model class.
data: DataConfigHandler
A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients.
coeff_mat: Tensor or None, optional
Coefficient matrix defining the parametric PSF field model.
Returns
-------
PSF model instance
An instance of the PSF model.
"""
pass
[docs]
def set_psf_model(model_name):
"""Set PSF Model Class.
A function to select a class of
the PSF model from a dictionary.
Parameters
----------
model_name: str
Name of PSF model
Returns
-------
psf_class: class
Name of PSF model class
"""
try:
psf_factory_class = PSF_FACTORY[model_name]
except KeyError as e:
logger.exception(e)
raise PSFModelError("PSF model entered is invalid. Check your config settings.")
return psf_factory_class
[docs]
def get_psf_model(*psf_model_params):
"""Get PSF Model Class Instance.
A function to instantiate a
PSF model class.
Parameters
----------
*psf_model_params : tuple
Positional arguments representing the parameters required to instantiate the PSF model.
Returns
-------
PSF model class instance
An instance of the PSF model class based on the provided parameters.
"""
model_name = psf_model_params[0].model_name
psf_factory_class = set_psf_model(model_name)
if psf_factory_class is None:
raise PSFModelError("PSF model entered is invalid. Check your config settings.")
return psf_factory_class().get_model_instance(*psf_model_params)
[docs]
def compile_PSF_model(model_inst, optimizer=None, loss=None, metrics=None):
"""Compile PSF Model.
A function to compile a PSF model instance.
Parameters
----------
model_inst: PSF model instance
An instance of the PSF model to be compiled.
optimizer: str, dict, or Keras optimizer instance, optional
The optimizer to use for compiling the model. It can be a string representing the optimizer name, a dictionary containing optimizer configuration, or an instance of a Keras optimizer. Defaults to None, which will use the default optimizer.
loss: str or Keras loss instance, optional
The loss function to use for compiling the model. It can be a string representing the loss function name or an instance of a Keras loss. Defaults to None, which will use the default loss function.
metrics: list of str or Keras metric instances, optional
A list of metrics to evaluate during training. Each metric can be a string representing the metric name or an instance of a Keras metric. Defaults to None, which will use the default metrics
(Mean Squared Error in this case).
Returns
-------
PSF model instance
The compiled PSF model instance ready for training.
"""
# Define model loss function
if loss is None:
loss = tf.keras.losses.MeanSquaredError()
# Handle optimizer: either config object or a Keras optimizer instance
if is_optimizer_instance(optimizer):
pass
else:
optimizer = get_optimizer(optimizer_config=optimizer)
# Define metric functions
if metrics is None:
metrics = [tf.keras.metrics.MeanSquaredError()]
# Compile the model
model_inst.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=None,
weighted_metrics=None,
run_eagerly=False,
)
return model_inst
[docs]
def get_psf_model_weights_filepath(weights_filepath):
"""Get PSF model weights filepath.
A function to return the basename of the user-specified PSF model weights path.
Parameters
----------
weights_filepath: str
Basename of the PSF model weights to be loaded.
Returns
-------
str
The absolute path concatenated to the basename of the PSF model weights to be loaded.
"""
try:
return glob.glob(weights_filepath)[0].split(".")[0]
except IndexError:
logger.exception(
"PSF weights file not found. Check that you've specified the correct weights file in the your config file."
)
raise PSFModelError("PSF model weights error.")
[docs]
def generate_zernike_maps_3d(n_zernikes, pupil_diam):
"""Generate 3D Zernike Maps.
This function generates Zernike maps on a three-dimensional tensor.
Parameters
----------
n_zernikes : int
The number of Zernike polynomials.
pupil_diam : float
The diameter of the pupil.
Returns
-------
tf.Tensor
A TensorFlow EagerTensor containing the Zernike map tensor.
Notes
-----
The Zernike maps are generated using the specified number of Zernike
polynomials and the size of the pupil diameter. The resulting tensor
contains the Zernike maps in a three-dimensional format.
"""
# Prepare the inputs
# Generate Zernike maps
zernikes = zernike_generator(n_zernikes=n_zernikes, wfe_dim=pupil_diam)
# Now as cubes
np_zernike_cube = np.zeros(
(len(zernikes), zernikes[0].shape[0], zernikes[0].shape[1])
)
for it in range(len(zernikes)):
np_zernike_cube[it, :, :] = zernikes[it]
np_zernike_cube[np.isnan(np_zernike_cube)] = 0
return tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32)
[docs]
def tf_obscurations(pupil_diam, N_filter=2, rotation_angle=0):
"""Tensor Flow Obscurations.
A function to generate obscurations as a tensor.
Parameters
----------
pupil_diam: float
Size of the pupil diameter
N_filters: int
Number of filters
rotation_angle: int
Rotation angle in degrees to apply to the obscuration pattern. It only supports 90 degree rotations. The rotation will be counterclockwise.
Returns
-------
Obscurations tensor
TensorFlow EagerTensor type
"""
obscurations = PSFSimulator.generate_euclid_pupil_obscurations(
N_pix=pupil_diam, N_filter=N_filter, rotation_angle=rotation_angle
)
return tf.convert_to_tensor(obscurations, dtype=tf.complex64)
[docs]
def simPSF(model_params):
"""Instantiate and configure a Simulated PSF model.
This function creates a `PSFSimulator` instance with the given model parameters, generates random Zernike coefficients, normalizes them, and produces a monochromatic PSF.
Parameters
----------
model_params: Recursive Namespace
A recursive namespace object storing model parameters
Returns
-------
PSFSimulator
A configured `PSFSimulator` instance with the specified model parameters.
"""
simPSF_np = PSFSimulator(
max_order=model_params.param_hparams.n_zernikes,
pupil_diameter=model_params.pupil_diameter,
output_dim=model_params.output_dim,
oversampling_rate=model_params.oversampling_rate,
output_Q=model_params.output_Q,
SED_interp_pts_per_bin=model_params.sed_interp_pts_per_bin,
SED_extrapolate=model_params.sed_extrapolate,
SED_interp_kind=model_params.sed_interp_kind,
SED_sigma=model_params.sed_sigma,
pix_sampling=model_params.pix_sampling,
tel_diameter=model_params.tel_diameter,
tel_focal_length=model_params.tel_focal_length,
euclid_obsc=model_params.euclid_obsc,
LP_filter_length=model_params.LP_filter_length,
)
simPSF_np.gen_random_Z_coeffs(max_order=model_params.param_hparams.n_zernikes)
z_coeffs = simPSF_np.normalize_zernikes(
simPSF_np.get_z_coeffs(), simPSF_np.max_wfe_rms
)
simPSF_np.set_z_coeffs(z_coeffs)
simPSF_np.generate_mono_PSF(lambda_obs=0.7, regen_sample=False)
return simPSF_np