"""Metrics.
A module which contains the specific functions
for performing various sets of metrics evaluation
of the trained psf model.
:Author: Tobias Liaudat <tobias.liaudat@cea.fr>
"""
import numpy as np
import tensorflow as tf
import galsim as gs
import wf_psf.utils.utils as utils
from wf_psf.psf_models.psf_models import build_PSF_model
from wf_psf.sims import psf_simulator as psf_simulator
import logging
logger = logging.getLogger(__name__)
[docs]
def compute_poly_metric(
tf_semiparam_field,
gt_tf_semiparam_field,
simPSF_np,
tf_pos,
tf_SEDs,
n_bins_lda=20,
n_bins_gt=20,
batch_size=16,
dataset_dict=None,
mask=False,
):
"""Calculate metrics for polychromatic reconstructions.
The ``tf_semiparam_field`` should be the model to evaluate, and the
``gt_tf_semiparam_field`` should be loaded with the ground truth PSF field.
Relative values returned in [%] (so multiplied by 100).
Parameters
----------
tf_semiparam_field: PSF field object
Trained model to evaluate.
gt_tf_semiparam_field: PSF field object
Ground truth model to produce gt observations at any position
and wavelength.
simPSF_np: PSF simulator object
Simulation object to be used by ``generate_packed_elems`` function.
tf_pos: Tensor or numpy.ndarray [batch x 2] floats
Positions to evaluate the model.
tf_SEDs: numpy.ndarray [batch x SED_samples x 2]
SED samples for the corresponding positions.
n_bins_lda: int
Number of wavelength bins to use for the polychromatic PSF.
n_bins_gt: int
Number of wavelength bins to use for the ground truth polychromatic PSF.
batch_size: int
Batch size for the PSF calcualtions.
dataset_dict: dict
Dictionary containing the dataset information. If provided, and if the `'stars'` key
is present, the noiseless stars from the dataset are used to compute the metrics.
Otherwise, the stars are generated from the gt model.
Default is `None`.
mask: bool
If `True`, predictions are masked using the same mask as the target data, ensuring
that metric calculations consider only unmasked regions.
Default is `False`.
Returns
-------
rmse: float
RMSE value.
rel_rmse: float
Relative RMSE value. Values in %.
std_rmse: float
Standard deviation of RMSEs.
std_rel_rmse: float
Standard deviation of relative RMSEs. Values in %.
"""
# Generate SED data list for the model
packed_SED_data = [
utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_lda)
for _sed in tf_SEDs
]
tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32)
tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1])
pred_inputs = [tf_pos, tf_packed_SED_data]
# Model prediction
preds = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size)
# Ground truth data preparation
if dataset_dict is None or "stars" not in dataset_dict:
logger.info(
"No precomputed ground truth stars found. Regenerating from the ground truth model using configured interpolation settings."
)
# Change interpolation parameters for the ground truth simPSF
simPSF_np.SED_interp_pts_per_bin = 0
simPSF_np.SED_sigma = 0
# Generate SED data list for gt model
packed_SED_data = [
utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_gt)
for _sed in tf_SEDs
]
tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32)
tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1])
pred_inputs = [tf_pos, tf_packed_SED_data]
# Ground Truth model prediction
gt_preds = gt_tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size)
else:
logger.info("Using precomputed ground truth stars from dataset_dict['stars'].")
gt_preds = dataset_dict["stars"]
# If the data is masked, mask the predictions
if mask:
logger.info(
"Applying masks to predictions. Only unmasked regions will be considered for metric calculations."
)
masks = 1 - dataset_dict["masks"]
# Ensure masks as float dtype
masks = masks.astype(preds.dtype)
# Weight the mse by the number of unmasked pixels
weights = np.sum(masks, axis=(1, 2))
# Avoid divide by zero
weights = np.maximum(weights, 1e-7)
# Mask the predictions and ground truth/observations
preds = preds * masks
gt_preds = gt_preds * masks
else:
weights = np.ones(gt_preds.shape[0]) * gt_preds.shape[1] * gt_preds.shape[2]
# Calculate residuals
residuals = np.sqrt(np.sum((gt_preds - preds) ** 2, axis=(1, 2)) / weights)
gt_star_mean = np.sqrt(np.sum((gt_preds) ** 2, axis=(1, 2)) / weights)
# RMSE calculations
rmse = np.mean(residuals)
rel_rmse = 100.0 * np.mean(residuals / gt_star_mean)
# STD calculations
std_rmse = np.std(residuals)
std_rel_rmse = 100.0 * np.std(residuals / gt_star_mean)
# Print RMSE values
logger.info("Absolute RMSE:\t %.4e \t +/- %.4e", rmse, std_rmse)
logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, std_rel_rmse)
return rmse, rel_rmse, std_rmse, std_rel_rmse
[docs]
def compute_mono_metric(
tf_semiparam_field,
gt_tf_semiparam_field,
simPSF_np,
tf_pos,
lambda_list,
batch_size=32,
):
"""Calculate metrics for monochromatic reconstructions.
The ``tf_semiparam_field`` should be the model to evaluate, and the
``gt_tf_semiparam_field`` should be loaded with the ground truth PSF field.
Relative values returned in [%] (so multiplied by 100).
Parameters
----------
tf_semiparam_field: PSF field object
Trained model to evaluate.
gt_tf_semiparam_field: PSF field object
Ground truth model to produce gt observations at any position
and wavelength.
simPSF_np: PSF simulator object
Simulation object capable of calculating ``phase_N`` values from
wavelength values.
tf_pos: list of floats [batch x 2]
Positions to evaluate the model.
lambda_list: list of floats [wavelength_values]
List of wavelength values in [um] to evaluate the model.
batch_size: int
Batch size to process the monochromatic PSF calculations.
Returns
-------
rmse_lda: list of float
List of RMSE as a function of wavelength.
rel_rmse_lda: list of float
List of relative RMSE as a function of wavelength. Values in %.
std_rmse_lda: list of float
List of standard deviation of RMSEs as a function of wavelength.
std_rel_rmse_lda: list of float
List of standard deviation of relative RMSEs as a function of
wavelength. Values in %.
"""
# Initialise lists
rmse_lda = []
rel_rmse_lda = []
std_rmse_lda = []
std_rel_rmse_lda = []
total_samples = tf_pos.shape[0]
# Main loop for each wavelength
for it in range(len(lambda_list)):
# Set the lambda (wavelength) and the required wavefront N
lambda_obs = lambda_list[it]
phase_N = simPSF_np.feasible_N(lambda_obs)
residuals = np.zeros(total_samples)
gt_star_mean = np.zeros(total_samples)
# Total number of epochs
n_epochs = int(np.ceil(total_samples / batch_size))
ep_low_lim = 0
for ep in range(n_epochs):
# Define the upper limit
if ep_low_lim + batch_size >= total_samples:
ep_up_lim = total_samples
else:
ep_up_lim = ep_low_lim + batch_size
# Extract the batch
batch_pos = tf_pos[ep_low_lim:ep_up_lim, :]
# Estimate the monochromatic PSFs
gt_mono_psf = gt_tf_semiparam_field.predict_mono_psfs(
input_positions=batch_pos, lambda_obs=lambda_obs, phase_N=phase_N
)
model_mono_psf = tf_semiparam_field.predict_mono_psfs(
input_positions=batch_pos, lambda_obs=lambda_obs, phase_N=phase_N
)
num_pixels = gt_mono_psf.shape[1] * gt_mono_psf.shape[2]
residuals[ep_low_lim:ep_up_lim] = (
np.sum((gt_mono_psf - model_mono_psf) ** 2, axis=(1, 2)) / num_pixels
)
gt_star_mean[ep_low_lim:ep_up_lim] = (
np.sum((gt_mono_psf) ** 2, axis=(1, 2)) / num_pixels
)
# Increase lower limit
ep_low_lim += batch_size
# Calculate residuals
residuals = np.sqrt(residuals)
gt_star_mean = np.sqrt(gt_star_mean)
# RMSE calculations
rmse_lda.append(np.mean(residuals))
rel_rmse_lda.append(100.0 * np.mean(residuals / gt_star_mean))
# STD calculations
std_rmse_lda.append(np.std(residuals))
std_rel_rmse_lda.append(100.0 * np.std(residuals / gt_star_mean))
return rmse_lda, rel_rmse_lda, std_rmse_lda, std_rel_rmse_lda
[docs]
def compute_opd_metrics(tf_semiparam_field, gt_tf_semiparam_field, pos, batch_size=16):
"""Compute the OPD metrics.
Need to handle a batch size to avoid Out-Of-Memory errors with
the GPUs. This is specially due to the fact that the OPD maps
have a higher dimensionality than the observed PSFs.
The OPD RMSE is computed after having removed the mean from the
different reconstructions. It is computed only on the
non-obscured elements from the OPD.
Parameters
----------
tf_semiparam_field: PSF field object
Trained model to evaluate.
gt_tf_semiparam_field: PSF field object
Ground truth model to produce gt observations at any position
and wavelength.
pos: numpy.ndarray [batch x 2]
Positions at where to predict the OPD maps.
batch_size: int
Batch size to process the OPD calculations.
Returns
-------
rmse: float
Absolute RMSE value.
rel_rmse: float
Relative RMSE value.
rmse_std: float
Absolute RMSE standard deviation.
rel_rmse_std: float
Relative RMSE standard deviation.
"""
# Get OPD obscurations
np_obscurations = np.real(gt_tf_semiparam_field.obscurations.numpy())
# Define total number of samples
n_samples = pos.shape[0]
# Initialise batch variables
opd_batch = None
gt_opd_batch = None
counter = 0
# Initialise result lists
rmse_vals = np.zeros(n_samples)
rel_rmse_vals = np.zeros(n_samples)
while counter < n_samples:
# Calculate the batch end element
if counter + batch_size <= n_samples:
end_sample = counter + batch_size
else:
end_sample = n_samples
# Define the batch positions
batch_pos = pos[counter:end_sample, :]
# We calculate a batch of OPDs
opd_batch = tf_semiparam_field.predict_opd(batch_pos).numpy()
gt_opd_batch = gt_tf_semiparam_field.predict_opd(batch_pos).numpy()
# Remove the mean of the OPD
opd_batch -= np.mean(opd_batch, axis=(1, 2)).reshape(-1, 1, 1)
gt_opd_batch -= np.mean(gt_opd_batch, axis=(1, 2)).reshape(-1, 1, 1)
# Obscure the OPDs
opd_batch *= np_obscurations
gt_opd_batch *= np_obscurations
# Generate obscuration mask
obsc_mask = np_obscurations > 0
nb_mask_elems = np.sum(obsc_mask)
# Compute the OPD RMSE with the masked obscurations
res_opd = np.sqrt(
np.array(
[
np.sum((im1[obsc_mask] - im2[obsc_mask]) ** 2) / nb_mask_elems
for im1, im2 in zip(opd_batch, gt_opd_batch)
]
)
)
gt_opd_mean = np.sqrt(
np.array(
[np.sum(im2[obsc_mask] ** 2) / nb_mask_elems for im2 in gt_opd_batch]
)
)
# RMSE calculations
rmse_vals[counter:end_sample] = res_opd
rel_rmse_vals[counter:end_sample] = 100.0 * (res_opd / gt_opd_mean)
# Add the results to the lists
counter += batch_size
# Calculate final values
rmse = np.mean(rmse_vals)
rel_rmse = np.mean(rel_rmse_vals)
rmse_std = np.std(rmse_vals)
rel_rmse_std = np.std(rel_rmse_vals)
# Print RMSE values
logger.info("Absolute RMSE:\t %.4e % \t +/- %.4e %", rmse, rmse_std)
logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, rel_rmse_std)
return rmse, rel_rmse, rmse_std, rel_rmse_std
[docs]
def compute_shape_metrics(
tf_semiparam_field,
gt_tf_semiparam_field,
simPSF_np,
SEDs,
tf_pos,
n_bins_lda,
n_bins_gt,
output_Q=1,
output_dim=64,
batch_size=16,
opt_stars_rel_pix_rmse=False,
dataset_dict=None,
):
"""Compute the pixel, shape and size RMSE of a PSF model.
This is done at a specific sampling and output image dimension.
It is done for polychromatic PSFs so SEDs are needed.
Parameters
----------
tf_semiparam_field: PSF field object
Trained model to evaluate.
gt_tf_semiparam_field: PSF field object
Ground truth model to produce gt observations at any position
and wavelength.
simPSF_np:
SEDs: numpy.ndarray [batch x SED_samples x 2]
SED samples for the corresponding positions.
tf_pos: Tensor [batch x 2]
Positions at where to predict the PSFs.
n_bins_lda: int
Number of wavelength bins to use for the polychromatic PSF.
n_bins_gt: int
Number of wavelength bins to use for the ground truth polychromatic PSF.
output_Q: int
Downsampling rate to match the specified telescope's sampling. The value
of `output_Q` should be equal to `oversampling_rate` in order to have
the right pixel sampling corresponding to the telescope characteristics
`pix_sampling`, `tel_diameter`, `tel_focal_length`. The final
oversampling obtained is `oversampling_rate/output_Q`.
Default is `1`, so the output psf will be super-resolved by a factor of
`oversampling_rate`. TLDR: better use `1` and measure shapes on the
super-resolved PSFs.
output_dim: int
Output dimension of the square PSF stamps.
batch_size: int
Batch size to process the PSF estimations.
opt_stars_rel_pix_rmse: bool
If `True`, the relative pixel RMSE of each star is added to ther saving dictionary.
The summary statistics are always computed.
Default is `False`.
dataset_dict: dict
Dictionary containing the dataset information. If provided, and if the `'super_res_stars'`
key is present, the noiseless super resolved stars from the dataset are used to compute
the metrics. Otherwise, the stars are generated from the gt model.
Default is `None`.
Returns
-------
result_dict: dict
Dictionary with all the results.
"""
# Save original output_Q and output_dim
original_out_Q = tf_semiparam_field.output_Q
original_out_dim = tf_semiparam_field.output_dim
gt_original_out_Q = gt_tf_semiparam_field.output_Q
gt_original_out_dim = gt_tf_semiparam_field.output_dim
# Set the required output_Q and output_dim parameters in the models
tf_semiparam_field.set_output_Q(output_Q=output_Q, output_dim=output_dim)
gt_tf_semiparam_field.set_output_Q(output_Q=output_Q, output_dim=output_dim)
# Need to compile the models again
tf_semiparam_field = build_PSF_model(tf_semiparam_field)
gt_tf_semiparam_field = build_PSF_model(gt_tf_semiparam_field)
# Generate SED data list
packed_SED_data = [
utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_lda) for _sed in SEDs
]
# Prepare inputs
tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32)
tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1])
pred_inputs = [tf_pos, tf_packed_SED_data]
# PSF model
predictions = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size)
# Ground truth data preparation
if dataset_dict is None or (
"super_res_stars" not in dataset_dict and "SR_stars" not in dataset_dict
):
logger.info(
"No pre-computed super-resolved ground truth stars found. Regenerating ground truth super resolved stars from the ground-truth model using configured interpolation settings."
)
# Change interpolation parameters for the ground truth simPSF
simPSF_np.SED_interp_pts_per_bin = 0
simPSF_np.SED_sigma = 0
# Generate SED data list for gt model
packed_SED_data = [
utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_gt)
for _sed in SEDs
]
# Prepare inputs
tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32)
tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1])
pred_inputs = [tf_pos, tf_packed_SED_data]
# Ground Truth model
gt_predictions = gt_tf_semiparam_field.predict(
x=pred_inputs, batch_size=batch_size
)
else:
logger.info("Using precomputed super-resolved ground truth stars from dataset.")
if "super_res_stars" in dataset_dict:
gt_predictions = dataset_dict["super_res_stars"]
elif "SR_stars" in dataset_dict:
gt_predictions = dataset_dict["SR_stars"]
# Calculate residuals
residuals = np.sqrt(np.mean((gt_predictions - predictions) ** 2, axis=(1, 2)))
gt_star_mean = np.sqrt(np.mean((gt_predictions) ** 2, axis=(1, 2)))
# Pixel RMSE for each star
if opt_stars_rel_pix_rmse:
stars_rel_pix_rmse = 100.0 * residuals / gt_star_mean
# RMSE calculations
pix_rmse = np.mean(residuals)
rel_pix_rmse = 100.0 * np.mean(residuals / gt_star_mean)
# STD calculations
pix_rmse_std = np.std(residuals)
rel_pix_rmse_std = 100.0 * np.std(residuals / gt_star_mean)
# Print pixel RMSE values
logger.info(
f"\nPixel star absolute RMSE:\t {pix_rmse:.4e} \t +/- {pix_rmse_std:.4e} "
)
logger.info(
f"Pixel star relative RMSE:\t {rel_pix_rmse:.4e} % \t +/- {rel_pix_rmse_std:.4e} %"
)
# Measure shapes of the reconstructions
pred_moments = [
gs.hsm.FindAdaptiveMom(gs.Image(_pred), strict=False) for _pred in predictions
]
# Measure shapes of the reconstructions
gt_pred_moments = [
gs.hsm.FindAdaptiveMom(gs.Image(_pred), strict=False)
for _pred in gt_predictions
]
pred_e1_HSM, pred_e2_HSM, pred_R2_HSM = [], [], []
gt_pred_e1_HSM, gt_pred_e2_HSM, gt_pred_R2_HSM = [], [], []
for it in range(len(gt_pred_moments)):
if (
pred_moments[it].moments_status == 0
and gt_pred_moments[it].moments_status == 0
):
pred_e1_HSM.append(pred_moments[it].observed_shape.g1)
pred_e2_HSM.append(pred_moments[it].observed_shape.g2)
pred_R2_HSM.append(2 * (pred_moments[it].moments_sigma ** 2))
gt_pred_e1_HSM.append(gt_pred_moments[it].observed_shape.g1)
gt_pred_e2_HSM.append(gt_pred_moments[it].observed_shape.g2)
gt_pred_R2_HSM.append(2 * (gt_pred_moments[it].moments_sigma ** 2))
pred_e1_HSM = np.array(pred_e1_HSM)
pred_e2_HSM = np.array(pred_e2_HSM)
pred_R2_HSM = np.array(pred_R2_HSM)
gt_pred_e1_HSM = np.array(gt_pred_e1_HSM)
gt_pred_e2_HSM = np.array(gt_pred_e2_HSM)
gt_pred_R2_HSM = np.array(gt_pred_R2_HSM)
# Calculate metrics
# e1
e1_res = gt_pred_e1_HSM - pred_e1_HSM
e1_res_rel = (gt_pred_e1_HSM - pred_e1_HSM) / gt_pred_e1_HSM
rmse_e1 = np.sqrt(np.mean(e1_res**2))
rel_rmse_e1 = 100.0 * np.sqrt(np.mean(e1_res_rel**2))
std_rmse_e1 = np.std(e1_res)
std_rel_rmse_e1 = 100.0 * np.std(e1_res_rel)
# e2
e2_res = gt_pred_e2_HSM - pred_e2_HSM
e2_res_rel = (gt_pred_e2_HSM - pred_e2_HSM) / gt_pred_e2_HSM
rmse_e2 = np.sqrt(np.mean(e2_res**2))
rel_rmse_e2 = 100.0 * np.sqrt(np.mean(e2_res_rel**2))
std_rmse_e2 = np.std(e2_res)
std_rel_rmse_e2 = 100.0 * np.std(e2_res_rel)
# R2
R2_res = gt_pred_R2_HSM - pred_R2_HSM
rmse_R2_meanR2 = np.sqrt(np.mean(R2_res**2)) / np.mean(gt_pred_R2_HSM)
std_rmse_R2_meanR2 = np.std(R2_res / gt_pred_R2_HSM)
# Print shape/size errors
logger.info(f"\nsigma(e1) RMSE =\t\t {rmse_e1:.4e} \t +/- {std_rmse_e1:.4e} ")
logger.info(f"sigma(e2) RMSE =\t\t {rmse_e2:.4e} \t +/- {std_rmse_e2:.4e} ")
logger.info(
f"sigma(R2)/<R2> =\t\t {rmse_R2_meanR2:.4e} \t +/- {std_rmse_R2_meanR2:.4e} "
)
# Print relative shape/size errors
logger.info(
f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e} % \t +/- {std_rel_rmse_e1:.4e} %"
)
logger.info(
f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e} % \t +/- {std_rel_rmse_e2:.4e} %"
)
# Print number of stars
logger.info(f"\nTotal number of stars: \t\t {len(gt_pred_moments)}")
logger.info(
f"Problematic number of stars: \t {len(gt_pred_moments) - gt_pred_e1_HSM.shape[0]}"
)
# Re-et the original output_Q and output_dim parameters in the models
tf_semiparam_field.set_output_Q(
output_Q=original_out_Q, output_dim=original_out_dim
)
gt_tf_semiparam_field.set_output_Q(
output_Q=gt_original_out_Q, output_dim=gt_original_out_dim
)
# Need to compile the models again
tf_semiparam_field = build_PSF_model(tf_semiparam_field)
gt_tf_semiparam_field = build_PSF_model(gt_tf_semiparam_field)
# Moment results
result_dict = {
"pred_e1_HSM": pred_e1_HSM,
"pred_e2_HSM": pred_e2_HSM,
"pred_R2_HSM": pred_R2_HSM,
"gt_pred_e1_HSM": gt_pred_e1_HSM,
"gt_ped_e2_HSM": gt_pred_e2_HSM,
"gt_pred_R2_HSM": gt_pred_R2_HSM,
"rmse_e1": rmse_e1,
"std_rmse_e1": std_rmse_e1,
"rel_rmse_e1": rel_rmse_e1,
"std_rel_rmse_e1": std_rel_rmse_e1,
"rmse_e2": rmse_e2,
"std_rmse_e2": std_rmse_e2,
"rel_rmse_e2": rel_rmse_e2,
"std_rel_rmse_e2": std_rel_rmse_e2,
"rmse_R2_meanR2": rmse_R2_meanR2,
"std_rmse_R2_meanR2": std_rmse_R2_meanR2,
"pix_rmse": pix_rmse,
"pix_rmse_std": pix_rmse_std,
"rel_pix_rmse": rel_pix_rmse,
"rel_pix_rmse_std": rel_pix_rmse_std,
"output_Q": output_Q,
"output_dim": output_dim,
"n_bins_lda": n_bins_lda,
}
if opt_stars_rel_pix_rmse:
result_dict["stars_rel_pix_rmse"] = stars_rel_pix_rmse
return result_dict