"""Utility functions for the PSF simulation and modeling.
:Authors: Tobias Liaudat <tobias.liaudat@cea.fr>
"""
import numpy as np
import tensorflow as tf
import PIL
import zernike as zk
from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf
_HAS_CV2 = False
_HAS_SKIMAGE = False
try:
import cv2
_HAS_CV2 = True
except ImportError:
try:
from skimage.transform import downscale_local_mean
_HAS_SKIMAGE = True
except ImportError:
pass
[docs]
def scale_to_range(input_array, old_range, new_range):
# Scale to [0,1]
input_array = (input_array - old_range[0]) / (old_range[1] - old_range[0])
# Scale to new_range
input_array = input_array * (new_range[1] - new_range[0]) + new_range[0]
return input_array
[docs]
def ensure_batch(arr):
"""
Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N).
Parameters
----------
arr : np.ndarray or tf.Tensor
Input 2D or 3D array/tensor.
Returns
-------
np.ndarray or tf.Tensor
With batch dimension prepended if needed.
"""
if isinstance(arr, np.ndarray):
return arr if arr.ndim == 3 else np.expand_dims(arr, axis=0)
elif isinstance(arr, tf.Tensor):
return arr if arr.ndim == 3 else tf.expand_dims(arr, axis=0)
else:
raise TypeError(f"Expected np.ndarray or tf.Tensor, got {type(arr)}")
[docs]
def calc_wfe(zernike_basis, zks):
wfe = np.einsum("ijk,ijk->jk", zernike_basis, zks.reshape(-1, 1, 1))
return wfe
[docs]
def calc_wfe_rms(zernike_basis, zks, pupil_mask):
wfe = calc_wfe(zernike_basis, zks)
wfe_rms = np.sqrt(np.mean((wfe[pupil_mask] - np.mean(wfe[pupil_mask])) ** 2))
return wfe_rms
[docs]
def generalised_sigmoid(x, max_val=1, power_k=1):
"""
Apply a generalized sigmoid function to the input.
This function computes a smooth, S-shaped curve that generalizes the standard
sigmoid function. It's useful for scaling values while maintaining a bounded output.
Parameters
----------
x : array_like
Input value(s) to which the generalized sigmoid is applied.
max_val : float, optional
Maximum output value. Default is 1.
power_k : float, optional
Power parameter controlling the steepness of the curve.
Default is 1. Higher values create steeper transitions.
Returns
-------
array_like
Output value(s) scaled by the generalized sigmoid function,
bounded between -max_val and max_val.
Notes
-----
When power_k=1, this reduces to a standard rational sigmoid function.
The function is odd, meaning generalised_sigmoid(-x) = -generalised_sigmoid(x).
"""
return max_val * x / np.power(1 + np.power(np.abs(x), power_k), 1 / power_k)
[docs]
def single_mask_generator(shape):
"""Generate a single mask with random 2D cosine waves.
Note: These masks simulate the effect of cosmic rays on the observations.
Parameters
----------
shape: tuple
Shape of the mask to be generated.
Returns
-------
cosine_wave: np.ndarray
A 2D mask with random 2D cosine waves.
"""
# 2D meshgrid between 0.5 and 1
x, y = np.meshgrid(np.linspace(0.7, 1.2, shape[1]), np.linspace(0.6, 1.1, shape[0]))
# random pair of 2D frequencies, xy shifts and flip flag
fxy_list = [np.random.random(5) * 6 for _ in range(100)]
# 2D cosine waves
cosine_wave_list = [
np.cos(2 * np.pi * (fxy[0] * (x - fxy[2] / 50) + fxy[1] * (y - fxy[3] / 50)))
for fxy in fxy_list
]
# Sum of all cosine waves with random orientation
cosine_wave_tot = np.zeros_like(cosine_wave_list[0])
for cosine_wave, fxy in zip(cosine_wave_list, fxy_list):
if fxy[4] < 3:
cosine_wave = np.flipud(cosine_wave)
cosine_wave_tot += cosine_wave
# normalize
cosine_wave = cosine_wave_tot / np.max(cosine_wave_tot)
# detect values less than 0.6
return cosine_wave < 0.6
[docs]
def generate_n_mask(shape, n_masks=1):
"""Generate n masks with random 2D cosine waves.
A wrapper around single_mask_generator to generate multiple masks.
Parameters
----------
shape: tuple
Shape of the masks to be generated.
n_masks: int
Number of masks to be generated.
Returns
-------
np.ndarray
Array of shape (n_masks, shape[0], shape[1]) containing the generated masks.
"""
return np.array([single_mask_generator(shape) for _ in range(n_masks)])
[docs]
def generate_SED_elems(SED, psf_simulator, n_bins=20):
"""Generate SED elements for PSF modeling.
Computes feasible Zernike mode numbers, wavelength values, and normalized
SED for a given spectral energy distribution (SED) sampled across specified
wavelength bins. These elements are required for PSF simulation and modeling
with the TensorFlow-based PSF classes.
Parameters
----------
SED : np.ndarray
The unfiltered SED with shape (n_wavelengths, 2). The first column contains
wavelength positions (in wavelength units), and the second column contains
the corresponding SED flux values.
psf_simulator : PSFSimulator
An instance of the PSFSimulator class initialized with the correct
optical and instrumental parameters.
n_bins : int, optional
Number of wavelength bins to sample the SED. Default is 20.
Returns
-------
tuple of (np.ndarray, np.ndarray, np.ndarray or float)
- feasible_N : np.ndarray, shape (n_bins,)
Feasible Zernike mode numbers at each wavelength bin.
- feasible_wv : np.ndarray, shape (n_bins,)
Sampled wavelength values across the SED.
- SED_norm : np.ndarray or float
Normalized SED values corresponding to feasible wavelengths.
See Also
--------
generate_SED_elems_in_tensorflow : TensorFlow version of this function.
generate_packed_elems : Wrapper that converts output to TensorFlow tensors.
"""
feasible_wv, SED_norm = psf_simulator.calc_SED_wave_values(SED, n_bins)
feasible_N = np.array([psf_simulator.feasible_N(_wv) for _wv in feasible_wv])
return feasible_N, feasible_wv, SED_norm
[docs]
def generate_SED_elems_in_tensorflow(
SED, psf_simulator, n_bins=20, tf_dtype=tf.float64
):
"""Generate SED Elements in TensorFlow Units.
A function to generate the SED elements needed for using the
TensorFlow class: TF_poly_PSF.
Parameters
----------
SED : np.ndarray
The unfiltered SED. The first column contains the wavelength positions. The second column contains the SED value at each wavelength.
psf_simulator : PSFSimulator object
An instance of the PSFSimulator class with the correct initialization values.
n_bins : int
Number of wavelength bins
tf_dtype : tf.DType
The Tensor Flow dtype to cast each element to (for example `tf.float32`,
`tf.int32`, etc.).
Returns
-------
list of tf.Tensor
[feasible_N, feasible_wv, SED_norm]:
- feasible_N : tf.Tensor, shape (n_bins,), dtype tf_dtype
- feasible_wv : tf.Tensor, shape (n_bins,), dtype tf_dtype
- SED_norm : tf.Tensor, scalar or array, dtype tf_dtype
"""
feasible_wv, SED_norm = psf_simulator.calc_SED_wave_values(SED, n_bins)
feasible_N = np.array([psf_simulator.feasible_N(_wv) for _wv in feasible_wv])
return convert_to_tf([feasible_N, feasible_wv, SED_norm], tf_dtype)
[docs]
def convert_to_tf(data, tf_dtype):
"""
Convert a sequence of array-like objects to TensorFlow tensors with a specified dtype.
Parameters
----------
data : Iterable
An iterable (e.g., list, tuple) of array-like objects (numpy arrays, Python
lists/tuples, tf.Tensor, etc.) to be converted to TensorFlow tensors.
tf_dtype : tf.DType
The TensorFlow dtype to cast each element to (for example `tf.float32`,
`tf.int32`, etc.).
Returns
-------
list of tf.Tensor
A list where each element is the result of calling
`tf.convert_to_tensor` on the corresponding item from `data`, cast to
`tf_dtype`.
Raises
------
TypeError
If `data` is not an iterable. A `TypeError` may also be raised by
`tf.convert_to_tensor` for individual elements that cannot be converted.
Notes
-----
- The function preserves the top-level sequence structure by returning a
list regardless of the input sequence type.
- Element-wise conversion uses TensorFlow's conversion semantics; shape
inference and broadcasting follow TensorFlow rules.
"""
return [tf.convert_to_tensor(x, dtype=tf_dtype) for x in data]
[docs]
def generate_packed_elems(SED, psf_simulator, n_bins=20):
"""
Generate packed SED elements as TensorFlow tensors.
Wrapper around generate_SED_elems(...) that converts the returned NumPy
arrays into TensorFlow tensors with dtype=tf.float64.
Parameters
----------
SED : numpy.ndarray
The unfiltered SED with shape (n_wavelengths, 2). The first column contains the wavelength
positions (in wavelength units), and the second column contains the corresponding SED flux values.
psf_simulator : PSFSimulator object
An instance of the PSF simulator providing calc_SED_wave_values and feasible_N.
n_bins : int, optional
Number of wavelength bins used to sample the SED (default 20).
Returns
-------
list of tf.Tensor
[feasible_N, feasible_wv, SED_norm]:
- feasible_N : tf.Tensor, shape (n_bins,), dtype tf.float64
- feasible_wv : tf.Tensor, shape (n_bins,), dtype tf.float64
- SED_norm : tf.Tensor, scalar or array, dtype tf.float64
"""
feasible_N, feasible_wv, SED_norm = generate_SED_elems(
SED, psf_simulator, n_bins=n_bins
)
feasible_N = tf.convert_to_tensor(feasible_N, dtype=tf.float64)
feasible_wv = tf.convert_to_tensor(feasible_wv, dtype=tf.float64)
SED_norm = tf.convert_to_tensor(SED_norm, dtype=tf.float64)
# returns the packed tensors
return [feasible_N, feasible_wv, SED_norm]
[docs]
def calc_poly_position_mat(pos, x_lims, y_lims, d_max):
r"""Calculate a matrix with position polynomials.
Scale positions to the square:
[self.x_lims[0], self.x_lims[1]] x [self.y_lims[0], self.y_lims[1]]
to the square [-1,1] x [-1,1]
"""
# Scale positions
scaled_pos_x = (pos[:, 0] - x_lims[0]) / (x_lims[1] - x_lims[0])
scaled_pos_x = (scaled_pos_x - 0.5) * 2
scaled_pos_y = (pos[:, 1] - y_lims[0]) / (y_lims[1] - y_lims[0])
scaled_pos_y = (scaled_pos_y - 0.5) * 2
poly_list = []
for d in range(d_max + 1):
# row_idx = d * (d + 1) // 2
for p in range(d + 1):
poly_list.append(scaled_pos_x ** (d - p) * scaled_pos_y**p)
return tf.convert_to_tensor(poly_list, dtype=tf.float32)
[docs]
def decimate_im(input_im, decim_f):
r"""Decimate image.
Decimated by a factor of decim_f.
Based on the PIL library using the default interpolator.
Default: PIL.Image.BICUBIC.
"""
pil_im = PIL.Image.fromarray(input_im)
(width, height) = (pil_im.width // decim_f, pil_im.height // decim_f)
im_resized = pil_im.resize((width, height))
return np.array(im_resized)
[docs]
def downsample_im(input_im, output_dim):
"""Downsample image to (output_dim, output_dim).
Uses OpenCV INTER_AREA when available, otherwise falls back
to scikit-image local mean downsampling.
Parameters
----------
input_im : np.ndarray
Input 2D image to be downsampled.
output_dim : int
Desired output dimension (both height and width).
Returns
-------
np.ndarray
Downsampled 2D image of shape (output_dim, output_dim).
"""
if _HAS_CV2:
return cv2.resize(
input_im,
(int(output_dim), int(output_dim)),
interpolation=cv2.INTER_AREA,
)
if _HAS_SKIMAGE:
f_x = int(input_im.shape[0] / output_dim)
f_y = int(input_im.shape[1] / output_dim)
if f_x <= 0 or f_y <= 0:
raise ValueError("Invalid downsampling factors.")
return downscale_local_mean(
input_im,
factors=(f_x, f_y),
)
raise ImportError(
"Neither OpenCV nor scikit-image is available for image downsampling."
)
[docs]
def zernike_generator(n_zernikes, wfe_dim):
r"""
Generate Zernike maps.
Based on the zernike github repository.
https://github.com/jacopoantonello/zernike
Parameters
----------
n_zernikes: int
Number of Zernike modes desired.
wfe_dim: int
Dimension of the Zernike map [wfe_dim x wfe_dim].
Returns
-------
zernikes: list of np.ndarray
List containing the Zernike modes.
The values outside the unit circle are filled with NaNs.
"""
# Calculate which n (from the (n,m) Zernike convention) we need
# so that we have the desired total number of Zernike coefficients
min_n = (-3 + np.sqrt(1 + 8 * n_zernikes)) / 2
n = int(np.ceil(min_n))
# Initialize the zernike generator
cart = zk.RZern(n)
# Create a [-1,1] mesh
ddx = np.linspace(-1.0, 1.0, wfe_dim)
ddy = np.linspace(-1.0, 1.0, wfe_dim)
xv, yv = np.meshgrid(ddx, ddy)
cart.make_cart_grid(xv, yv)
c = np.zeros(cart.nk)
zernikes = []
# Extract each Zernike map one by one
for i in range(n_zernikes):
c *= 0.0
c[i] = 1.0
zernikes.append(cart.eval_grid(c, matrix=True))
return zernikes
[docs]
def add_noise(image, desired_SNR):
"""Add noise to an image to obtain a desired SNR."""
sigma_noise = np.sqrt(
(np.sum(image**2)) / (desired_SNR * image.shape[0] * image.shape[1])
)
noisy_image = image + np.random.standard_normal(image.shape) * sigma_noise
return noisy_image
[docs]
class NoiseEstimator:
"""
A class for estimating noise levels in an image.
Parameters
----------
img_dim : tuple of int
The dimensions of the image as (height, width).
win_rad : int
The radius of the exclusion window (in pixels).
"""
def __init__(self, img_dim: tuple[int, int], win_rad: int) -> None:
"""
Initialize a NoiseEstimator instance.
This constructor sets up the noise estimator by storing the image dimensions
and exclusion window radius, then initializes the exclusion window mask.
Parameters
----------
img_dim : tuple of int
The dimensions of the image as (height, width).
win_rad : int
The radius of the exclusion window in pixels. Pixels within this radius
of the image center are excluded from noise estimation.
Notes
-----
The exclusion window is initialized automatically via _init_window(), creating
a boolean mask where pixels inside the exclusion radius are marked False
(excluded) and pixels outside are marked True (included).
"""
self.img_dim: tuple[int, int] = img_dim
self.win_rad: int = win_rad
self._init_window() # Initialize self.window
def _init_window(self):
"""
Initialize the exclusion window mask stored in self.window.
The mask is a boolean array of shape `self.img_dim` (rows, cols). Pixels
whose Euclidean distance from the image center is less than or equal to
`self.win_rad` are marked False (excluded); all other pixels are True
(included). The mask dtype is `bool`.
Notes
-----
- The image center is computed as (rows / 2, cols / 2). This yields a
floating-point center so the distance is computed with sub-pixel
precision; for even dimensions the center lies between pixels.
- The comparison uses "<=" so pixels exactly at distance `win_rad` are
excluded. Change to "<" if you prefer a strict interior exclusion.
- Time complexity is O(rows * cols) for mask construction.
- No return value; the constructed mask is assigned to `self.window`.
"""
self.window = np.ones(self.img_dim, dtype=bool)
mid_x = self.img_dim[0] / 2
mid_y = self.img_dim[1] / 2
for _x in range(self.img_dim[0]):
for _y in range(self.img_dim[1]):
# If pixel is within the exclusion radius, set it to False
if np.sqrt((_x - mid_x) ** 2 + (_y - mid_y) ** 2) <= self.win_rad:
self.window[_x, _y] = False
[docs]
def apply_mask(self, mask: np.ndarray = None) -> np.ndarray:
"""
Apply a given mask to the exclusion window.
Parameters
----------
mask : np.ndarray, optional
A boolean mask to apply to the exclusion window. If None, the exclusion
window is returned without any modification.
Returns
-------
np.ndarray
The resulting boolean array after applying the mask to the exclusion window.
"""
if mask is None:
return self.window # Return just the window if no mask is provided
return self.window & mask # Otherwise, apply the mask as usual
[docs]
@staticmethod
def sigma_mad(x):
"""
Robustly estimate the standard deviation using the Median Absolute Deviation (MAD).
Computes MAD = ``median(|x - median(x)|)`` and scales it by 1.4826 to make the
estimator consistent with the standard deviation for a Gaussian distribution:
sigma ≈ 1.4826 * MAD
Parameters
----------
x : array-like
Input data. The values are flattened before computation. NaNs are not
specially handled and will propagate; remove or mask them prior to
calling if needed.
Returns
-------
float
Robust estimate of the standard deviation of the input data.
Notes
-----
- The MAD-based estimator is much less sensitive to outliers than the
sample standard deviation, making it appropriate for noisy data with
occasional large deviations.
- The constant 1.4826 is the scaling factor for consistency with the
standard deviation of a normal distribution.
"""
return 1.4826 * np.median(np.abs(x - np.median(x)))
[docs]
def estimate_noise(self, image: np.ndarray, mask: np.ndarray = None) -> float:
"""
Estimates the noise level of an image using the MAD estimator.
Parameters
----------
image : np.ndarray
The input image for noise estimation.
mask : np.ndarray, optional
A boolean mask specifying which pixels to include in the noise estimation.
If None, the default exclusion window is used. The mask should have the same shape as `image`.
Returns
-------
float
The estimated noise standard deviation (MAD of the image pixels within the window or mask).
"""
if mask is not None:
return self.sigma_mad(image[self.apply_mask(mask)])
# Use the default window if no mask is provided
return self.sigma_mad(image[self.window])
[docs]
class ZernikeInterpolation:
"""Interpolate Zernike coefficients using K-nearest RBF splines.
This class provides utilities to interpolate Zernike-coefficient vectors
defined at a set of source positions to arbitrary query positions using
a local RBF spline fitted to the K nearest source samples.
The interpolation pipeline:
- For a given query position, compute Euclidean distances to all source
positions and select the K nearest neighbors.
- Use tfa.image.interpolate_spline (RBF / spline interpolation) on the
selected neighbor positions and their Zernike coefficient vectors to
compute the interpolated coefficients at the query location.
Parameters
----------
tf_pos : tf.Tensor, shape (n_sources, 2)
Source/sample positions (x, y). Expected dtype float32 or convertible.
tf_zks : tf.Tensor, shape (n_sources, n_zernikes)
Zernike coefficient vectors at the source positions.
k : int, default 50
Number of nearest neighbors to use for the local interpolation. If
larger than the number of sources, all sources are used.
order : int, default 2
Spline order passed to tfa.image.interpolate_spline (e.g. 2 for thin
plate style interpolation).
Attributes
----------
tf_pos, tf_zks, k, order
Stored copies of the constructor inputs.
Notes
-----
- This class relies on TensorFlow Addons' interpolate_spline, which requires
inputs to include a leading batch dimension; the implementation handles
that automatically.
- For best numerical stability and compatibility with TFA, use float32
tensors for inputs when possible.
- Two main methods are provided:
- interpolate_zk(single_pos): interpolate a single position -> 1D vector.
- interpolate_zks(interp_positions): vectorized interpolation for many
query positions (uses tf.map_fn under the hood).
"""
def __init__(self, tf_pos, tf_zks, k=50, order=2):
self.tf_pos = tf_pos
self.tf_zks = tf_zks
self.k = k
self.order = order
[docs]
def interpolate_zk(self, single_pos):
"""Interpolate Zernike coefficients at a single query position using K-nearest neighbors.
Finds the K nearest training positions to the query position and uses RBF
spline interpolation to estimate Zernike coefficients at that location.
Parameters
----------
single_pos : tf.Tensor, shape (2,)
Query position coordinates as (x, y).
Returns
-------
tf.Tensor, shape (n_zernikes,)
Interpolated Zernike coefficient vector at the query position.
"""
# Compute distance
dist = tf.math.reduce_euclidean_norm(self.tf_pos - single_pos, axis=1) * -1.0
# Get top K elements
result = tf.math.top_k(dist, k=self.k)
# Gather useful elements from the array
rec_pos = tf.gather(
self.tf_pos,
result.indices,
validate_indices=None,
axis=0,
batch_dims=0,
)
rec_zks = tf.gather(
self.tf_zks,
result.indices,
validate_indices=None,
axis=0,
batch_dims=0,
)
# Interpolate
interp_zk = tfa_interpolate_spline_rbf(
train_points=tf.expand_dims(rec_pos, axis=0),
train_values=tf.expand_dims(rec_zks, axis=0),
query_points=tf.expand_dims(single_pos[tf.newaxis, :], axis=0),
order=self.order,
regularization_weight=0.0,
)
# Remove extra dimension required by tfa's interpolate_spline
interp_zk = tf.squeeze(interp_zk, axis=0)
return interp_zk
[docs]
def interpolate_zks(self, interp_positions):
"""Interpolate Zernike coefficient vectors at multiple query positions.
Vectorized wrapper that applies self.interpolate_zk to each row of
interp_positions using tf.map_fn.
Parameters
----------
interp_positions : tf.Tensor, shape (n_targets, 2)
Query positions where Zernike coefficients should be interpolated.
Each row is an (x, y) coordinate.
Returns
-------
tf.Tensor, shape (n_targets, n_zernikes), dtype=tf.float32
Interpolated Zernike coefficient vectors for each query position.
tf.map_fn may introduce an extra singleton dimension; this is removed
by tf.squeeze before returning.
Notes
-----
- self.interpolate_zk expects a 1-D tensor of shape (2,) and returns a
1-D tensor of length n_zernikes.
- This function uses tf.map_fn with fn_output_signature=tf.float32 and
swap_memory=True for efficient batching.
"""
interp_zks = tf.map_fn(
self.interpolate_zk,
interp_positions,
parallel_iterations=10,
fn_output_signature=tf.float32,
swap_memory=True,
)
return tf.squeeze(interp_zks, axis=1)
[docs]
class IndependentZernikeInterpolation:
"""Interpolate each Zernike polynomial independently.
The interpolation is done independently for each Zernike polynomial.
Parameters
----------
tf_pos: Tensor (n_sources, 2)
Positions
tf_zks: Tensor (n_sources, n_zernikes)
Zernike coefficients for each position
order: int
Order of the RBF interpolation.
Default is 2, corresponds to thin plate interp (r^2*log(r))
"""
def __init__(self, tf_pos, tf_zks, order=2):
self.tf_pos = tf_pos
self.tf_zks = tf_zks
self.order = order
self.target_pos = None
[docs]
def interp_one_zk(self, zk_prior):
"""
Interpolate a single Zernike polynomial across target positions.
Each Zernike coefficient in `zk_prior` is interpolated independently
using a spline.
Parameters
----------
zk_prior : tf.Tensor of shape (n_sources,)
Zernike coefficients for a single Zernike polynomial, defined at
the source positions `self.tf_pos`.
Returns
-------
tf.Tensor of shape (n_targets,)
Interpolated Zernike coefficients at the target positions
`self.target_pos`.
Notes
-----
This function uses `tfa.image.interpolate_spline`, which requires the
input to have a batch dimension. The extra dimension is removed before
returning the result.
"""
interp_zk = tfa.image.interpolate_spline(
train_points=tf.expand_dims(self.tf_pos, axis=0),
train_values=tf.expand_dims(zk_prior[:, tf.newaxis], axis=0),
query_points=tf.expand_dims(self.target_pos, axis=0),
order=self.order,
regularization_weight=0.0,
)
# Remove extra dimension required by tfa's interpolate_spline
return tf.squeeze(interp_zk, axis=0)
[docs]
def interpolate_zks(self, target_pos):
"""Vectorize to interpolate to each Zernike.
Each zernike is computed indepently from the others.
Parameters
----------
target_pos: Tensor (n_targets, 2)
Positions to interpolate to.
Returns
-------
Tensor (n_targets, n_zernikes)
"""
self.target_pos = target_pos
interp_zks = tf.map_fn(
self.interp_one_zk,
tf.transpose(self.tf_zks, perm=[1, 0]),
parallel_iterations=10,
fn_output_signature=tf.float32,
swap_memory=True,
)
# Remove null dimension and transpose back to have batch at input
return tf.transpose(tf.squeeze(interp_zks, axis=2), perm=[1, 0])
[docs]
def load_multi_cycle_params_click(args):
"""
Load multiple cycle training parameters.
For backwards compatibility, the training parameters are received as a string,
separated and stored in the args dictionary.
Parameters
----------
args: dictionary
Comand line arguments dictionary loaded with the click package.
Returns
-------
args: dictionary
The input dictionary with all multi-cycle training parameters correctly loaded.
"""
if args["l_rate_param"] is None:
args["l_rate_param"] = list(
map(float, args["l_rate_param_multi_cycle"].split(" "))
)
if len(args["l_rate_param"]) == 1:
args["l_rate_param"] = args["l_rate_param"] * args["total_cycles"]
elif len(args["l_rate_param"]) != args["total_cycles"]:
print(
"Invalid argument: --l_rate_param. Expected 1 or {} values but {} were given.".format(
args["total_cycles"], len(args["l_rate_param"])
)
)
sys.exit()
if args["l_rate_non_param"] is None:
args["l_rate_non_param"] = list(
map(float, args["l_rate_non_param_multi_cycle"].split(" "))
)
if len(args["l_rate_non_param"]) == 1:
args["l_rate_non_param"] = args["l_rate_non_param"] * args["total_cycles"]
elif len(args["l_rate_non_param"]) != args["total_cycles"]:
print(
"Invalid argument: --l_rate_non_param. Expected 1 or {} values but {} were given.".format(
args["total_cycles"], len(args["l_rate_non_param"])
)
)
sys.exit()
if args["n_epochs_param"] is None:
args["n_epochs_param"] = list(
map(int, args["n_epochs_param_multi_cycle"].split(" "))
)
if len(args["n_epochs_param"]) == 1:
args["n_epochs_param"] = args["n_epochs_param"] * args["total_cycles"]
elif len(args["n_epochs_param"]) != args["total_cycles"]:
print(
"Invalid argument: --n_epochs_param. Expected 1 or {} values but {} were given.".format(
args["total_cycles"], len(args["n_epochs_param"])
)
)
sys.exit()
if args["n_epochs_non_param"] is None:
args["n_epochs_non_param"] = list(
map(int, args["n_epochs_non_param_multi_cycle"].split(" "))
)
if len(args["n_epochs_non_param"]) == 1:
args["n_epochs_non_param"] = args["n_epochs_non_param"] * args["total_cycles"]
elif len(args["n_epochs_non_param"]) != args["total_cycles"]:
print(
"Invalid argument: --n_epochs_non_param. Expected 1 or {} values but {} were given.".format(
args["total_cycles"], len(args["n_epochs_non_param"])
)
)
sys.exit()
return args
[docs]
def compute_unobscured_zernike_projection(tf_z1, tf_z2, norm_factor=None):
"""Compute a zernike projection for unobscured wavefronts (OPDs).
Compute internal product between zernikes and OPDs.
Defined such that Zernikes are orthonormal to each other.
First one should compute: norm_factor = unobscured_zernike_projection(tf_zernike,tf_zernike)
for futur calls: unobscured_zernike_projection(OPD,tf_zernike_k, norm_factor)
If the OPD has obscurations, or is not an unobscured circular aperture,
the Zernike polynomials are no longer orthonormal. Therefore, you should consider
using the function `tf_decompose_obscured_opd_basis` that takes into account the
obscurations in the projection.
"""
if norm_factor is None:
norm_factor = 1
return np.sum((tf.math.multiply(tf_z1, tf_z2)).numpy()) / (norm_factor)
[docs]
def decompose_tf_obscured_opd_basis(
tf_opd, tf_obscurations, tf_zk_basis, n_zernike, iters=20
):
"""Decompose obscured OPD into a basis using an iterative algorithm.
Tensorflow implementation.
Parameters
----------
tf_opd : tf.Tensor
Input OPD that requires to be decomposed on `tf_zk_basis`. The tensor shape is (opd_dim, opd_dim).
tf_obscurations : tf.Tensor
Tensor with the obscuration map. The tensor shape is (opd_dim, opd_dim).
tf_zk_basis : tf.Tensor
Zernike polynomial maps. The tensor shape is (n_batch, opd_dim, opd_dim)
n_zernike : int
Number of Zernike polynomials to project on.
iters : int
Number of iterations of the algorithm.
Returns
-------
obsc_coeffs : np.ndarray
Array of size `n_zernike` with projected Zernike coefficients
Raises
------
ValueError
If `n_zernike` is bigger than tf_zk_basis.shape[0].
"""
if n_zernike > tf_zk_basis.shape[0]:
raise ValueError(
"Number of Zernike polynomials to project (n_zernike) exceeds the available Zernike elements in the provided basis (tf_zk_basis). Please ensure that n_zernike is less than or equal to the number of Zernike elements in tf_zk_basis."
)
# Clone input OPD
input_tf_opd = tf.identity(tf_opd)
# Clone obscurations and project
input_tf_obscurations = tf.math.real(tf.identity(tf_obscurations))
# Compute normalisation factor
ngood = tf.math.reduce_sum(input_tf_obscurations, axis=None, keepdims=False).numpy()
obsc_coeffs = np.zeros(n_zernike)
new_coeffs = np.zeros(n_zernike)
for count in range(iters):
for i, b in enumerate(tf_zk_basis):
this_coeff = (
tf.math.reduce_sum(
tf.math.multiply(input_tf_opd, b), axis=None, keepdims=False
).numpy()
/ ngood
)
new_coeffs[i] = this_coeff
for i, b in enumerate(tf_zk_basis):
input_tf_opd = input_tf_opd - tf.math.multiply(
new_coeffs[i] * b, input_tf_obscurations
)
obsc_coeffs += new_coeffs
new_coeffs = np.zeros(n_zernike)
return obsc_coeffs