"""Centroids.
A module with utils to handle PSF centroids.
:Author: Tobias Liaudat <tobias.liaudat@cea.fr> and Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import numpy as np
import scipy.signal as scisig
from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff
from typing import Optional
[docs]
def compute_zernike_tip_tilt(
star_images: np.ndarray,
star_masks: Optional[np.ndarray] = None,
pixel_sampling: float = 12e-6,
reference_shifts: list[float] = [-1/3, -1/3],
sigma_init: float = 2.5,
n_iter: int = 20,
) -> np.ndarray:
"""
Compute Zernike tip-tilt corrections for a batch of PSF images.
This function estimates the centroid shifts of multiple PSFs and computes
the corresponding Zernike tip-tilt corrections to align them with a reference.
Parameters
----------
star_images : np.ndarray
A batch of PSF images (3D array of shape `(num_images, height, width)`).
star_masks : np.ndarray, optional
A batch of masks (same shape as `star_postage_stamps`). Each mask can have:
- `0` to ignore the pixel.
- `1` to fully consider the pixel.
- Values in `(0,1]` as weights for partial consideration.
Defaults to None.
pixel_sampling : float, optional
The pixel size in meters. Defaults to `12e-6 m` (12 microns).
reference_shifts : list[float], optional
The target centroid shifts in pixels, specified as `[dy, dx]`.
Defaults to `[-1/3, -1/3]` (nominal Euclid conditions).
sigma_init : float, optional
Initial standard deviation for centroid estimation. Default is `2.5`.
n_iter : int, optional
Number of iterations for centroid refinement. Default is `20`.
Returns
-------
np.ndarray
An array of shape `(num_images, 2)`, where:
- Column 0 contains `Zk1` (tip) values.
- Column 1 contains `Zk2` (tilt) values.
Notes
-----
- This function processes all images at once using vectorized operations.
- The Zernike coefficients are computed in the WaveDiff convention.
"""
# Vectorize the centroid computation
centroid_estimator = CentroidEstimator(
im=star_images,
mask=star_masks,
sigma_init=sigma_init,
n_iter=n_iter
)
shifts = centroid_estimator.get_intra_pixel_shifts()
# Ensure reference_shifts is a NumPy array (if it's not already)
reference_shifts = np.array(reference_shifts)
# Reshape to ensure it's a column vector (1, 2)
reference_shifts = reference_shifts[None,:]
# Broadcast reference_shifts to match the shape of shifts
reference_shifts = np.broadcast_to(reference_shifts, shifts.shape)
# Compute displacements
displacements = (reference_shifts - shifts) #
# Ensure the correct axis order for displacements (x-axis, then y-axis)
displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary
# Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements
zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call
# Reshape the result back to the original shape of displacements
zk1_2_array = zk1_2_array.reshape(displacements.shape)
return zk1_2_array
[docs]
class CentroidEstimator:
"""
Calculate centroids and estimate intra-pixel shifts for a batch of star images.
This class estimates the centroid of each star in a batch of images using an
iterative process that fits an elliptical Gaussian model to the star images.
The estimated centroids are returned along with the intra-pixel shifts, which
represent the difference between the estimated centroid and the center of the
image grid (or pixel grid).
The process is vectorized, allowing multiple star images to be processed in
parallel, which significantly improves performance when working with large batches.
Parameters
----------
im : numpy.ndarray
A 3D numpy array of star image stamps. The shape of the array should be
(n_images, height, width), where n_images is the number of stars, and
height and width are the dimensions of each star's image.
mask : numpy.ndarray, optional
A 3D numpy array of the same shape as `im`, representing the mask for each star image.
A mask value of `0` indicates that the pixel is fully considered (unmasked), while a value of `1` means the pixel is completely ignored (masked).
Values between `0` and `1` act as weights, allowing partial consideration of the pixel.
If not provided, no mask is applied.
sigma_init : float, optional
The initial guess for the standard deviation (sigma) of the elliptical Gaussian
that models the star. Default is 7.5.
n_iter : int, optional
The number of iterations for the iterative centroid estimation procedure.
Default is 5.
auto_run : bool, optional
If True, the centroid estimation procedure will be automatically run upon
initialization. Default is True.
xc : float, optional
The initial guess for the x-component of the centroid. If None, it is set
to the center of the image. Default is None.
yc : float, optional
The initial guess for the y-component of the centroid. If None, it is set
to the center of the image. Default is None.
Attributes
----------
xc : numpy.ndarray
The x-components of the estimated centroids for each image. Shape is (n_images,).
yc : numpy.ndarray
The y-components of the estimated centroids for each image. Shape is (n_images,).
Methods
-------
update_grid()
Updates the grid of pixel positions based on the current centroid estimates.
elliptical_gaussian(e1=0, e2=0)
Computes an elliptical 2D Gaussian with the specified shear parameters.
compute_moments()
Computes the first-order moments of the star images and updates the centroid estimates.
estimate()
Runs the iterative centroid estimation procedure for all images.
get_centroids()
Returns the estimated centroids for all images as a 2D numpy array.
get_intra_pixel_shifts()
Gets the intra-pixel shifts for all images as a list of x and y displacements.
Notes
-----
The iterative centroid estimation procedure fits an elliptical Gaussian to each
star image and computes the centroid by calculating the weighted moments. The
`estimate()` method performs the centroid calculation for a batch of images using
the iterative approach defined by the `n_iter` parameter. This class is designed
to be efficient and scalable when processing large batches of star images.
"""
def __init__(self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None):
"""Initialize class attributes."""
self.im = im
self.mask = mask
if self.mask is not None:
self.im = self.im * (1 - self.mask)
self.stamp_size = im.shape[1:]
self.sigma_init = sigma_init
self.n_iter = n_iter
self.xc0, self.yc0 = (
float(self.stamp_size[0]) / 2,
float(self.stamp_size[1]) / 2,
)
self.xc = np.full((self.im.shape[0],), self.xc0)
self.yc = np.full((self.im.shape[0],), self.yc0)
if auto_run:
self.estimate()
[docs]
def update_grid(self):
"""Vectorized update of the grid coordinates for multiple star stamps."""
num_images, Nx, Ny = self.im.shape # Extract dimensions
x_range = np.arange(Nx)
y_range = np.arange(Ny)
# Correct subtraction without mixing axes
self.xx = (x_range - self.xc[:, None])
self.yy = (y_range - self.yc[:, None])
# Now, expand to the correct shape (num_images, Nx, Ny)
# Add the extra dimension for the number of stars
self.xx = self.xx[:, :, None] # Shape: (num_images, Nx, 1)
self.yy = self.yy[:, None, :] # Shape: (num_images, 1, Ny)
self.xx = np.broadcast_to(self.xx, (num_images, Nx, Ny))
self.yy = np.broadcast_to(self.yy, (num_images, Nx, Ny))
[docs]
def elliptical_gaussian(self, e1=0, e2=0):
"""Compute an elliptical 2D Gaussian with arbitrary centroid."""
# Shear the grid coordinates
gxx = (1 - e1) * self.xx - e2 * self.yy
gyy = (1 + e1) * self.yy - e2 * self.xx
# Compute elliptical Gaussian
return np.exp(-(gxx**2 + gyy**2) / (2 * self.sigma_init**2))
[docs]
def compute_moments(self):
"""Compute the moments for multiple PSFs at once."""
if self.mask is not None:
masked_im_window = self.im * self.window * (self.mask == 0)
else:
masked_im_window = self.im * self.window
Q0 = np.sum(masked_im_window, axis=(1, 2)) # Sum over images and their pixels
Q1 = np.array(
[
np.sum(np.sum(masked_im_window, axis=2 - i) * np.arange(self.stamp_size[i]), axis=1)
for i in range(2)
]
)
self.xc = Q1[0] / Q0
self.yc = Q1[1] / Q0
[docs]
def estimate(self):
"""Estimate centroids for all images."""
for _ in range(self.n_iter):
self.update_grid()
self.window = self.elliptical_gaussian()
# Calculate weighted moments.
self.compute_moments()
return self.xc, self.yc
[docs]
def get_centroids(self):
"""Return centroids for all images."""
return np.array([self.xc, self.yc])
[docs]
def get_intra_pixel_shifts(self):
"""Get intra-pixel shifts for all images.
Intra-pixel shifts are the differences between the estimated centroid and the center of the image stamp (or pixel grid). These shifts are calculated for all images in the batch.
Returns
-------
np.array
A 2D array of shape (num_of_images, 2), where each row corresponds to the x and y shifts for each image.
"""
shifts = np.stack([self.xc - self.xc0, self.yc - self.yc0], axis=-1)
return shifts
[docs]
def shift_ker_stack(shifts, upfact, lanc_rad=8):
r"""Generate shifting kernels and rotated shifting kernels."""
# lanc_rad = np.ceil(np.max(3*sigmas)).astype(int)
shap = shifts.shape
var_shift_ker_stack = np.zeros((2 * lanc_rad + 1, 2 * lanc_rad + 1, shap[0]))
var_shift_ker_stack_adj = np.zeros((2 * lanc_rad + 1, 2 * lanc_rad + 1, shap[0]))
for i in range(0, shap[0]):
uin = shifts[i, :].reshape((1, 2)) * upfact
var_shift_ker_stack[:, :, i] = lanczos(uin, n=lanc_rad)
var_shift_ker_stack_adj[:, :, i] = np.rot90(var_shift_ker_stack[:, :, i], 2)
return var_shift_ker_stack, var_shift_ker_stack_adj
[docs]
def lanczos(U, n=10, n2=None):
r"""Generate Lanczos kernel for a given shift."""
if n2 is None:
n2 = n
siz = np.size(U)
if siz == 2:
U_in = np.copy(U)
if len(U.shape) == 1:
U_in = np.zeros((1, 2))
U_in[0, 0] = U[0]
U_in[0, 1] = U[1]
H = np.zeros((2 * n + 1, 2 * n2 + 1))
if (U_in[0, 0] == 0) and (U_in[0, 1] == 0):
H[n, n2] = 1
else:
i = 0
j = 0
for i in range(0, 2 * n + 1):
for j in range(0, 2 * n2 + 1):
H[i, j] = (
np.sinc(U_in[0, 0] - (i - n))
* np.sinc((U_in[0, 0] - (i - n)) / n)
* np.sinc(U_in[0, 1] - (j - n))
* np.sinc((U_in[0, 1] - (j - n)) / n)
)
else:
H = np.zeros((2 * n + 1,))
for i in range(0, 2 * n):
H[i] = np.sinc(np.pi * (U - (i - n))) * np.sinc(np.pi * (U - (i - n)) / n)
return H
[docs]
def degradation_op(X, shift_ker, D):
r"""Shift and decimate fine-grid image."""
return decim(scisig.fftconvolve(X, shift_ker, mode="same"), D, av_en=0)
[docs]
def decim(im, d, av_en=1, fft=1):
r"""Decimate image to lower resolution."""
im_filt = np.copy(im)
im_d = np.copy(im)
if d > 1:
if av_en == 1:
siz = d + 1 - (d % 2)
mask = np.ones((siz, siz)) / siz**2
if fft == 1:
im_filt = scisig.fftconvolve(im, mask, mode="same")
else:
im_filt = scisig.convolve(im, mask, mode="same")
n1 = int(np.floor(im.shape[0] / d))
n2 = int(np.floor(im.shape[1] / d))
im_d = np.zeros((n1, n2))
i, j = 0, 0
for i in range(0, n1):
for j in range(0, n2):
im_d[i, j] = im[i * d, j * d]
if av_en == 1:
return im_filt, im_d
else:
return im_d