"""PSF Model Physical Semi-Parametric Polychromatic.
A module which defines the classes and methods
to manage the parameters of the psf physical polychromatic model.
:Authors: Tobias Liaudat <tobias.liaudat@cea.fr> and Jennifer Pollack <jennifer.pollack@cea.fr>
"""
from typing import Optional
import tensorflow as tf
from tensorflow.python.keras.engine import data_adapter
from wf_psf.psf_models import psf_models as psfm
from wf_psf.utils.read_config import RecursiveNamespace
from wf_psf.utils.configs_handler import DataConfigHandler
from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior
from wf_psf.psf_models.tf_layers import (
TFPolynomialZernikeField,
TFZernikeOPD,
TFBatchPolychromaticPSF,
TFBatchMonochromaticPSF,
TFNonParametricPolynomialVariationsOPD,
TFPhysicalLayer,
)
import logging
logger = logging.getLogger(__name__)
[docs]
@psfm.register_psfclass
class PhysicalPolychromaticFieldFactory(psfm.PSFModelBaseFactory):
"""Factory class for the TensorFlow Physical Polychromatic PSF Field Model.
This factory class is responsible for instantiating instances of the TensorFlow Physical Polychromatic PSF Field Model. It is registered with the PSF model factory registry.
Parameters
----------
ids : tuple
A tuple containing identifiers for the factory class.
Methods
-------
get_model_instance(model_params, training_params, data=None, coeff_mat=None)
Instantiates an instance of the TensorFlow Physical Polychromatic Field class with the provided parameters.
"""
ids = ("physical_poly",)
[docs]
def get_model_instance(self, model_params, training_params, data, coeff_mat=None):
"""Create an instance of the TensorFlow Physical Polychromatic Field model.
This method instantiates a `TFPhysicalPolychromaticField` object with the given model and training parameters, and data containing prior information like Zernike coefficients. Optionally, a coefficient matrix can be provided.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
training_params: Recursive Namespace
A Recursive Namespace object containing training hyperparameters for this PSF model class.
data: DataConfigHandler
A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients.
coeff_mat: Tensor or None, optional
Coefficient matrix defining the parametric PSF field model.
Returns
-------
TFPhysicalPolychromaticField
An instance of the TensorFlow Physical Polychromatic Field model.
"""
return TFPhysicalPolychromaticField(
model_params, training_params, data, coeff_mat
)
[docs]
class TFPhysicalPolychromaticField(tf.keras.Model):
"""TensorFlow Physical Polychromatic PSF Field class.
This class represents a polychromatic PSF field model with a physical layer.
It incorporates parametric and non-parametric modeling approaches to accurately reconstruct the point spread function (PSF) across multiple wavelengths.
The model provides functionalities for:
- Initializing model parameters and defining the physical PSF layer.
- Performing forward passes and computing wavefront transformations.
- Handling Zernike parameterization and coefficient matrices.
- Evaluating model performance and saving optimization history.
See individual method docstrings for more details.
"""
def __init__(self, model_params, training_params, data, coeff_mat=None):
"""Initialize the TFPhysicalPolychromaticField instance.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
training_params: Recursive Namespace
A Recursive Namespace object containing training hyperparameters for this PSF model class.
data: DataConfigHandler
A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients.
coeff_mat: Tensor or None, optional
Coefficient matrix defining the parametric PSF field model.
Returns
-------
TFPhysicalPolychromaticField
Initialized instance of the TFPhysicalPolychromaticField class.
"""
super().__init__(model_params, training_params, coeff_mat)
self._initialize_parameters_and_layers(
model_params, training_params, data, coeff_mat
)
def _initialize_parameters_and_layers(
self,
model_params: RecursiveNamespace,
training_params: RecursiveNamespace,
data: DataConfigHandler,
coeff_mat: Optional[tf.Tensor] = None,
):
"""Initialize Parameters of the PSF model.
This method sets up the PSF model parameters, observational positions,
Zernike coefficients, and components required for the automatically
differentiable optical forward model.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
training_params: Recursive Namespace
A Recursive Namespace object containing training hyperparameters for this PSF model class.
data: DataConfigHandler object
A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients.
coeff_mat: Tensor or None, optional
Initialization of the coefficient matrix defining the parametric psf field model.
Notes
-----
- Initializes Zernike parameters based on dataset priors.
- Configures the PSF model layers according to `model_params`.
- If `coeff_mat` is provided, the model coefficients are updated accordingly.
"""
self.output_Q = model_params.output_Q
self.obs_pos = get_obs_positions(data)
self.l2_param = model_params.param_hparams.l2_param
# Inputs: Save optimiser history Parametric model features
self.save_optim_history_param = (
model_params.param_hparams.save_optim_history_param
)
# Inputs: Save optimiser history NonParameteric model features
self.save_optim_history_nonparam = (
model_params.nonparam_hparams.save_optim_history_nonparam
)
self._initialize_zernike_parameters(model_params, data)
self._initialize_layers(model_params, training_params)
# Initialize the model parameters with non-default value
if coeff_mat is not None:
self.assign_coeff_matrix(coeff_mat)
def _initialize_zernike_parameters(self, model_params, data):
"""Initialize the Zernike parameters.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
data: DataConfigHandler object
A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients.
"""
self.zks_prior = get_zernike_prior(model_params, data, data.batch_size)
self.n_zks_total = max(
model_params.param_hparams.n_zernikes,
tf.cast(tf.shape(self.zks_prior)[1], tf.int32),
)
self.zernike_maps = psfm.generate_zernike_maps_3d(
self.n_zks_total, model_params.pupil_diameter
)
def _initialize_layers(self, model_params, training_params):
"""Initialize the layers of the PSF model.
This method initializes the layers of the PSF model, including the physical layer, polynomial Zernike field, batch polychromatic layer, and non-parametric OPD layer.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
training_params: Recursive Namespace
A Recursive Namespace object containing training hyperparameters for this PSF model class.
coeff_mat: Tensor or None, optional
Initialization of the coefficient matrix defining the parametric psf field model.
"""
self._initialize_physical_layer(model_params)
self._initialize_polynomial_Z_field(model_params)
self._initialize_Zernike_OPD(model_params)
self._initialize_batch_polychromatic_layer(model_params, training_params)
self._initialize_nonparametric_opd_layer(model_params, training_params)
def _initialize_physical_layer(self, model_params):
"""Initialize the physical layer of the PSF model.
This method initializes the physical layer of the PSF model using parameters
specified in the `model_params` object.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
"""
self.tf_physical_layer = TFPhysicalLayer(
self.obs_pos,
self.zks_prior,
interpolation_type=model_params.interpolation_type,
interpolation_args=model_params.interpolation_args,
)
def _initialize_polynomial_Z_field(self, model_params):
"""Initialize the polynomial Zernike field of the PSF model.
This method initializes the polynomial Zernike field of the PSF model using
parameters specified in the `model_params` object.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
"""
self.tf_poly_Z_field = TFPolynomialZernikeField(
x_lims=model_params.x_lims,
y_lims=model_params.y_lims,
random_seed=model_params.param_hparams.random_seed,
n_zernikes=model_params.param_hparams.n_zernikes,
d_max=model_params.param_hparams.d_max,
)
def _initialize_Zernike_OPD(self, model_params):
"""Initialize the Zernike OPD field of the PSF model.
This method initializes the Zernike Optical Path Difference
field of the PSF model using parameters specified in the `model_params` object.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
"""
# Initialize the zernike to OPD layer
self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps)
def _initialize_batch_polychromatic_layer(self, model_params, training_params):
"""Initialize the batch polychromatic PSF layer.
This method initializes the batch opd to batch polychromatic PSF layer
using the provided `model_params` and `training_params`.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
training_params: Recursive Namespace
A Recursive Namespace object containing training hyperparameters for this PSF model class.
"""
self.batch_size = training_params.batch_size
self.obscurations = psfm.tf_obscurations(
pupil_diam=model_params.pupil_diameter,
N_filter=model_params.LP_filter_length,
rotation_angle=model_params.obscuration_rotation_angle,
)
self.output_dim = model_params.output_dim
self.tf_batch_poly_PSF = TFBatchPolychromaticPSF(
obscurations=self.obscurations,
output_Q=self.output_Q,
output_dim=self.output_dim,
)
def _initialize_nonparametric_opd_layer(self, model_params, training_params):
"""Initialize the non-parametric OPD layer.
This method initializes the non-parametric OPD layer using the provided
`model_params` and `training_params`.
Parameters
----------
model_params: Recursive Namespace
A Recursive Namespace object containing parameters for this PSF model class.
training_params: Recursive Namespace
A Recursive Namespace object containing training hyperparameters for this PSF model class.
"""
# self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam
# self.opd_dim = tf.shape(self.zernike_maps)[1].numpy()
self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD(
x_lims=model_params.x_lims,
y_lims=model_params.y_lims,
random_seed=model_params.param_hparams.random_seed,
d_max=model_params.nonparam_hparams.d_max_nonparam,
opd_dim=tf.shape(self.zernike_maps)[1].numpy(),
)
[docs]
def get_coeff_matrix(self):
"""Get coefficient matrix."""
return self.tf_poly_Z_field.get_coeff_matrix()
[docs]
def assign_coeff_matrix(self, coeff_mat: Optional[tf.Tensor]) -> None:
"""Assign a coefficient matrix to the parametric PSF field model.
This method updates the coefficient matrix used by the parametric PSF field model,
allowing for customization or modification of the model's parameters.
If `coeff_mat` is `None`, the model will revert to using its default coefficient matrix.
Parameters
----------
coeff_mat : Optional[tf.Tensor]
A TensorFlow tensor representing the coefficient matrix for the PSF field model.
If `None`, the model will use the default coefficient matrix.
Returns
-------
None
"""
self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat)
[docs]
def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> None:
"""Set the output sampling rate (output_Q) for PSF generation.
This method updates the `output_Q` parameter, which defines the
resampling factor for generating PSFs at different resolutions
relative to the telescope's native sampling. It also allows optionally
updating `output_dim`, which sets the output resolution of the PSF model.
If `output_dim` is provided, the PSF model's output resolution is updated.
The method then reinitializes the batch polychromatic PSF generator
to reflect the updated parameters.
Parameters
----------
output_Q : float
The resampling factor that determines the output PSF resolution
relative to the telescope's native sampling.
output_dim : Optional[int], default=None
The new output dimension for the PSF model. If `None`, the output
dimension remains unchanged.
Returns
-------
None
"""
self.output_Q = output_Q
if output_dim is not None:
self.output_dim = output_dim
# Reinitialize the PSF batch poly generator
self.tf_batch_poly_PSF = TFBatchPolychromaticPSF(
obscurations=self.obscurations,
output_Q=self.output_Q,
output_dim=self.output_dim,
)
[docs]
def set_zero_nonparam(self):
"""Set the non-parametric part of the OPD (Optical Path Difference) to zero.
This method sets the non-parametric component of the Optical Path Difference (OPD)
to zero, effectively removing its contribution from the overall PSF (Point Spread Function).
"""
self.tf_np_poly_opd.set_alpha_zero()
[docs]
def set_nonzero_nonparam(self):
"""Set the non-parametric part to non-zero values.
This method sets the non-parametric component of the Optical Path Difference (OPD)
to non-zero values, allowing it to contribute to the overall PSF (Point Spread Function).
"""
self.tf_np_poly_opd.set_alpha_identity()
[docs]
def set_trainable_layers(self, param_bool=True, nonparam_bool=True):
"""Set the layers to be trainable.
A method to set layers to be trainable.
Parameters
----------
param_bool: bool
Boolean flag for parametric model layers
nonparam_bool: bool
Boolean flag for non-parametric model layers
"""
self.tf_np_poly_opd.trainable = nonparam_bool
self.tf_poly_Z_field.trainable = param_bool
[docs]
def pad_zernikes(self, zk_param, zk_prior):
"""Pad the Zernike coefficients to match the maximum length.
Pad the input Zernike coefficient tensors to match the length of the
maximum number of Zernike coefficients among the parametric and prior parts.
Parameters
----------
zk_param: tf.Tensor
Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1].
zk_prior: tf.Tensor
Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1].
Returns
-------
padded_zk_param: tf.Tensor
Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1].
padded_zk_prior: tf.Tensor
Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1].
"""
# Calculate the number of Zernikes to pad for parametric and prior parts
pad_num_param = self.n_zks_total - tf.shape(zk_param)[1]
pad_num_prior = self.n_zks_total - tf.shape(zk_prior)[1]
# Pad the Zernike coefficients for parametric and prior parts
padded_zk_param = tf.cond(
tf.not_equal(pad_num_param, 0),
lambda: tf.pad(zk_param, [(0, 0), (0, pad_num_param), (0, 0), (0, 0)]),
lambda: zk_param,
)
padded_zk_prior = tf.cond(
tf.not_equal(pad_num_prior, 0),
lambda: tf.pad(zk_prior, [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)]),
lambda: zk_prior,
)
return padded_zk_param, padded_zk_prior
[docs]
def predict_step(self, data, evaluate_step=False):
"""Predict (inference) step.
A method to enable a special type of
interpolation (different from training) for
the physical layer.
Parameters
----------
data : NOT SURE
evaluate_step : bool
Boolean flag to evaluate step
Returns
-------
poly_psfs TFBatchPolychromaticPSF
Instance of TFBatchPolychromaticPSF class containing computed polychromatic PSFs.
"""
if evaluate_step:
input_data = data
else:
# Format input data
data = data_adapter.expand_1d(data)
input_data, _, _ = data_adapter.unpack_x_y_sample_weight(data)
# Unpack inputs
input_positions = input_data[0]
packed_SEDs = input_data[1]
# Compute zernikes from parametric model and physical layer
zks_coeffs = self.predict_zernikes(input_positions)
# Propagate to obtain the OPD
param_opd_maps = self.tf_zernike_OPD(zks_coeffs)
# Calculate the non parametric part
nonparam_opd_maps = self.tf_np_poly_opd(input_positions)
# Add the estimations
opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps)
# Compute the polychromatic PSFs
poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs])
return poly_psfs
[docs]
def predict_mono_psfs(self, input_positions, lambda_obs, phase_N):
"""Predict a set of monochromatic Point Spread Functions (PSFs) at desired positions.
This method calculates monochromatic PSFs based on the provided input positions,
observed wavelength, and required wavefront dimension.
Parameters
----------
input_positions : Tensor [batch_dim, 2]
Positions at which to compute the PSFs.
lambda_obs : float
Observed wavelength in micrometers (um).
phase_N : int
Required wavefront dimension. This should be calculated using a SimPSFToolkit
instance. Example:
```
simPSF_np = wf.SimPSFToolkit(...)
phase_N = simPSF_np.feasible_N(lambda_obs)
```
Returns
-------
mono_psf_batch : Tensor
Batch of monochromatic PSFs.
"""
# Initialise the monochromatic PSF batch calculator
tf_batch_mono_psf = TFBatchMonochromaticPSF(
obscurations=self.obscurations,
output_Q=self.output_Q,
output_dim=self.output_dim,
)
# Set the lambda_obs and the phase_N parameters
tf_batch_mono_psf.set_lambda_phaseN(phase_N, lambda_obs)
# Predict zernikes from parametric model and physical layer
zks_coeffs = self.predict_zernikes(input_positions)
# Propagate to obtain the OPD
param_opd_maps = self.tf_zernike_OPD(zks_coeffs)
# Calculate the non parametric part
nonparam_opd_maps = self.tf_np_poly_opd(input_positions)
# Add the estimations
opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps)
# Compute the monochromatic PSFs
mono_psf_batch = tf_batch_mono_psf(opd_maps)
return mono_psf_batch
[docs]
def predict_opd(self, input_positions):
"""Predict the OPD at some positions.
Parameters
----------
input_positions: Tensor [batch_dim, 2]
Positions to predict the OPD.
Returns
-------
opd_maps : Tensor [batch, opd_dim, opd_dim]
OPD at requested positions.
"""
# Predict zernikes from parametric model and physical layer
zks_coeffs = self.predict_zernikes(input_positions)
# Propagate to obtain the OPD
param_opd_maps = self.tf_zernike_OPD(zks_coeffs)
# Calculate the non parametric part
nonparam_opd_maps = self.tf_np_poly_opd(input_positions)
# Add the estimations
opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps)
return opd_maps
[docs]
def compute_zernikes(self, input_positions):
"""Compute Zernike coefficients at a batch of positions.
This method computes the Zernike coefficients for a batch of input positions
using both the parametric model and the physical layer.
Parameters
----------
input_positions: Tensor [batch_dim, 2]
Positions for which to compute the Zernike coefficients.
Returns
-------
zernike_coefficients : Tensor [batch, n_zks_total, 1, 1]
Computed Zernike coefficients for the input positions.
Notes
-----
This method combines the predictions from both the parametric model and
the physical layer to obtain the final Zernike coefficients.
"""
# Calculate parametric part
zernike_params = self.tf_poly_Z_field(input_positions)
# Calculate the physical layer
zernike_prior = self.tf_physical_layer.call(input_positions)
# Pad and sum the zernike coefficients
padded_zernike_params, padded_zernike_prior = self.pad_zernikes(
zernike_params, zernike_prior
)
zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior)
return zernike_coeffs
[docs]
def predict_zernikes(self, input_positions):
"""Predict Zernike coefficients at a batch of positions.
This method predicts the Zernike coefficients for a batch of input positions
using both the parametric model and the physical layer. During training,
the prediction from the physical layer is typically not used.
Parameters
----------
input_positions: Tensor [batch_dim, 2]
Positions for which to predict the Zernike coefficients.
Returns
-------
zernike_coeffs : Tensor [batch, n_zks_total, 1, 1]
Predicted Zernike coefficients for the input positions.
Notes
-----
At training time, the prediction from the physical layer may not be utilized,
as the model might be trained to rely solely on the parametric part.
"""
# Calculate parametric part
zernike_params = self.tf_poly_Z_field(input_positions)
# Calculate the prediction from the physical layer
physical_layer_prediction = self.tf_physical_layer.predict(input_positions)
# Pad and sum the Zernike coefficients
padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes(
zernike_params, physical_layer_prediction
)
zernike_coeffs = tf.math.add(
padded_zernike_params, padded_physical_layer_prediction
)
return zernike_coeffs
[docs]
def call(self, inputs, training=True):
"""Define the PSF (Point Spread Function) field forward model.
This method defines the forward model of the PSF field, which involves several steps:
1. Transforming input positions into Zernike coefficients.
2. Converting Zernike coefficients into Optical Path Difference (OPD) maps.
3. Combining OPD maps with Spectral Energy Distribution (SED) information to generate
polychromatic PSFs.
Parameters
----------
inputs : list
List containing input data required for PSF computation. It should contain two
elements:
- input_positions: Tensor [batch_dim, 2]
Positions at which to compute the PSFs.
- packed_SEDs: Tensor [batch_dim, ...]
Packed Spectral Energy Distributions (SEDs) for the corresponding positions.
training : bool, optional
Indicates whether the model is being trained or used for inference. Defaults to True.
Returns
-------
poly_psfs : Tensor
Polychromatic PSFs generated by the forward model.
Notes
-----
- The `input_positions` tensor should have a shape of [batch_dim, 2], where each row
represents the x and y coordinates of a position.
- The `packed_SEDs` tensor should have a shape of [batch_dim, ...], containing the SED
information for each position.
- During training, this method computes the Zernike coefficients from the input positions
and calculates the corresponding OPD maps. Additionally, it adds an L2 loss term based on
the parametric OPD maps.
- During inference, this method generates predictions using precomputed OPD maps or by
propagating through the forward model.
Examples
--------
# Usage during training
inputs = [input_positions, packed_SEDs]
poly_psfs = psf_model(inputs)
# Usage during inference
inputs = [input_positions, packed_SEDs]
poly_psfs = psf_model(inputs, training=False)
"""
# Unpack inputs
input_positions = inputs[0]
packed_SEDs = inputs[1]
# For the training
if training:
# Compute zernikes from parametric model and physical layer
zks_coeffs = self.compute_zernikes(input_positions)
# Propagate to obtain the OPD
param_opd_maps = self.tf_zernike_OPD(zks_coeffs)
# Add l2 loss on the parametric OPD
self.add_loss(
self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps))
)
# Calculate the non parametric part
nonparam_opd_maps = self.tf_np_poly_opd(input_positions)
# Add the estimations
opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps)
# Compute the polychromatic PSFs
poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs])
# For the inference
else:
# Compute predictions
poly_psfs = self.predict_step(inputs, evaluate_step=True)
return poly_psfs