Source code for wf_psf.psf_models.psf_model_semiparametric

"""PSF Model Semi-Parametric.

A module which defines the classes and methods
to manage the parameters of the psf semi-parametric model.

:Authors: Tobias Liaudat <> and Jennifer Pollack <>


import numpy as np
import tensorflow as tf
from tensorflow.python.keras.engine import data_adapter
from wf_psf.psf_models import psf_models as psfm
from wf_psf.psf_models import tf_layers as tfl
from wf_psf.utils.utils import PI_zernikes, zernike_generator
from wf_psf.psf_models.tf_layers import (
import logging

logger = logging.getLogger(__name__)

[docs] @psfm.register_psfclass class TF_SemiParam_field(tf.keras.Model): """PSF field forward model. Semi parametric model based on the Zernike polynomial basis. Parameters ---------- ids: tuple A tuple storing the string attribute of the PSF model class model_params: Recursive Namespace Recursive Namespace object containing parameters for this PSF model class training_params: Recursive Namespace Recursive Namespace object containing training hyperparameters for this PSF model class coeff_mat: Tensor or None Initialization of the coefficient matrix defining the parametric psf field model """ ids = ("poly",) def __init__(self, model_params, training_params, coeff_mat=None): super().__init__() # Inputs: random seed for Tensor Flow initialization self.random_seed = model_params.param_hparams.random_seed # Inputs: pupil diameter self.pupil_diam = model_params.pupil_diameter # Inputs: oversampling used self.output_Q = model_params.output_Q # Inputs: TF_poly_Z_field self.n_zernikes = model_params.param_hparams.n_zernikes self.d_max = model_params.param_hparams.d_max self.x_lims = model_params.x_lims self.y_lims = model_params.y_lims # Inputs: TF_NP_poly_OPD self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam self.zernike_maps = psfm.tf_zernike_cube(self.n_zernikes, self.pupil_diam) self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() # Inputs: TF_batch_poly_PSF self.batch_size = training_params.batch_size self.obscurations = psfm.tf_obscurations(self.pupil_diam) self.output_dim = model_params.output_dim # Inputs: Loss self.l2_param = model_params.param_hparams.l2_param # Inputs: Project DD model features self.project_dd_features = model_params.nonparam_hparams.project_dd_features # Inputs: Reset DD model features self.reset_dd_features = model_params.nonparam_hparams.reset_dd_features # Inputs: Save optimiser history Parametric model features self.save_optim_history_param = ( model_params.param_hparams.save_optim_history_param ) # Inputs: Save optimiser history NonParameteric model features self.save_optim_history_nonparam = ( model_params.nonparam_hparams.save_optim_history_nonparam ) # Initialize the first layer self.tf_poly_Z_field = tfl.TF_poly_Z_field( x_lims=self.x_lims, y_lims=self.y_lims, random_seed=self.random_seed, n_zernikes=self.n_zernikes, d_max=self.d_max, ) # Initialize the zernike to OPD layer self.tf_zernike_OPD = tfl.TF_zernike_OPD(zernike_maps=self.zernike_maps) # Initialize the non-parametric (np) layer self.tf_np_poly_opd = tfl.TF_NP_poly_OPD( x_lims=self.x_lims, y_lims=self.y_lims, random_seed=self.random_seed, d_max=self.d_max_nonparam, opd_dim=self.opd_dim, ) # Initialize the batch opd to batch polychromatic PSF layer self.tf_batch_poly_PSF = tfl.TF_batch_poly_PSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) # Initialize the model parameters with non-default value # self._coeff_mat = coeff_mat if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat)
[docs] def get_coeff_matrix(self): """Get coefficient matrix. A function to get the coefficient matrix for parametric model. Returns ------- coefficient matrix: float Tensor Flow coefficient matrix for the parametric PSF field model """ return self.tf_poly_Z_field.get_coeff_matrix()
[docs] def assign_coeff_matrix(self, coeff_mat): """Assign coefficient matrix. A function to set the coefficient matrix. Parameters ---------- coeff_mat: float Tensor Flow coefficient matrix for the parametric PSF field model """ self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat)
[docs] def set_zero_nonparam(self): """Set to zero the non-parametric part. A function to set non-parametric alpha parameters equal to zero. """ self.tf_np_poly_opd.set_alpha_zero()
[docs] def set_nonzero_nonparam(self): """Set to non-zero the non-parametric part. A function to set non-parametric alpha parameters equal to non-zero values. """ self.tf_np_poly_opd.set_alpha_identity()
[docs] def set_trainable_layers(self, param_bool=True, nonparam_bool=True): """Set Trainable Layers. A function to set the layers to be trainable or not. Parameters ---------- param_bool: bool Boolean flag for the parametric layers nonparam_bool: bool Boolean flag for the non-parametric layers """ self.tf_np_poly_opd.trainable = nonparam_bool self.tf_poly_Z_field.trainable = param_bool
[docs] def set_output_Q(self, output_Q, output_dim=None): """Set the value of the output_Q parameter. Useful for generating/predicting PSFs at a different sampling wrt the observation sampling. Parameters ---------- output_Q: float Oversampling factor output_dim: int Output dimension """ self.output_Q = output_Q if output_dim is not None: self.output_dim = output_dim # Reinitialize the PSF batch poly generator self.tf_batch_poly_PSF = TF_batch_poly_PSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, )
[docs] def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # TO Do Clean up """Predict a set of monochromatic PSF at desired positions. Parameters ---------- input_positions: Tensor(batch_dim x 2) lambda_obs: float Observed wavelength in um. phase_N: int Required wavefront dimension. Should be calculated with as: ``simPSF_np = wf.SimPSFToolkit(...)`` ``phase_N = simPSF_np.feasible_N(lambda_obs)`` """ # Initialise the monochromatic PSF batch calculator tf_batch_mono_psf = TF_batch_mono_PSF( obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) # Set the lambda_obs and the phase_N parameters tf_batch_mono_psf.set_lambda_phaseN(phase_N, lambda_obs) # Calculate parametric part zernike_coeffs = self.tf_poly_Z_field(input_positions) param_opd_maps = self.tf_zernike_OPD(zernike_coeffs) # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) # Compute the monochromatic PSFs mono_psf_batch = tf_batch_mono_psf(opd_maps) return mono_psf_batch
[docs] def predict_opd(self, input_positions): """Predict the OPD at some positions. Parameters ---------- input_positions: Tensor(batch_dim x 2) Positions to predict the OPD. Returns ------- opd_maps : Tensor [batch x opd_dim x opd_dim] OPD at requested positions. """ # Calculate parametric part zernike_coeffs = self.tf_poly_Z_field(input_positions) param_opd_maps = self.tf_zernike_OPD(zernike_coeffs) # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) return opd_maps
[docs] def assign_S_mat(self, S_mat): """Assign DD features matrix.""" self.tf_np_poly_opd.assign_S_mat(S_mat)
[docs] def project_DD_features(self, tf_zernike_cube): """ Project non-parametric wavefront onto first n_z Zernikes and transfer their parameters to the parametric model. """ # Compute Zernike norm for projections n_pix_zernike = PI_zernikes(tf_zernike_cube[0, :, :], tf_zernike_cube[0, :, :]) # Multiply Alpha matrix with DD features matrix S inter_res_v2 = tf.tensordot( self.tf_np_poly_opd.alpha_mat[: self.tf_poly_Z_field.coeff_mat.shape[1], :], self.tf_np_poly_opd.S_mat, axes=1, ) # Project over first n_z Zernikes # TO DO: Clean up delta_C_poly = tf.constant( np.array( [ [ PI_zernikes( tf_zernike_cube[i, :, :], inter_res_v2[j, :, :], n_pix_zernike, ) for j in range(self.tf_poly_Z_field.coeff_mat.shape[1]) ] for i in range(self.n_zernikes) ] ), dtype=tf.float32, ) old_C_poly = self.tf_poly_Z_field.coeff_mat # Corrected parametric coeff matrix new_C_poly = old_C_poly + delta_C_poly self.assign_coeff_matrix(new_C_poly) # Remove extracted features from non-parametric model # Mix DD features with matrix alpha S_tilde = tf.tensordot( self.tf_np_poly_opd.alpha_mat, self.tf_np_poly_opd.S_mat, axes=1 ) # TO DO: Clean Up # Get beta tilde as the protection of the first n_param_poly_terms (6 for d_max=2) onto the first n_zernikes. beta_tilde_inner = np.array( [ [ PI_zernikes(tf_zernike_cube[j, :, :], S_tilde_slice, n_pix_zernike) for j in range(self.n_zernikes) ] for S_tilde_slice in S_tilde[ : self.tf_poly_Z_field.coeff_mat.shape[1], :, : ] ] ) # Only pad in the first dimension so we get a matrix of size (d_max_nonparam_terms)x(n_zernikes) --> 21x15 or 21x45. beta_tilde = np.pad( beta_tilde_inner, [(0, S_tilde.shape[0] - beta_tilde_inner.shape[0]), (0, 0)], mode="constant", ) # Unmix beta tilde with the inverse of alpha beta = tf.constant( np.linalg.inv(self.tf_np_poly_opd.alpha_mat) @ beta_tilde, dtype=tf.float32 ) # To do: Clarify comment or delete. # Get the projection for the unmixed features # Now since beta.shape[1]=n_zernikes we can take the whole beta matrix. S_mat_projected = tf.tensordot(beta, tf_zernike_cube, axes=[1, 0]) # Subtract the projection from the DD features S_new = self.tf_np_poly_opd.S_mat - S_mat_projected self.assign_S_mat(S_new)
[docs] def call(self, inputs): """Define the PSF field forward model. [1] From positions to Zernike coefficients [2] From Zernike coefficients to OPD maps [3] From OPD maps and SED info to polychromatic PSFs OPD: Optical Path Differences """ # Unpack inputs input_positions = inputs[0] packed_SEDs = inputs[1] # Forward model # Calculate parametric part zernike_coeffs = self.tf_poly_Z_field(input_positions) param_opd_maps = self.tf_zernike_OPD(zernike_coeffs) # Add l2 loss on the parametric OPD self.add_loss( self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) ) # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) return poly_psfs