Source code for wf_psf.plotting.plots_interface

"""Plot Interface.

An interface module with classes to handle different
plot configurations.

:Author: Jennifer Pollack <jennifer.pollack@cea.fr>

"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mtick
import seaborn as sns
import os
import re
import logging

logger = logging.getLogger(__name__)
logging.getLogger("matplotlib").setLevel(logging.WARNING)


[docs] def define_plot_style(): # type: ignore """Define Plot Style. A function to set plot_style parameters. """ plot_style = { "figure.figsize": (12, 8), "figure.dpi": 200, "figure.autolayout": True, "lines.linewidth": 2, "lines.linestyle": "-", "lines.marker": "o", "lines.markersize": 10, "legend.fontsize": 20, "legend.loc": "best", "axes.titlesize": 24, "font.size": 16, } mpl.rcParams.update(plot_style) # Use seaborn style sns.set()
[docs] def make_plot( x_axis, y_axis, y_axis_err, y2_axis, label, plot_title, x_axis_label, y_right_axis_label, y_left_axis_label, filename, plot_show=False, ): """Make Plot. A function to generate metrics plots. Parameters ---------- x_axis: list x-axis values y_axis: list y-axis values y_axis_err: list Error values for y-axis points y2_axis: list y2-axis values for right axis label: str Label for the points plot_title: str Name of plot x_axis_label: str Label for x-axis y_left_axis_label: str Label for left vertical axis of plot y_right_axis_label: str Label for right vertical axis of plot filename: str Name of file to save plot plot_show: bool Boolean flag to set plot display """ define_plot_style() fig = plt.figure(figsize=(12, 8)) ax1 = fig.add_subplot(111) ax1.set_title(plot_title) ax1.set_xlabel(x_axis_label) ax1.set_ylabel(y_left_axis_label) ax1.yaxis.set_major_formatter(mtick.FormatStrFormatter("%.1e")) ax2 = ax1.twinx() plt.minorticks_on() ax2.set_ylabel(y_right_axis_label) ax2.grid(False) for it in range(len(y_axis)): # type: ignore for k, _ in y_axis[it].items(): ax1.errorbar( x=x_axis[it], y=y_axis[it][k], yerr=y_axis_err[it][k], label=label[it], alpha=0.75, ) ax1.legend() kwargs = dict( linewidth=2, linestyle="dashed", markersize=4, marker="^", alpha=0.5 ) ax2.plot(x_axis[it], y2_axis[it][k], **kwargs) plt.savefig(filename) if plot_show is True: plt.show()
[docs] class MetricsPlotHandler: """MetricsPlotHandler class. A class to handle plot parameters for various metrics results. Parameters ---------- id: str Class ID name plotting_params: Recursive Namespace object metrics: dict Dictionary containing list of metrics list_of_stars: list List containing the number of training stars per run metric_name: str Name of metric rmse: str Root-mean square error label std_rmse: str Standard error on root-mean square error standard label plots_dir: str Output directory for metrics plots """ ids = ("poly_metrics", "opd_metrics", "poly_pixel") def __init__( self, plotting_params, metrics, list_of_stars, metric_name, rmse, std_rmse, rel_rmse, plot_title, plots_dir, ): self.plotting_params = plotting_params self.metrics = metrics self.metric_name = metric_name self.rmse = rmse self.std_rmse = std_rmse self.rel_rmse = rel_rmse self.plot_title = plot_title self.plots_dir = plots_dir self.list_of_stars = list_of_stars
[docs] def get_metrics(self, dataset): """Get Metrics. A function to get metrics: rmse, rmse_std for each run input, e.g. wf-outputs-xxxxxxxxxxxx. Parameters ---------- dataset: str A str representing dataset type, i.e. test_metrics or train_metrics. Returns ------- tuple: A tuple consisting of the id, root-mean-square (rms) and standard rms errors for a metric for each input run. """ rmse = [] std_rmse = [] rel_rmse = [] metrics_id = [] for k, v in self.metrics.items(): for metrics_data in v: run_id = list(metrics_data.keys())[0] metrics_id.append(run_id + "-" + k) rmse.append( { (k + "-" + run_id): metrics_data[run_id][0][dataset][ self.metric_name ][self.rmse] } ) std_rmse.append( { (k + "-" + run_id): metrics_data[run_id][0][dataset][ self.metric_name ][self.std_rmse] } ) rel_rmse.append( { (k + "-" + run_id): metrics_data[run_id][0][dataset][ self.metric_name ][self.rel_rmse] } ) return metrics_id, rmse, std_rmse, rel_rmse
[docs] def plot(self): """Plot. A function to generate metric plots as function of number of stars for the train and test metrics. """ for plot_dataset in ["test_metrics", "train_metrics"]: metrics_id, rmse, std_rmse, rel_rmse = self.get_metrics(plot_dataset) make_plot( x_axis=self.list_of_stars, y_axis=rmse, y_axis_err=std_rmse, y2_axis=rel_rmse, label=metrics_id, plot_title="Stars " + plot_dataset + self.plot_title, x_axis_label="Number of stars", y_left_axis_label="Absolute error", y_right_axis_label="Relative error [%]", filename=os.path.join( self.plots_dir, plot_dataset + "_" + self.metric_name + "_nstars_" + "_".join(str(nstar) for nstar in self.list_of_stars) + "_RMSE.png", ), plot_show=self.plotting_params.plot_show, )
[docs] class MonochromaticMetricsPlotHandler: """MonochromaticMetricsPlotHandler class. A class to handle plot parameters for monochromatic metrics results. Parameters ---------- id: str Class ID name plotting_params: Recursive Namespace object RecursiveNamespace object containing plotting parameters metrics_confs: dict Dictionary containing the metric configurations as RecursiveNamespace objects for each run metrics: list Dictionary containing list of metrics list_of_stars: list List containing the number of stars used for each training data set plots_dir: str Output directory for metrics plots """ ids = ("mono_metrics",) def __init__( self, plotting_params, metrics_confs, metrics, list_of_stars, plots_dir, ): self.plotting_params = plotting_params self.metrics_confs = metrics_confs self.metrics = metrics self.list_of_stars = list_of_stars self.plots_dir = plots_dir
[docs] def plot(self): """Plot. A function to generate plots for the train and test metrics. """ # Define common data # Common data lambda_list = np.arange(0.55, 0.9, 0.01) for plot_dataset in ["test_metrics", "train_metrics"]: y_axis = [] y_axis_err = [] y2_axis = [] metrics_id = [] for k, v in self.metrics.items(): if self.metrics_confs[k].metrics.eval_mono_metric: for metrics_data in v: run_id = list(metrics_data.keys())[0] metrics_id.append(run_id + "-" + k) y_axis.append( { (k + "-" + run_id): metrics_data[run_id][0][ plot_dataset ]["mono_metric"]["rmse_lda"] } ) y_axis_err.append( { (k + "-" + run_id): metrics_data[run_id][0][ plot_dataset ]["mono_metric"]["std_rmse_lda"] } ) y2_axis.append( { (k + "-" + run_id): metrics_data[run_id][0][ plot_dataset ]["mono_metric"]["rel_rmse_lda"] } ) make_plot( x_axis=[lambda_list for _ in range(len(y_axis))], y_axis=y_axis, y_axis_err=y_axis_err, y2_axis=y2_axis, label=metrics_id, plot_title="Stars " + plot_dataset # type: ignore + "\nMonochromatic pixel RMSE @ Euclid resolution", x_axis_label="Wavelength [um]", y_left_axis_label="Absolute error", y_right_axis_label="Relative error [%]", filename=os.path.join( self.plots_dir, ( plot_dataset + "_nstars_" + "_".join(str(nstar) for nstar in self.list_of_stars) + "_monochrom_pixel_RMSE.png" ), ), plot_show=self.plotting_params.plot_show, )
[docs] class ShapeMetricsPlotHandler: """ShapeMetricsPlotHandler class. A class to handle plot parameters shape metrics results. Parameters ---------- id: str Class ID name plotting_params: Recursive Namespace object Recursive Namespace Object containing plotting parameters metrics: list Dictionary containing list of metrics list_of_stars: list List containing the number of stars used for each training data set plots_dir: str Output directory for metrics plots """ id = "shape_metrics" def __init__(self, plotting_params, metrics, list_of_stars, plots_dir): self.plotting_params = plotting_params self.metrics = metrics self.list_of_stars = list_of_stars self.plots_dir = plots_dir
[docs] def plot(self): """Plot. A generic function to generate plots for the train and test metrics. """ e1_req_euclid = 2e-04 e2_req_euclid = 2e-04 R2_req_euclid = 1e-03 for plot_dataset in ["test_metrics", "train_metrics"]: metrics_data = self.prepare_metrics_data( plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid ) # Plot for e1 for k, v in metrics_data.items(): self.make_shape_metrics_plot( metrics_data[k]["rmse"], metrics_data[k]["std_rmse"], metrics_data[k]["rel_rmse"], plot_dataset, k, )
[docs] def prepare_metrics_data( self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid ): """Prepare Metrics Data. A function to prepare the metrics data for plotting. Parameters ---------- plot_dataset: str A string representing the dataset, i.e. training or test metrics. e1_req_euclid: float A float denoting the Euclid requirement for the `e1` shape metric. e2_req_euclid: float A float denoting the Euclid requirement for the `e2` shape metric. R2_req_euclid: float A float denoting the Euclid requirement for the `R2` shape metric. Returns ------- shape_metrics_data: dict A dictionary containing the shape metrics data from a set of runs. """ shape_metrics_data = { "e1": {"rmse": [], "std_rmse": [], "rel_rmse": []}, "e2": {"rmse": [], "std_rmse": [], "rel_rmse": []}, "R2_meanR2": {"rmse": [], "std_rmse": [], "rel_rmse": []}, } for k, v in self.metrics.items(): for metrics_data in v: run_id = list(metrics_data.keys())[0] for metric in ["e1", "e2", "R2_meanR2"]: metric_rmse = metrics_data[run_id][0][plot_dataset][ "shape_results_dict" ][f"rmse_{metric}"] metric_std_rmse = metrics_data[run_id][0][plot_dataset][ "shape_results_dict" ][f"std_rmse_{metric}"] relative_metric_rmse = metric_rmse / ( e1_req_euclid if metric == "e1" else (e2_req_euclid if metric == "e2" else R2_req_euclid) ) shape_metrics_data[metric]["rmse"].append( {f"{k}-{run_id}": metric_rmse} ) shape_metrics_data[metric]["std_rmse"].append( {f"{k}-{run_id}": metric_std_rmse} ) shape_metrics_data[metric]["rel_rmse"].append( {f"{k}-{run_id}": relative_metric_rmse} ) return shape_metrics_data
[docs] def make_shape_metrics_plot( self, rmse_data, std_rmse_data, rel_rmse_data, plot_dataset, metric ): """Make Shape Metrics Plot. A function to produce plots for the shape metrics. Parameters ---------- rmse_data: list A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse). std_rmse_data: list A list of dictionaries where each dictionary stores run as the key and the Standard Deviation of the Root Mean Square Error (rmse) as the value. rel_rmse_data: list A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value. plot_dataset: str A string denoting whether metrics are for the train or test datasets. metric: str A string representing the type of shape metric, i.e., e1, e2, or R2. """ make_plot( x_axis=self.list_of_stars, y_axis=rmse_data, y_axis_err=std_rmse_data, y2_axis=rel_rmse_data, label=[key for item in rmse_data for key in item], plot_title=f"Stars {plot_dataset}. Shape {metric.upper()} RMSE", x_axis_label="Number of stars", y_left_axis_label="Absolute error", y_right_axis_label="Relative error [%]", filename=os.path.join( self.plots_dir, f"{plot_dataset}_nstars_{'_'.join(str(nstar) for nstar in self.list_of_stars)}_Shape_{metric.upper()}_RMSE.png", ), plot_show=self.plotting_params.plot_show, )
[docs] def get_number_of_stars(metrics): """Get Number of Stars. A function to get the number of stars used in training the model. Parameters ---------- metrics: dict A dictionary containig the metrics results per run Returns ------- list_of_stars: list A list containing the number of training stars per run. """ list_of_stars = [] try: for k, v in metrics.items(): for run in v: run_id = list(run.keys())[0] list_of_stars.append(int(re.search(r"\d+", run_id).group())) except AttributeError: list_of_stars = np.arange(len(metrics.items())) + 1 return list_of_stars
[docs] def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_path): r"""Plot model results. Parameters ---------- plotting_params: RecursiveNamespace Object Recursive Namespace object containing plot configuration parameters list_of_metrics: list List containing all model metrics metrics_confs: list List containing all metrics configuration parameters plot_saving_path: str Directory path for saving output plots """ metrics = { "poly_metric": { "rmse": "rmse", "std_rmse": "std_rmse", "rel_rmse": "rel_rmse", "plot_title": ".\nPolychromatic pixel RMSE @ Euclid resolution", }, "opd_metric": { "rmse": "rmse_opd", "std_rmse": "rmse_std_opd", "rel_rmse": "rel_rmse_opd", "plot_title": ".\nOPD RMSE", }, "shape_results_dict": { "rmse": "pix_rmse", "std_rmse": "pix_rmse_std", "rel_rmse": "rel_pix_rmse", "plot_title": "\nPixel RMSE @ 3x Euclid resolution", }, } list_of_stars = get_number_of_stars(list_of_metrics) for k, v in metrics.items(): metrics_plot = MetricsPlotHandler( plotting_params, list_of_metrics, list_of_stars, k, v["rmse"], v["std_rmse"], v["rel_rmse"], v["plot_title"], plot_saving_path, ) metrics_plot.plot() monochrom_metrics_plot = MonochromaticMetricsPlotHandler( plotting_params, metrics_confs, list_of_metrics, list_of_stars, plot_saving_path ) monochrom_metrics_plot.plot() shape_metrics_plot = ShapeMetricsPlotHandler( plotting_params, list_of_metrics, list_of_stars, plot_saving_path ) shape_metrics_plot.plot()