"""MetricsConfigHandler.
A module containing the MetricsConfigHandler class, which is responsible for managing the parameters of the metrics configuration file and running the
metrics evaluation of a trained PSF model
:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import os
from wf_psf.utils.configs_handler import (
ConfigHandler,
register_configclass,
ConfigParameterError,
)
from wf_psf.data.factory import DataAdapterFactory
from wf_psf.data.data_adapter import StructureState, RepresentationState
from wf_psf.data.data_config_handler import DataConfigHandler
from wf_psf.data.schemas import DatasetMode
from wf_psf.utils.read_config import read_conf
from wf_psf.metrics.metrics_interface import evaluate_model
from wf_psf.psf_models import psf_models
from wf_psf.psf_models.psf_model_loader import load_trained_psf_model
from wf_psf.plotting.plotting_config_handler import PlottingConfigHandler
import logging
logger = logging.getLogger(__name__)
[docs]
@register_configclass
class MetricsConfigHandler(ConfigHandler):
"""MetricsConfigHandler.
A class to handle metrics configuation
parameters.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
metrics_conf: str
Path to the metrics configuration file
file_handler: object
An instance of the FileIOHandler
training_conf: RecursiveNamespace object
RecursiveNamespace object containing the training configuration parameters
"""
ids = ("metrics_conf",)
def __init__(self, metrics_conf, file_handler, training_conf=None):
self._metrics_conf = read_conf(metrics_conf)
self._file_handler = file_handler
self.training_conf = training_conf
self.data_adapter = self._load_data_conf()
self.simPSF = psf_models.simPSF(self.training_conf.model_params)
self.n_bins_lambda = self.training_conf.model_params.n_bins_lambda
self.metrics_dir = self._file_handler.get_metrics_dir(
self._file_handler._run_output_dir
)
self.trained_psf_model = self._load_trained_psf_model()
@property
def metrics_conf(self):
"""Get Metrics Conf.
A function to return the metrics configuration file name.
Returns
-------
RecursiveNamespace
An instance of the metrics configuration file.
"""
return self._metrics_conf
@property
def training_conf(self):
"""Returns the loaded training configuration."""
return self._training_conf
@training_conf.setter
def training_conf(self, training_conf):
"""
Set the training configuration.
If None is provided, attempts to load it
from the trained_model_path in the metrics configuration.
"""
if training_conf is None:
try:
training_conf_path = self._get_training_conf_path_from_metrics()
logger.info(
f"Loading training config from inferred path: {training_conf_path}"
)
self._training_conf = read_conf(training_conf_path).training
except Exception as e:
logger.error(f"Failed to load training config: {e}")
raise
else:
self._training_conf = training_conf
@property
def plotting_conf(self):
"""Get Plotting Conf.
A function to return the plotting configuration file name.
Returns
-------
str
Name of plotting configuration file
"""
return self.metrics_conf.metrics.plotting_config
def _load_trained_psf_model(self):
trained_model_path = self._get_trained_model_path()
try:
model_subdir = self.metrics_conf.metrics.model_save_path
cycle = self.metrics_conf.metrics.saved_training_cycle
except AttributeError as e:
raise KeyError("Missing required model config fields.") from e
model_name = self.training_conf.model_params.model_name
id_name = self.training_conf.id_name
weights_path_pattern = os.path.join(
trained_model_path,
model_subdir,
(f"{model_subdir}*_{model_name}*{id_name}_cycle{cycle}*"),
)
return load_trained_psf_model(
self.training_conf,
self.data_adapter.complete_data,
weights_path_pattern,
)
def _get_training_conf_path_from_metrics(self):
"""Get training config path from metrics config.
Retrieves the full path to the training config based on the metrics configuration.
Returns
-------
str
Full path to the training configuration file.
Raises
------
KeyError
If 'trained_model_config' key is missing.
FileNotFoundError
If the file does not exist at the constructed path.
"""
trained_model_path = self._get_trained_model_path()
try:
training_conf_filename = self._metrics_conf.metrics.trained_model_config
except AttributeError as e:
raise KeyError(
"Missing 'trained_model_config' key in metrics configuration."
) from e
training_conf_path = os.path.join(
self._file_handler.get_config_dir(trained_model_path),
training_conf_filename,
)
if not os.path.exists(training_conf_path):
raise FileNotFoundError(
f"Training config file not found: {training_conf_path}"
)
return training_conf_path
def _get_trained_model_path(self):
"""Get trained model path.
Determine the trained model path from either:
1. The metrics configuration file (i.e., for metrics-only runs after training), or
2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation).
Returns
-------
str
Path to the trained model directory.
Raises
------
ConfigParameterError
If the path specified in the metrics config is invalid or missing.
"""
trained_model_path = getattr(
self._metrics_conf.metrics, "trained_model_path", None
)
if trained_model_path:
if not os.path.isdir(trained_model_path):
raise ConfigParameterError(
f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}"
)
logger.info(
f"Using trained model path from metrics config: {trained_model_path}"
)
return trained_model_path
# Fallback for single-run training + metrics evaluation mode
fallback_path = os.path.join(
self._file_handler.output_path,
self._file_handler.parent_output_dir,
self._file_handler.workdir,
)
logger.info(
f"Using fallback trained model path from runtime file handler: {fallback_path}"
)
return fallback_path
def _load_data_conf(self):
"""Load Data Conf.
A method to load the data configuration file
and return an instance of DataConfigHandler class.
Returns
-------
An instance of the DataConfigHandler class.
"""
try:
data_params = DataConfigHandler(
os.path.join(
self._file_handler.config_path,
self.training_conf.data_config,
),
)
adapter = DataAdapterFactory.build(data_params)
# Join data, if not already complete
if adapter.structure_state == StructureState.SPLIT:
logger.info("Joining split datasets...")
adapter.join_data()
return adapter
except TypeError as e:
logger.exception(e)
raise ConfigParameterError("Data configuration loading error.")
[docs]
def call_plot_config_handler_run(self, model_metrics):
"""Make Metrics Plots.
A function to call the PlottingConfigHandler run
command to generate metrics plots.
Parameters
----------
model_metrics: dict
A dictionary storing the metrics
output generated during evaluation
of the trained PSF model.
"""
self._plotting_conf = os.path.join(
self._file_handler.config_path,
self.plotting_conf,
)
plots_config_handler = PlottingConfigHandler(
self._plotting_conf,
self._file_handler,
)
# Update metrics_confs dict with latest result
plots_config_handler.metrics_confs[self._file_handler.workdir] = (
self.metrics_conf
)
# Update metric results dict with latest result
plots_config_handler.list_of_metrics_dict[self._file_handler.workdir] = [
{
self.training_conf.model_params.model_name
+ self.training_conf.id_name: [model_metrics]
}
]
plots_config_handler.run()
[docs]
def run(self):
"""Run.
A function to run WaveDiff according to the
input configuration.
"""
logger.info("Running metrics evaluation on trained PSF model...")
# Split dataset for metrics evaluation, idempotent
if self.data_adapter.structure_state == StructureState.COMPLETE:
self.data_adapter.split_data()
# Convert to TF required for PSF model generation
if self.data_adapter.representation_state == RepresentationState.NUMPY:
self.data_adapter.convert_to_tensorflow(self.simPSF, self.n_bins_lambda, mode=DatasetMode.EVALUATION)
model_metrics = evaluate_model(
self.metrics_conf.metrics,
self.training_conf,
self.data_adapter,
self.simPSF,
self.trained_psf_model,
self.metrics_dir,
)
if self.plotting_conf is not None:
self._file_handler.copy_conffile_to_output_dir(
self.metrics_conf.metrics.plotting_config
)
self.call_plot_config_handler_run(model_metrics)