Source code for wf_psf.plotting.plotting_config_handler

"""PlottingConfigHandler.

A module containing the PlottingConfigHandler class, which is responsible for managing the parameters of the plotting configuration file and running the
plotting of the metrics of a trained PSF model

:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>
"""

import os
import glob
import numpy as np
from typing import Any
from wf_psf.utils.configs_handler import ConfigHandler, register_configclass
from wf_psf.utils.read_config import read_conf
from wf_psf.plotting.plots_interface import plot_metrics
import logging

logger = logging.getLogger(__name__)


[docs] @register_configclass class PlottingConfigHandler(ConfigHandler): """PlottingConfigHandler. A class to handle plotting config settings. Parameters ---------- ids: tuple A tuple containing a string id for the Configuration Class plotting_conf: str Name of plotting configuration file file_handler: obj An instance of the FileIOHandler class """ ids = ("plotting_conf",) def __init__(self, plotting_conf, file_handler): self.plotting_conf = read_conf(plotting_conf) self.file_handler = file_handler self.metrics_confs = {} self.check_and_update_metrics_confs() self.list_of_metrics_dict = self.make_dict_of_metrics() self.plots_dir = self.file_handler.get_plots_dir( self.file_handler._run_output_dir )
[docs] def check_and_update_metrics_confs(self): """Check and Update Metrics Confs. A function to check if user provided inputs metrics dir to add to metrics_confs dictionary. """ if self.plotting_conf.plotting_params.metrics_dir: self._update_metrics_confs()
[docs] def make_dict_of_metrics(self): """Make dictionary of metrics. A function to create a dictionary for each metrics per run. Returns ------- dict A dictionary containing metrics or an empty dictionary. """ if self.plotting_conf.plotting_params.metrics_dir: return self.load_metrics_into_dict() else: return {}
def _update_metrics_confs(self): """Update Metrics Configurations. A method to update the metrics_confs dictionary with each set of metrics configuration parameters provided as input. """ for wf_dir, metrics_conf in zip( self.plotting_conf.plotting_params.metrics_dir, self.plotting_conf.plotting_params.metrics_config, ): self.metrics_confs[wf_dir] = read_conf( os.path.join( self.plotting_conf.plotting_params.metrics_output_path, self.file_handler.get_config_dir(wf_dir), metrics_conf, ) ) def _metrics_run_id_name(self, wf_outdir, metrics_params): """Get Metrics Run ID Name. A function to generate run id name for the metrics of a trained model Parameters ---------- wf_outdir: str Name of the wf-psf run output directory metrics_params: RecursiveNamespace Object RecursiveNamespace object containing the metrics parameters used to evaluated the trained model. Returns ------- metrics_run_id_name: list List containing the model name and id for each training run """ try: training_conf = read_conf( os.path.join( metrics_params.metrics.trained_model_path, metrics_params.metrics.trained_model_config, ) ) id_name = training_conf.training.id_name model_name = training_conf.training.model_params.model_name return [model_name + id_name] except (TypeError, FileNotFoundError): logger.info("Trained model path not provided...") logger.info( f"Trying to retrieve training config file from workdir: {wf_outdir}" ) training_confs = [ read_conf(training_conf) for training_conf in glob.glob( os.path.join( self.plotting_conf.plotting_params.metrics_output_path, self.file_handler.get_config_dir(wf_outdir), "training*", ) ) ] run_ids = [ training_conf.training.model_params.model_name + training_conf.training.id_name for training_conf in training_confs ] return run_ids except FileNotFoundError: logger.exception("File not found.")
[docs] def load_metrics_into_dict(self): """Load Metrics into Dictionary. A method to load a metrics file of a trained model from a previous run into a dictionary. Returns ------- metrics_files_dict: dict A dictionary containing all of the metrics from the loaded metrics files. """ metrics_dict = {} for k, v in self.metrics_confs.items(): run_id_names: list[Any] | None = self._metrics_run_id_name(k, v) metrics_dict[k] = [] for run_id_name in run_id_names: # pyright: ignore[reportOptionalIterable] output_path = os.path.join( self.plotting_conf.plotting_params.metrics_output_path, k, "metrics", "metrics-" + run_id_name + ".npy", ) logger.info( f"Attempting to read in trained model config file...{output_path}" ) try: metrics_dict[k].append( {run_id_name: [np.load(output_path, allow_pickle=True)[()]]} ) except FileNotFoundError: logger.error( "The required file for the plots was not found. Please check your configs settings." ) return metrics_dict
[docs] def run(self): """Run. A function to run wave-diff according to the input configuration. """ logger.info("Generating metric plots...") plot_metrics( self.plotting_conf.plotting_params, self.list_of_metrics_dict, self.metrics_confs, self.plots_dir, )