"""Metrics Interface.
A module which defines the classes and methods
to manage metrics evaluation of the trained psf model.
:Author: Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import numpy as np
from typing import Any
import time
from wf_psf.psf_models import psf_models
from wf_psf.metrics import metrics as wf_metrics
import logging
logger = logging.getLogger(__name__)
[docs]
class MetricsParamsHandler:
"""Metrics Parameters Handler.
A class to handle metrics-related parameters.
Parameters
----------
metrics_params: Recursive Namespace object
Recursive Namespace object containing metrics input parameters
trained_model: Recursive Namespace object
Recursive Namespace object containing trained model input parameters
"""
def __init__(self, metrics_params, trained_model):
self.metrics_params = metrics_params
self.trained_model = trained_model
[docs]
def evaluate_metrics_polychromatic_lowres(
self, psf_model: Any, simPSF: Any, data: Any, dataset: dict[str, Any]
) -> dict[str, float]:
"""Evaluate RMSE metrics for low-resolution polychromatic PSF.
This method computes Root Mean Square Error (RMSE) metrics for a
low-resolution polychromatic Point Spread Function (PSF) model.
Parameters
----------
psf_model : object
An instance of the PSF model selected for metrics evaluation.
simPSF : object
An instance of the PSFSimulator.
data : object
A DataConfigHandler object containing training and test datasets.
dataset : dict
Dictionary containing dataset details, including:
- ``SEDs`` (Spectral Energy Distributions)
- ``positions`` (Star positions)
- ``C_poly`` Tensor or None, optional
The Zernike coefficient matrix used in generating simulations of the PSF model. This
matrix defines the Zernike polynomials up to a given order used to simulate the PSF
field. It may be present in some datasets or only required for some classes.
If not present or required, the model will proceed without it.
Returns
-------
dict
A dictionary containing the RMSE, relative RMSE, and their
corresponding standard deviation values.
- ``rmse`` : float
Root Mean Square Error (RMSE).
- ``rel_rmse`` : float
Relative RMSE.
- ``std_rmse`` : float
Standard deviation of RMSE.
- ``std_rel_rmse`` : float
Standard deviation of relative RMSE.
"""
logger.info("Computing polychromatic metrics at low resolution.")
# Check if testing predictions should be masked
mask = self.trained_model.training_hparams.loss == "mask_mse"
# Compute metrics
rmse, rel_rmse, std_rmse, std_rel_rmse = wf_metrics.compute_poly_metric(
tf_semiparam_field=psf_model,
gt_tf_semiparam_field=psf_models.get_psf_model(
self.metrics_params.ground_truth_model.model_params,
self.metrics_params.metrics_hparams,
data,
dataset.get("C_poly", None), # Extract C_poly or default to None
),
simPSF_np=simPSF,
tf_pos=dataset["positions"],
tf_SEDs=dataset["SEDs"],
n_bins_lda=self.trained_model.model_params.n_bins_lda,
n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda,
batch_size=self.metrics_params.metrics_hparams.batch_size,
dataset_dict=dataset,
mask=mask,
)
return {
"rmse": rmse,
"rel_rmse": rel_rmse,
"std_rmse": std_rmse,
"std_rel_rmse": std_rel_rmse,
}
[docs]
def evaluate_metrics_mono_rmse(
self, psf_model: Any, simPSF: Any, data: Any, dataset: dict[str, Any]
) -> dict[str, float]:
"""Evaluate RMSE metrics for Monochromatic PSF.
This method computes Root Mean Square Error (RMSE) metrics for a
monochromatic Point Spread Function (PSF) model across a range of
wavelengths.
Parameters
----------
psf_model : object
An instance of the PSF model selected for metrics evaluation.
simPSF : object
An instance of the PSFSimulator.
data : object
A DataConfigHandler object containing training and test datasets.
dataset : dict
Dictionary containing dataset details, including:
- ``positions`` (Star positions)
- ``C_poly`` (Tensor or None, optional)
The Zernike coefficient matrix used in generating simulations of the PSF model. This
matrix defines the Zernike polynomials up to a given order used to simulate the PSF
field. It may be present in some datasets or only required for some classes.
If not present or required, the model will proceed without it.
Returns
-------
dict
A dictionary containing RMSE, relative RMSE, and their corresponding
standard deviation values computed over a wavelength range.
- ``rmse_lda`` : float
Root Mean Square Error (RMSE) over wavelengths.
- ``rel_rmse_lda`` : float
Relative RMSE over wavelengths.
- ``std_rmse_lda`` : float
Standard deviation of RMSE over wavelengths.
- ``std_rel_rmse_lda`` : float
Standard deviation of relative RMSE over wavelengths.
"""
logger.info("Computing monochromatic metrics.")
# Define the wavelength range (550nm to 900nm with 10nm intervals)
lambda_list = np.arange(0.55, 0.9, 0.01) # 10nm separation
(
rmse_lda,
rel_rmse_lda,
std_rmse_lda,
std_rel_rmse_lda,
) = wf_metrics.compute_mono_metric(
tf_semiparam_field=psf_model,
gt_tf_semiparam_field=psf_models.get_psf_model(
self.metrics_params.ground_truth_model.model_params,
self.metrics_params.metrics_hparams,
data,
dataset.get("C_poly", None),
),
simPSF_np=simPSF,
tf_pos=dataset["positions"],
lambda_list=lambda_list,
)
return {
"rmse_lda": rmse_lda,
"rel_rmse_lda": rel_rmse_lda,
"std_rmse_lda": std_rmse_lda,
"std_rel_rmse_lda": std_rel_rmse_lda,
}
[docs]
def evaluate_metrics_opd(
self, psf_model: Any, simPSF: Any, data: Any, dataset: dict[str, Any]
) -> dict[str, float]:
"""Evaluate Optical Path Difference (OPD) metrics.
This method computes Root Mean Square Error (RMSE) and relative RMSE
for Optical Path Differences (OPD), along with their standard deviations.
Parameters
----------
psf_model: object
An instance of the PSF model selected for metrics evaluation.
simPSF: object
An instance of the PSFSimulator.
data : object
A DataConfigHandler object containing training and test datasets.
dataset : dict
Dictionary containing dataset details, including:
- ``positions`` (Star positions)
- ``C_poly`` (Tensor or None, optional)
The Zernike coefficient matrix used in generating simulations of the PSF model. This
matrix defines the Zernike polynomials up to a given order used to simulate the PSF
field. It may be present in some datasets or only required for some classes.
If not present or required, the model will proceed without it.
Returns
-------
dict
A dictionary containing RMSE, relative RMSE, and their corresponding
standard deviation values for OPD metrics.
- ``rmse_opd`` : float
Root Mean Square Error (RMSE) for OPD.
- ``rel_rmse_opd`` : float
Relative RMSE for OPD.
- ``rmse_std_opd`` : float
Standard deviation of RMSE for OPD.
- ``rel_rmse_std_opd`` : float
Standard deviation of relative RMSE for OPD.
"""
logger.info("Computing OPD metrics.")
(
rmse_opd,
rel_rmse_opd,
rmse_std_opd,
rel_rmse_std_opd,
) = wf_metrics.compute_opd_metrics(
tf_semiparam_field=psf_model,
gt_tf_semiparam_field=psf_models.get_psf_model(
self.metrics_params.ground_truth_model.model_params,
self.metrics_params.metrics_hparams,
data,
dataset.get("C_poly", None), # Extract C_poly if available
),
pos=dataset["positions"],
batch_size=self.metrics_params.metrics_hparams.batch_size,
)
return {
"rmse_opd": rmse_opd,
"rel_rmse_opd": rel_rmse_opd,
"rmse_std_opd": rmse_std_opd,
"rel_rmse_std_opd": rel_rmse_std_opd,
}
[docs]
def evaluate_metrics_shape(
self, psf_model: Any, simPSF: Any, data: Any, dataset: dict[str, Any]
) -> dict[str, float]:
"""Evaluate PSF Shape Metrics.
Computes shape-related metrics for the PSF model, including RMSE,
relative RMSE, and their standard deviations.
Parameters
----------
psf_model : object
Instance of the PSF model selected for evaluation.
simPSF : object
Instance of the PSFSimulator.
data : object
A DataConfigHandler object containing training and test datasets.
dataset : dict
Dictionary containing dataset details, including:
- ``SEDs`` (Spectral Energy Distributions)
- ``positions`` (Star positions)
- ``C_poly`` (Tensor or None, optional)
The Zernike coefficient matrix used in generating simulations of the PSF model. This
matrix defines the Zernike polynomials up to a given order used to simulate the PSF
field. It may be present in some datasets or only required for some classes.
If not present or required, the model will proceed without it.
Returns
-------
shape_results: dict
Dictionary containing RMSE, Relative RMSE values, and
corresponding Standard Deviation values for PSF Shape metrics.
"""
logger.info("Computing Shape metrics.")
shape_results = wf_metrics.compute_shape_metrics(
tf_semiparam_field=psf_model,
gt_tf_semiparam_field=psf_models.get_psf_model(
self.metrics_params.ground_truth_model.model_params,
self.metrics_params.metrics_hparams,
data,
dataset.get("C_poly", None),
),
simPSF_np=simPSF,
SEDs=dataset["SEDs"],
tf_pos=dataset["positions"],
n_bins_lda=self.trained_model.model_params.n_bins_lda,
n_bins_gt=self.metrics_params.ground_truth_model.model_params.n_bins_lda,
batch_size=self.metrics_params.metrics_hparams.batch_size,
output_Q=self.metrics_params.metrics_hparams.output_Q,
output_dim=self.metrics_params.metrics_hparams.output_dim,
opt_stars_rel_pix_rmse=self.metrics_params.metrics_hparams.opt_stars_rel_pix_rmse,
dataset_dict=dataset,
)
return shape_results
[docs]
def evaluate_model(
metrics_params,
trained_model_params,
data,
psf_model,
weights_path,
metrics_output,
):
"""Evaluate the trained model on both training and test datasets by computing various metrics.
The metrics to evaluate are determined by the configuration in `metrics_params` and `metric_evaluation_flags`.
Metrics are computed for both the training and test datasets, and results are stored in a dictionary.
Parameters
----------
metrics_params: Recursive Namespace object
Recursive Namespace object containing metrics input parameters
trained_model_params: Recursive Namespace object
Recursive Namespace object containing trained model input parameters
data: DataHandler object
DataHandler object containing training and test data
psf_model: object
PSF model object
weights_path: str
Directory location of model weights
metrics_output: str
Directory location of metrics output
"""
# Start measuring elapsed time
starting_time = time.time()
try:
## Load datasets
# -----------------------------------------------------
# Get training data
logger.info("Fetching and preprocessing training and test data...")
# Initialize metrics_handler
metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params)
## Prepare models
# Prepare np input
simPSF_np = data.training_data.simPSF
## Load the model's weights
try:
logger.info(f"Loading PSF model weights from {weights_path}")
psf_model.load_weights(weights_path)
except Exception as e:
logger.exception("An error occurred with the weights_path file: %s", e)
exit()
# Define datasets
datasets = {"test": data.test_data.dataset, "train": data.training_data.dataset}
# Initialise dictionary to store metrics
all_metrics = {}
# Define metric names and their corresponding evaluation flags
metric_evaluation_flags = {
"poly_metric": {
"test": True,
"train": True,
},
"mono_metric": {
"test": metrics_params.eval_mono_metric,
"train": metrics_params.eval_mono_metric,
},
"opd_metric": {
"test": metrics_params.eval_opd_metric,
"train": metrics_params.eval_opd_metric,
},
"shape_results_dict": {
"test": metrics_params.eval_test_shape_results_dict,
"train": metrics_params.eval_train_shape_results_dict,
},
}
# Define the metric evaluation functions
metric_functions = {
"poly_metric": metrics_handler.evaluate_metrics_polychromatic_lowres,
"mono_metric": metrics_handler.evaluate_metrics_mono_rmse,
"opd_metric": metrics_handler.evaluate_metrics_opd,
"shape_results_dict": metrics_handler.evaluate_metrics_shape,
}
for dataset_type, dataset in datasets.items():
## Metric evaluation
logger.info(
f"\n***\nMetric evaluation on the {dataset_type} dataset\n***\n"
)
# Create dictionary to store metrics for the current dataset
dataset_metrics = {}
# Evaluate metrics based on evaluation flags
for metric_name, metric_function in metric_functions.items():
# Check if any attribute in the metrics_params contains the
# substring metric_name
if metric_evaluation_flags[metric_name][dataset_type]:
dataset_metrics[metric_name] = metric_function(
psf_model,
simPSF_np,
data,
dataset,
)
else:
dataset_metrics[metric_name] = None
# Store dataset metrics in the overall metrics dictionary
all_metrics[f"{dataset_type}_metrics"] = dataset_metrics
run_id_name = (
trained_model_params.model_params.model_name + trained_model_params.id_name
)
output_path = metrics_output + "/" + "metrics-" + run_id_name
np.save(output_path, all_metrics, allow_pickle=True)
## Print final time
final_time = time.time()
logger.info("\nTotal elapsed time: %f" % (final_time - starting_time))
## Close log file
logger.info("\n Good bye..")
return all_metrics
except Exception as e:
logger.info("Error: %s", e)
raise