"""TensorFlow layers for PSF modelling.
This module contains TensorFlow layers to model PSF variations across
the field of view.
:Author: Tobias Liaudat <tobias.liaudat@cea.fr>
"""
import tensorflow as tf
import tensorflow_addons as tfa
from wf_psf.psf_models.tf_modules import TFMonochromaticPSF
from wf_psf.utils.utils import calc_poly_position_mat
import wf_psf.utils.utils as utils
import logging
logger = logging.getLogger(__name__)
[docs]
class TFPolynomialZernikeField(tf.keras.layers.Layer):
"""Calculate the zernike coefficients for a given position.
This module implements a polynomial model of Zernike
coefficient variation.
Parameters
----------
n_zernikes: int
Number of Zernike polynomials to consider
d_max: int
Max degree of polynomial determining the FoV variations.
"""
def __init__(
self,
x_lims,
y_lims,
random_seed=None,
n_zernikes=45,
d_max=2,
name="TF_poly_Z_field",
):
super().__init__(name=name)
self.n_zernikes = n_zernikes
self.d_max = d_max
self.coeff_mat = None
self.x_lims = x_lims
self.y_lims = y_lims
self.random_seed = random_seed
self.init_coeff_matrix()
[docs]
def get_poly_coefficients_shape(self):
"""Return the shape of the coefficient matrix."""
return (self.n_zernikes, int((self.d_max + 1) * (self.d_max + 2) / 2))
[docs]
def assign_coeff_matrix(self, coeff_mat):
"""Assign coefficient matrix."""
self.coeff_mat.assign(coeff_mat)
[docs]
def get_coeff_matrix(self):
"""Get coefficient matrix."""
return self.coeff_mat
[docs]
def init_coeff_matrix(self):
"""Initialize coefficient matrix."""
tf.random.set_seed(self.random_seed)
coef_init = tf.random_uniform_initializer(minval=-0.01, maxval=0.01)
self.coeff_mat = tf.Variable(
initial_value=coef_init(self.get_poly_coefficients_shape()),
trainable=True,
dtype=tf.float32,
)
[docs]
def call(self, positions):
"""Calculate the zernike coefficients for a given position.
The position polynomial matrix and the coefficients should be
set before calling this function.
Parameters
----------
positions: Tensor(batch, 2)
First element is x-axis, second is y-axis.
Returns
-------
zernikes_coeffs: Tensor(batch, n_zernikes, 1, 1)
"""
poly_mat = calc_poly_position_mat(
positions, self.x_lims, self.y_lims, self.d_max
)
zernikes_coeffs = tf.transpose(tf.linalg.matmul(self.coeff_mat, poly_mat))
return zernikes_coeffs[:, :, tf.newaxis, tf.newaxis]
[docs]
class TFZernikeOPD(tf.keras.layers.Layer):
"""Calculate the OPD from Zernike maps and coefficients.
This class generates OPD maps from Zernike coefficients and Zernike maps.
Both Zernike maps and Zernike coefficients must be provided
to the class.
Parameters
----------
zernike_maps: Tensor (Num_coeffs, x_dim, y_dim)
z_coeffs: Tensor (batch_size, n_zernikes, 1, 1)
Returns
-------
opd: Tensor (batch_size, x_dim, y_dim)
"""
def __init__(self, zernike_maps, name="TF_zernike_OPD"):
super().__init__(name=name)
self.zernike_maps = zernike_maps
[docs]
def call(self, z_coeffs):
"""Perform the weighted sum of Zernikes coeffs and maps.
Returns
-------
opd: Tensor (batch_size, x_dim, y_dim)
"""
return tf.math.reduce_sum(tf.math.multiply(self.zernike_maps, z_coeffs), axis=1)
[docs]
class TFBatchPolychromaticPSF(tf.keras.layers.Layer):
"""Calculate a polychromatic PSF from an OPD and stored SED values.
The calculation of the packed values with the respective SED is done
with the PSFSimulator class but outside the TF class.
Parameters
----------
obscurations: Tensor [opd_dim, opd_dim]
Obscurations to apply to the wavefront.
packed_SED_data: Tensor [batch_size, 3, n_bins_lda]
Comes from: tf.convert_to_tensor(list(list(Tensor,Tensor,Tensor)))
Where each inner list consist of a packed_elem:
packed_elems: Tuple of tensors
Contains three 1D tensors with the parameters needed for
the calculation of one monochromatic PSF.
packed_elems[0]: phase_N
packed_elems[1]: lambda_obs
packed_elems[2]: SED_norm_val
The SED data is constant in a FoV.
psf_batch: Tensor [batch_size, output_dim, output_dim]
Tensor containing the psfs that will be updated each
time a calculation is required. REMOVED!
"""
def __init__(self, obscurations, output_Q, output_dim=64, name="TF_batch_poly_PSF"):
super().__init__(name=name)
self.output_Q = output_Q
self.obscurations = obscurations
self.output_dim = output_dim
self.current_opd = None
[docs]
def calculate_monochromatic_PSF(self, packed_elems):
"""Calculate monochromatic PSF from packed elements.
packed_elems[0]: phase_N
packed_elems[1]: lambda_obs
packed_elems[2]: SED_norm_val
"""
# Unpack elements
phase_N = packed_elems[0]
lambda_obs = packed_elems[1]
SED_norm_val = packed_elems[2]
# Build the monochromatic PSF generator
tf_monochromatic_psf_gen = TFMonochromaticPSF(
phase_N,
lambda_obs,
self.obscurations,
output_Q=self.output_Q,
output_dim=self.output_dim,
)
# Calculate the PSF
monochromatic_psf = tf_monochromatic_psf_gen.__call__(self.current_opd)
monochromatic_psf = tf.squeeze(monochromatic_psf, axis=0)
# Multiply with the respective normalized SED and return
return tf.math.scalar_mul(SED_norm_val, monochromatic_psf)
[docs]
def calculate_polychromatic_PSF(self, packed_elems):
"""Calculate a polychromatic PSF."""
self.current_opd = packed_elems[0][tf.newaxis, :, :]
SED_pack_data = packed_elems[1]
def _calculate_polychromatic_PSF(elems_to_unpack):
return tf.map_fn(
self.calculate_monochromatic_PSF,
elems_to_unpack,
parallel_iterations=10,
fn_output_signature=tf.float32,
swap_memory=True,
)
# Readability
# stacked_psfs = _calculate_poly_PSF(packed_elems)
# poly_psf = tf.math.reduce_sum(stacked_psfs, axis=0)
# return poly_psf
stack_psf = _calculate_polychromatic_PSF(SED_pack_data)
polychromatic_psf = tf.math.reduce_sum(stack_psf, axis=0)
return polychromatic_psf
[docs]
def call(self, inputs):
"""Calculate the batch polychromatic PSFs."""
# Unpack Inputs
opd_batch = inputs[0]
packed_SED_data = inputs[1]
def _calculate_PSF_batch(elems_to_unpack):
return tf.map_fn(
self.calculate_polychromatic_PSF,
elems_to_unpack,
parallel_iterations=10,
fn_output_signature=tf.float32,
swap_memory=True,
)
psf_polychromatic_batch = _calculate_PSF_batch((opd_batch, packed_SED_data))
return psf_polychromatic_batch
[docs]
class TFBatchMonochromaticPSF(tf.keras.layers.Layer):
"""Calculate a monochromatic PSF from a batch of OPDs.
The calculation of the ``phase_N`` variable is done
with the PSFSimulator class but outside the TF class.
Parameters
----------
obscurations: Tensor [opd_dim, opd_dim]
Obscurations to apply to the wavefront.
psf_batch: Tensor [batch_size, output_dim, output_dim]
Tensor containing the psfs that will be updated each
time a calculation is required.
Can be started with zeros.
output_Q: int
Output oversampling value.
output_dim: int
Output PSF stamp dimension.
"""
def __init__(self, obscurations, output_Q, output_dim=64, name="Pbatch_mono_PSF"):
super().__init__(name=name)
self.output_Q = output_Q
self.obscurations = obscurations
self.output_dim = output_dim
self.phase_N = None
self.lambda_obs = None
self.tf_mono_psf_gen = None
self.current_opd = None
[docs]
def calculate_monochromatic_PSF(self, current_opd):
"""Calculate monochromatic PSF from OPD info."""
# Calculate the PSF
mono_psf = self.tf_mono_psf_gen.__call__(current_opd[tf.newaxis, :, :])
mono_psf = tf.squeeze(mono_psf, axis=0)
return mono_psf
[docs]
def init_mono_PSF(self):
"""Initialise or restart the PSF generator."""
self.tf_mono_psf_gen = TFMonochromaticPSF(
self.phase_N,
self.lambda_obs,
self.obscurations,
output_Q=self.output_Q,
output_dim=self.output_dim,
)
[docs]
def set_lambda_phaseN(self, phase_N=914, lambda_obs=0.7):
"""Set the lambda value for monochromatic PSFs and the phaseN."""
self.phase_N = phase_N
self.lambda_obs = lambda_obs
self.init_mono_PSF()
[docs]
def set_output_params(self, output_Q, output_dim):
"""Set output patams, Q and dimension."""
self.output_Q = output_Q
self.output_dim = output_dim
self.init_mono_PSF()
[docs]
def call(self, opd_batch):
"""Calculate the batch poly PSFs."""
if self.phase_N is None:
self.set_lambda_phaseN()
def _calculate_PSF_batch(elems_to_unpack):
return tf.map_fn(
self.calculate_monochromatic_PSF,
elems_to_unpack,
parallel_iterations=10,
fn_output_signature=tf.float32,
swap_memory=True,
)
mono_psf_batch = _calculate_PSF_batch(opd_batch)
return mono_psf_batch
[docs]
class TFNonParametricPolynomialVariationsOPD(tf.keras.layers.Layer):
"""Non-parametric OPD generation with polynomial variations.
Parameters
----------
x_lims: [int, int]
Limits of the x axis.
y_lims: [int, int]
Limits of the y axis.
random_seed: int
Random seed initialization for Tensor Flow
d_max: int
Max degree of polynomial determining the FoV variations.
opd_dim: int
Dimension of the OPD maps. Same as pupil diameter.
"""
def __init__(
self,
x_lims,
y_lims,
random_seed=None,
d_max=3,
opd_dim=256,
name="TF_NP_poly_OPD",
):
super().__init__(name=name)
# Parameters
self.x_lims = x_lims
self.y_lims = y_lims
self.random_seed = random_seed
self.d_max = d_max
self.opd_dim = opd_dim
self.n_poly = int((self.d_max + 1) * (self.d_max + 2) / 2)
# Variables
self.S_mat = None
self.alpha_mat = None
self.init_vars()
[docs]
def init_vars(self):
"""Initialize trainable variables.
Basic initialization. Random uniform for S and identity for alpha.
"""
# S initialization
tf.random.set_seed(self.random_seed)
random_init = tf.random_uniform_initializer(minval=-0.001, maxval=0.001)
self.S_mat = tf.Variable(
initial_value=random_init(shape=[self.n_poly, self.opd_dim, self.opd_dim]),
trainable=True,
dtype=tf.float32,
)
# Alpha initialization
self.alpha_mat = tf.Variable(
initial_value=tf.eye(self.n_poly), trainable=True, dtype=tf.float32
)
# Update random seed for next call
if self.random_seed is not None:
self.random_seed += 1
[docs]
def set_alpha_zero(self):
"""Set alpha matrix to zero."""
self.alpha_mat.assign(tf.zeros_like(self.alpha_mat, dtype=tf.float32))
[docs]
def set_alpha_identity(self):
"""Set alpha matrix to the identity."""
self.alpha_mat.assign(tf.eye(self.n_poly, dtype=tf.float32))
[docs]
def assign_S_mat(self, S_mat):
"""Assign DD features matrix."""
self.S_mat.assign(S_mat)
[docs]
def call(self, positions):
"""Calculate the OPD maps for the given positions.
Calculating: Pi(pos) x alpha x S
Parameters
----------
positions: Tensor(batch, 2)
First element is x-axis, second is y-axis.
Returns
-------
opd_maps: Tensor(batch, opd_dim, opd_dim)
"""
# Calculate the Pi matrix
poly_mat = calc_poly_position_mat(
positions, self.x_lims, self.y_lims, self.d_max
)
# We need to transpose it here to have the batch dimension at first
poly_mat = tf.transpose(poly_mat, perm=[1, 0])
inter_res = tf.linalg.matmul(poly_mat, self.alpha_mat)
return tf.tensordot(inter_res, self.S_mat, axes=1)
[docs]
class TFNonParametricMCCDOPDv2(tf.keras.layers.Layer):
"""Non-parametric OPD generation with hybrid-MCCD variations.
Parameters
----------
obs_pos: tensor(n_stars, 2)
Observed positions of the `n_stars` in the dataset. The indexing of the
positions has to correspond to the indexing in the `spatial_dic`.
spatial_dic: tensor(n_stars, n_dic_elems)
Dictionary containing the spatial-constraint dictionary. `n_stars`
corresponds to the total number of stars in the dataset. `n_dic_elems`
corresponds to the number of elements of the dictionary, not to be
confounded with `n_comp`, the total number of non-parametric features
of the wavefront-PSF.
x_lims: [int, int]
Limits of the x axis.
y_lims: [int, int]
Limits of the y axis.
graph_comps: int
Number of wavefront-PSF features correspondign to the graph constraint.
d_max: int
Max degree of polynomial determining the FoV variations. The number of
wavefront-PSF features of the polynomial part is
computed `(d_max+1)*(d_max+2)/2`.
opd_dim: int
Dimension of the OPD maps. Same as pupil diameter.
"""
def __init__(
self,
obs_pos,
spatial_dic,
x_lims,
y_lims,
random_seed=None,
d_max=2,
graph_features=6,
l1_rate=1e-5,
opd_dim=256,
name="TF_NP_MCCD_OPD_v2",
):
super().__init__(name=name)
# Parameters
self.x_lims = x_lims
self.y_lims = y_lims
self.random_seed = random_seed
logger.info(type(self.random_seed))
self.d_max = d_max
self.opd_dim = opd_dim
# L1 regularisation parameter
self.l1_rate = l1_rate
self.obs_pos = obs_pos
self.poly_dic = spatial_dic[0]
self.graph_dic = spatial_dic[1]
self.n_stars = self.poly_dic.shape[0]
self.n_graph_elems = self.graph_dic.shape[1]
self.poly_features = int((self.d_max + 1) * (self.d_max + 2) / 2)
self.graph_features = graph_features
# Variables
self.S_poly = None
self.S_graph = None
self.alpha_poly = None
self.alpha_graph = None
self.init_vars()
[docs]
def init_vars(self):
"""Initialize trainable variables.
Basic initialization. Random uniform for S and identity for alpha.
"""
# S initialization
tf.random.set_seed(self.random_seed)
random_init = tf.random_uniform_initializer(minval=-0.001, maxval=0.001)
self.S_poly = tf.Variable(
initial_value=random_init(
shape=[self.poly_features, self.opd_dim, self.opd_dim]
),
trainable=True,
dtype=tf.float32,
)
self.S_graph = tf.Variable(
initial_value=random_init(
shape=[self.graph_features, self.opd_dim, self.opd_dim]
),
trainable=True,
dtype=tf.float32,
)
# Alpha initialization
self.alpha_poly = tf.Variable(
initial_value=tf.eye(
num_rows=self.poly_features, num_columns=self.poly_features
),
trainable=True,
dtype=tf.float32,
)
self.alpha_graph = tf.Variable(
initial_value=tf.eye(
num_rows=self.n_graph_elems, num_columns=self.graph_features
),
trainable=True,
dtype=tf.float32,
)
# Update random seed for next call
if self.random_seed is not None:
self.random_seed += 1
[docs]
def set_alpha_zero(self):
"""Set alpha matrix to zero."""
self.alpha_poly.assign(tf.zeros_like(self.alpha_poly, dtype=tf.float32))
self.alpha_graph.assign(tf.zeros_like(self.alpha_graph, dtype=tf.float32))
[docs]
def set_alpha_identity(self):
"""Set alpha matrix to the identity."""
self.alpha_poly.assign(
tf.eye(
num_rows=self.poly_features,
num_columns=self.poly_features,
dtype=tf.float32,
)
)
self.alpha_graph.assign(
tf.eye(
num_rows=self.n_graph_elems,
num_columns=self.graph_features,
dtype=tf.float32,
)
)
[docs]
def predict(self, positions):
"""Prediction step."""
## Polynomial part
# Calculate the Pi matrix
poly_mat = calc_poly_position_mat(
positions, self.x_lims, self.y_lims, self.d_max
)
# We need to transpose it here to have the batch dimension at first
A_poly = tf.linalg.matmul(tf.transpose(poly_mat, perm=[1, 0]), self.alpha_poly)
interp_poly_opd = tf.tensordot(A_poly, self.S_poly, axes=1)
## Graph part
A_graph_train = tf.linalg.matmul(self.graph_dic, self.alpha_graph)
# RBF interpolation
# Order 2 means a thin_plate RBF interpolation
# All tensors need to expand one dimension to fulfil requirement in
# the tfa's interpolate_spline function
A_interp_graph = tfa.image.interpolate_spline(
train_points=tf.expand_dims(self.obs_pos, axis=0),
train_values=tf.expand_dims(A_graph_train, axis=0),
query_points=tf.expand_dims(positions, axis=0),
order=2,
regularization_weight=0.0,
)
# Remove extra dimension required by tfa's interpolate_spline
A_interp_graph = tf.squeeze(A_interp_graph, axis=0)
interp_graph_opd = tf.tensordot(A_interp_graph, self.S_graph, axes=1)
return tf.math.add(interp_poly_opd, interp_graph_opd)
[docs]
def call(self, positions):
"""Calculate the OPD maps for the given positions.
Calculating: batch(spatial_dict) x alpha x S
Parameters
----------
positions: Tensor(batch, 2)
First element is x-axis, second is y-axis.
Returns
-------
opd_maps: Tensor(batch, opd_dim, opd_dim)
"""
# Add L1 loss of the graph alpha matrix
# self.add_loss(self.l1_rate * tf.math.reduce_sum(tf.math.abs(self.alpha_graph)))
# Try Lp norm with p=1.1
p = 1.1
self.add_loss(
self.l1_rate
* tf.math.pow(
tf.math.reduce_sum(tf.math.pow(tf.math.abs(self.alpha_graph), p)), 1 / p
)
)
def calc_index(idx_pos):
return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0]
# Calculate the indices of the input batch
indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64)
# Recover the spatial dict from the batch indexes
# Matrix multiplication dict*alpha
# Tensor product to calculate the contribution
# Polynomial contribution
batch_poly_dict = tf.gather(
self.poly_dic, indices=indices, axis=0, batch_dims=0
)
intermediate_poly = tf.linalg.matmul(batch_poly_dict, self.alpha_poly)
contribution_poly = tf.tensordot(intermediate_poly, self.S_poly, axes=1)
# Graph contribution
batch_graph_dict = tf.gather(
self.graph_dic, indices=indices, axis=0, batch_dims=0
)
intermediate_graph = tf.linalg.matmul(batch_graph_dict, self.alpha_graph)
contribution_graph = tf.tensordot(intermediate_graph, self.S_graph, axes=1)
return tf.math.add(contribution_poly, contribution_graph)
[docs]
class TFNonParametricGraphOPD(tf.keras.layers.Layer):
"""Non-parametric OPD generation with only graph-cosntraint variations.
Parameters
----------
obs_pos: tensor(n_stars, 2)
Observed positions of the `n_stars` in the dataset. The indexing of the
positions has to correspond to the indexing in the `spatial_dic`.
spatial_dic: tensor(n_stars, n_dic_elems)
Dictionary containing the spatial-constraint dictionary. `n_stars`
corresponds to the total number of stars in the dataset. `n_dic_elems`
corresponds to the number of elements of the dictionary, not to be
confounded with `n_comp`, the total number of non-parametric features
of the wavefront-PSF.
x_lims: [int, int]
Limits of the x axis.
y_lims: [int, int]
Limits of the y axis.
graph_comps: int
Number of wavefront-PSF features correspondign to the graph constraint.
d_max: int
Max degree of polynomial determining the FoV variations. The number of
wavefront-PSF features of the polynomial part is
computed `(d_max+1)*(d_max+2)/2`.
opd_dim: int
Dimension of the OPD maps. Same as pupil diameter.
"""
def __init__(
self,
obs_pos,
spatial_dic,
x_lims,
y_lims,
random_seed=None,
graph_features=6,
l1_rate=1e-5,
opd_dim=256,
name="TF_NP_GRAPH_OPD",
):
super().__init__(name=name)
# Parameters
self.x_lims = x_lims
self.y_lims = y_lims
self.random_seed = random_seed
self.opd_dim = opd_dim
# L1 regularisation parameter
self.l1_rate = l1_rate
self.obs_pos = obs_pos
self.poly_dic = spatial_dic[0]
self.graph_dic = spatial_dic[1]
self.n_stars = self.poly_dic.shape[0]
self.n_graph_elems = self.graph_dic.shape[1]
self.graph_features = graph_features
# Variables
self.S_graph = None
self.alpha_graph = None
self.init_vars()
[docs]
def init_vars(self):
"""Initialize trainable variables.
Basic initialization. Random uniform for S and identity for alpha.
"""
# S initialization
tf.random.set_seed(self.random_seed)
random_init = tf.random_uniform_initializer(minval=-0.001, maxval=0.001)
self.S_graph = tf.Variable(
initial_value=random_init(
shape=[self.graph_features, self.opd_dim, self.opd_dim]
),
trainable=True,
dtype=tf.float32,
)
# Alpha initialization
self.alpha_graph = tf.Variable(
initial_value=tf.eye(
num_rows=self.n_graph_elems, num_columns=self.graph_features
),
trainable=True,
dtype=tf.float32,
)
# Update random seed for next call
if self.random_seed is not None:
self.random_seed += 1
[docs]
def set_alpha_zero(self):
"""Set alpha matrix to zero."""
self.alpha_graph.assign(tf.zeros_like(self.alpha_graph, dtype=tf.float32))
[docs]
def set_alpha_identity(self):
"""Set alpha matrix to the identity."""
self.alpha_graph.assign(
tf.eye(
num_rows=self.n_graph_elems,
num_columns=self.graph_features,
dtype=tf.float32,
)
)
[docs]
def predict(self, positions):
"""Prediction step."""
## Graph part
A_graph_train = tf.linalg.matmul(self.graph_dic, self.alpha_graph)
# RBF interpolation
# Order 2 means a thin_plate RBF interpolation
# All tensors need to expand one dimension to fulfil requirement in
# the tfa's interpolate_spline function
A_interp_graph = tfa.image.interpolate_spline(
train_points=tf.expand_dims(self.obs_pos, axis=0),
train_values=tf.expand_dims(A_graph_train, axis=0),
query_points=tf.expand_dims(positions, axis=0),
order=2,
regularization_weight=0.0,
)
# Remove extra dimension required by tfa's interpolate_spline
A_interp_graph = tf.squeeze(A_interp_graph, axis=0)
interp_graph_opd = tf.tensordot(A_interp_graph, self.S_graph, axes=1)
return interp_graph_opd
[docs]
def call(self, positions):
"""Calculate the OPD maps for the given positions.
Calculating: batch(spatial_dict) x alpha x S
Parameters
----------
positions: Tensor(batch, 2)
First element is x-axis, second is y-axis.
Returns
-------
opd_maps: Tensor(batch, opd_dim, opd_dim)
"""
# Add L1 loss of the graph alpha matrix
# self.add_loss(
# self.l1_rate * tf.math.reduce_sum(tf.math.abs(self.alpha_graph))
# )
# Try Lp norm with p=1.1
p = 1.1
self.add_loss(
self.l1_rate
* tf.math.pow(
tf.math.reduce_sum(tf.math.pow(tf.math.abs(self.alpha_graph), p)), 1 / p
)
)
def calc_index(idx_pos):
return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0]
# Calculate the indices of the input batch
indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64)
# Recover the spatial dict from the batch indexes
# Matrix multiplication dict*alpha
# Tensor product to calculate the contribution
# Graph contribution
batch_graph_dict = tf.gather(
self.graph_dic, indices=indices, axis=0, batch_dims=0
)
intermediate_graph = tf.linalg.matmul(batch_graph_dict, self.alpha_graph)
contribution_graph = tf.tensordot(intermediate_graph, self.S_graph, axes=1)
return contribution_graph
[docs]
class TFPhysicalLayer(tf.keras.layers.Layer):
"""The Zernike physical layer.
This layer gives the Zernike contribution of the physical layer.
It is fixed and not trainable. It can interpolate the Zernike coefficients
at the input positions using different interpolation schemes.
Parameters
----------
obs_pos: Tensor (n_stars, 2)
Observed positions of the `n_stars` in the dataset. The indexing of the
positions has to correspond to the indexing in the `zks_prior`.
n_zernikes: int
Number of Zernike polynomials
zks_prior: Tensor (n_stars, n_zernikes)
Zernike coefficients for each position
interpolation_type: str
Type of interpolation to be used.
Options are: None, 'all', 'top_K', 'independent_Zk'.
Default is None.
interpolation_args: dict
Interpolation hyper-parameters. The order of the RBF interpolation,
and the K elements in the `top_K` interpolation.
"""
def __init__(
self,
obs_pos,
zks_prior,
interpolation_type=None,
interpolation_args=None,
name="TF_physical_layer",
):
super().__init__(name=name)
self.obs_pos = obs_pos
self.zks_prior = zks_prior
if interpolation_args is None:
self.interpolation_args = {"order": 2, "K": 50}
# Define the prediction routine by default
self.predict = self.call
# Define the prediction routine
if interpolation_type == "all":
self.predict = self.interpolate_all
elif interpolation_type == "top_K":
self.predict = self.interpolate_top_K
elif interpolation_type == "independent_Zk":
self.predict = self.interpolate_independent_Zk
[docs]
def interpolate_all(self, positions):
"""Interpolate using all the input elements.
The TensorFlow Addons function `tfa.image.interpolate_spline` is used
to perform the RBF interpolation of the Zernike coefficients.
Parameters
----------
positions : tf.Tensor
Tensor of shape (batch_size, 2) representing the positions.
The first element represents the x-axis, and the second element represents the y-axis.
Returns
-------
interp_zks : tf.Tensor
Tensor of shape (batch_size, n_zernikes, 1, 1) containing the interpolated Zernike coefficients
corresponding to the input positions.
"""
# RBF interpolation of prior Zernikes
# Order 2 means a thin_plate RBF interpolation
# All tensors need to expand one dimension to fulfil requirement in
# the tfa's interpolate_spline function
interp_zks = tfa.image.interpolate_spline(
train_points=tf.expand_dims(self.obs_pos, axis=0),
train_values=tf.expand_dims(self.zks_prior, axis=0),
query_points=tf.expand_dims(positions, axis=0),
order=self.interpolation_args["order"],
regularization_weight=0.0,
)
# Remove extra dimension required by tfa's interpolate_spline
interp_zks = tf.squeeze(interp_zks, axis=0)
return interp_zks[:, :, tf.newaxis, tf.newaxis]
[docs]
def interpolate_top_K(self, positions):
"""Interpolate using only the K closest elements.
The class wf.utils.ZernikeInterpolation allows to interpolate the Zernike
coefficients using only the K closest points to build the interpolant.
Parameters
----------
positions : tf.Tensor
Tensor of shape (batch_size, 2) representing the positions.
The first element represents the x-axis, and the second element represents the y-axis.
Returns
-------
interp_zks : tf.Tensor
Tensor of shape (batch_size, n_zernikes, 1, 1) containing the interpolated Zernike coefficients
corresponding to the input positions.
"""
zk_interpolator = utils.ZernikeInterpolation(
self.obs_pos,
self.zks_prior,
k=self.interpolation_args["K"],
order=self.interpolation_args["order"],
)
interp_zks = zk_interpolator.interpolate_zks(positions)
return interp_zks[:, :, tf.newaxis, tf.newaxis]
[docs]
def interpolate_independent_Zk(self, positions):
"""Interpolate each Zernike independently.
The class wf.utils.IndependentZernikeInterpolation allows to interpolate each
order of the Zernike polynomials independently using all the points avaialble to build
the interpolant.
Parameters
----------
positions : tf.Tensor
Tensor of shape (batch_size, 2) representing the positions.
The first element represents the x-axis, and the second element represents the y-axis.
Returns
-------
interp_zks : tf.Tensor
Tensor of shape (batch_size, n_zernikes, 1, 1) containing the interpolated Zernike coefficients
corresponding to the input positions.
"""
zk_interpolator = utils.IndependentZernikeInterpolation(
self.obs_pos, self.zks_prior, order=self.interpolation_args["order"]
)
interp_zks = zk_interpolator.interpolate_zks(positions)
return interp_zks[:, :, tf.newaxis, tf.newaxis]
[docs]
def call(self, positions):
"""Calculate the prior Zernike coefficients for a batch of positions.
This method calculates the Zernike coefficients for a batch of input positions
based on the pre-computed Zernike coefficients for observed positions.
Parameters
----------
positions : tf.Tensor
Tensor of shape (batch_size, 2) representing the positions.
The first element represents the x-axis, and the second element represents the y-axis.
Returns
-------
zernike_coeffs : tf.Tensor
Tensor of shape (batch_size, n_zernikes, 1, 1) containing the prior Zernike coefficients
corresponding to the input positions.
Notes
-----
The method retrieves the Zernike coefficients for each input position from the pre-computed
Zernike coefficients stored for observed positions. It matches each input position with
the closest observed position and retrieves the corresponding Zernike coefficients.
Before calling this method, ensure that the position polynomial matrix and the
corresponding Zernike coefficients have been precomputed and set for the layer.
Raises
------
ValueError
If the shape of the input `positions` tensor is not compatible.
"""
def calc_index(idx_pos):
return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0]
# Calculate the indices of the input batch
indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64)
# Recover the prior zernikes from the batch indexes
batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0)
return batch_zks[:, :, tf.newaxis, tf.newaxis]