Source code for wf_psf.inference.psf_inference

"""Inference.

A module which provides a PSFInference class to perform inference
with trained PSF models. It is able to load a trained model,
perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs.

:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>

"""

import os
from pathlib import Path
import numpy as np
from wf_psf.data.data_handler import DataHandler
from wf_psf.utils.read_config import read_conf
from wf_psf.utils.utils import ensure_batch
from wf_psf.psf_models import psf_models
from wf_psf.psf_models.psf_model_loader import load_trained_psf_model
import tensorflow as tf


[docs] class InferenceConfigHandler: """ Handle configuration loading and management for PSF inference. This class manages the loading of inference, training, and data configuration files required for PSF inference operations. Parameters ---------- inference_config_path : str Path to the inference configuration YAML file. Attributes ---------- inference_config_path : str Path to the inference configuration file. inference_config : RecursiveNamespace or None Loaded inference configuration. training_config : RecursiveNamespace or None Loaded training configuration. data_config : RecursiveNamespace or None Loaded data configuration. trained_model_path : Path Path to the trained model directory. model_subdir : str Subdirectory name for model files. trained_model_config_path : Path Path to the training configuration file. data_config_path : str or None Path to the data configuration file. """ ids = ("inference_conf",) def __init__(self, inference_config_path: str): self.inference_config_path = inference_config_path self.inference_config = None self.training_config = None self.data_config = None
[docs] def load_configs(self): """ Load configuration files based on the inference config. Loads the inference configuration first, then uses it to determine and load the training and data configurations. Notes ----- Updates the following attributes in-place: - inference_config - training_config - data_config (if data_config_path is specified) """ self.inference_config = read_conf(self.inference_config_path) self.set_config_paths() self.training_config = read_conf(self.trained_model_config_path) if self.data_config_path is not None: # Load the data configuration self.data_config = read_conf(self.data_config_path)
[docs] def set_config_paths(self): """ Extract and set the configuration paths from the inference config. Sets the following attributes: - trained_model_path - model_subdir - trained_model_config_path - data_config_path """ # Set config paths config_paths = self.inference_config.inference.configs self.trained_model_path = Path(config_paths.trained_model_path) self.model_subdir = config_paths.model_subdir self.trained_model_config_path = ( self.trained_model_path / config_paths.trained_model_config_path ) self.data_config_path = config_paths.data_config_path
[docs] @staticmethod def overwrite_model_params(training_config=None, inference_config=None): """ Overwrite training model_params with values from inference_config if available. Parameters ---------- training_config : RecursiveNamespace Configuration object from training phase. inference_config : RecursiveNamespace Configuration object from inference phase. Notes ----- Updates are applied in-place to training_config.training.model_params. """ model_params = training_config.training.model_params inf_model_params = inference_config.inference.model_params if model_params and inf_model_params: for key, value in inf_model_params.__dict__.items(): if hasattr(model_params, key): setattr(model_params, key, value)
[docs] class PSFInference: """ Perform PSF inference using a pre-trained WaveDiff model. This class handles the setup for PSF inference, including loading configuration files, instantiating the PSF simulator and data handler, and preparing the input data required for inference. Parameters ---------- inference_config_path : str Path to the inference configuration YAML file. x_field : array-like, optional x coordinates in SHE convention. y_field : array-like, optional y coordinates in SHE convention. seds : array-like, optional Spectral energy distributions (SEDs). sources : array-like, optional Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). masks : array-like, optional Corresponding masks for the sources (same shape as sources). Defaults to None. Attributes ---------- inference_config_path : str Path to the inference configuration file. x_field : array-like or None x coordinates for PSF positions. y_field : array-like or None y coordinates for PSF positions. seds : array-like or None Spectral energy distributions. sources : array-like or None Source postage stamps. masks : array-like or None Source masks. engine : PSFInferenceEngine or None The inference engine instance. Examples -------- Basic usage with position coordinates and SEDs: .. code-block:: python psf_inf = PSFInference( inference_config_path="config.yaml", x_field=[100.5, 200.3], y_field=[150.2, 250.8], seds=sed_array ) psf_inf.run_inference() psf = psf_inf.get_psf(0) """ def __init__( self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None, ): self.inference_config_path = inference_config_path # Inputs for the model self.x_field = x_field self.y_field = y_field self.seds = seds self.sources = sources self.masks = masks # Internal caches for lazy-loading self._config_handler = None self._simPSF = None self._data_handler = None self._trained_psf_model = None self._n_bins_lambda = None self._batch_size = None self._cycle = None self._output_dim = None # Initialise PSF Inference engine self.engine = None @property def config_handler(self): """ Get or create the configuration handler. Returns ------- InferenceConfigHandler The configuration handler instance with loaded configs. """ if self._config_handler is None: self._config_handler = InferenceConfigHandler(self.inference_config_path) self._config_handler.load_configs() return self._config_handler
[docs] def prepare_configs(self): """ Prepare the configuration for inference. Overwrites training model parameters with inference configuration values. """ # Overwrite model parameters with inference config self.config_handler.overwrite_model_params( self.training_config, self.inference_config )
@property def inference_config(self): """ Get the inference configuration. Returns ------- RecursiveNamespace The inference configuration object. """ return self.config_handler.inference_config @property def training_config(self): """ Get the training configuration. Returns ------- RecursiveNamespace The training configuration object. """ return self.config_handler.training_config @property def data_config(self): """ Get the data configuration. Returns ------- RecursiveNamespace or None The data configuration object, or None if not available. """ return self.config_handler.data_config @property def simPSF(self): """ Get or create the PSF simulator. Returns ------- simPSF The PSF simulator instance. """ if self._simPSF is None: self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF def _prepare_dataset_for_inference(self): """ Prepare dataset dictionary for inference. Returns ------- dict or None Dictionary containing positions, sources, and masks, or None if positions are invalid. """ positions = self.get_positions() if positions is None: return None return {"positions": positions, "sources": self.sources, "masks": self.masks} @property def data_handler(self): """ Get or create the data handler. Returns ------- DataHandler The data handler instance configured for inference. """ if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( dataset_type="inference", data_params=self.data_config, simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, dataset=self._prepare_dataset_for_inference(), sed_data=self.seds, ) self._data_handler.run_type = "inference" return self._data_handler @property def trained_psf_model(self): """ Get or load the trained PSF model. Returns ------- Model The loaded trained PSF model. """ if self._trained_psf_model is None: self._trained_psf_model = self.load_inference_model() return self._trained_psf_model
[docs] def get_positions(self): """ Combine x_field and y_field into position pairs. Returns ------- numpy.ndarray Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None. Raises ------ ValueError If x_field and y_field have different lengths. """ if self.x_field is None or self.y_field is None: return None x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) if x_arr.size == 0 or y_arr.size == 0: return None if x_arr.size != y_arr.size: raise ValueError( f"x_field and y_field must have the same length. " f"Got {x_arr.size} and {y_arr.size}" ) # Flatten arrays to handle any input shape, then stack x_flat = x_arr.flatten() y_flat = y_arr.flatten() return np.column_stack((x_flat, y_flat))
[docs] def load_inference_model(self): """Load the trained PSF model based on the inference configuration. Returns ------- Model The loaded trained PSF model. Notes ----- Constructs the weights path pattern based on the trained model path, model subdirectory, model name, id name, and cycle number specified in the configuration files. """ model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name id_name = self.training_config.training.id_name weights_path_pattern = os.path.join( model_path, model_dir, f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", ) # Load the trained PSF model return load_trained_psf_model( self.training_config, self.data_handler, weights_path_pattern, )
@property def n_bins_lambda(self): """Get the number of wavelength bins for inference. Returns ------- int The number of wavelength bins used during inference.""" if self._n_bins_lambda is None: self._n_bins_lambda = ( self.inference_config.inference.model_params.n_bins_lda ) return self._n_bins_lambda @property def batch_size(self): """ Get the batch size for inference. Returns ------- int The batch size for processing during inference. """ if self._batch_size is None: self._batch_size = self.inference_config.inference.batch_size assert self._batch_size > 0, "Batch size must be greater than 0." return self._batch_size @property def cycle(self): """Get the cycle number for inference. Returns ------- int The cycle number used for loading the trained model. """ if self._cycle is None: self._cycle = self.inference_config.inference.cycle return self._cycle @property def output_dim(self): """Get the output dimension for PSF inference. Returns ------- int The output dimension (height and width) of the inferred PSFs. """ if self._output_dim is None: self._output_dim = self.inference_config.inference.model_params.output_dim return self._output_dim def _prepare_positions_and_seds(self): """ Preprocess and return tensors for positions and SEDs with consistent shapes. Handles single-star, multi-star, and even scalar inputs, ensuring: - positions: shape (n_samples, 2) - sed_data: shape (n_samples, n_bins, 2) """ # Ensure x_field and y_field are at least 1D x_arr = np.atleast_1d(self.x_field) y_arr = np.atleast_1d(self.y_field) if x_arr.size != y_arr.size: raise ValueError( f"x_field and y_field must have the same length. " f"Got {x_arr.size} and {y_arr.size}" ) # Combine into positions array (n_samples, 2) positions = np.column_stack((x_arr, y_arr)) positions = tf.convert_to_tensor(positions, dtype=tf.float32) # Ensure SEDs have shape (n_samples, n_bins, 2) sed_data = ensure_batch(self.seds) if sed_data.shape[0] != positions.shape[0]: raise ValueError( f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}" ) if sed_data.shape[2] != 2: raise ValueError( f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}" ) # Process SEDs through the data handler self.data_handler.process_sed_data(sed_data) sed_data_tensor = self.data_handler.sed_data return positions, sed_data_tensor
[docs] def run_inference(self): """Run PSF inference and return the full PSF array. Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). Notes ----- Prepares configurations and input data, initializes the inference engine, and computes the PSF for all input positions. """ # Prepare the configuration for inference self.prepare_configs() # Prepare positions and SEDs for inference positions, sed_data = self._prepare_positions_and_seds() self.engine = PSFInferenceEngine( trained_model=self.trained_psf_model, batch_size=self.batch_size, output_dim=self.output_dim, ) return self.engine.compute_psfs(positions, sed_data)
def _ensure_psf_inference_completed(self): """Ensure that PSF inference has been completed. Runs inference if it has not been done yet. """ if self.engine is None or self.engine.inferred_psfs is None: self.run_inference()
[docs] def get_psfs(self): """Get all inferred PSFs. Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). Notes ----- Ensures automatically that inference has been completed before accessing the PSFs. """ self._ensure_psf_inference_completed() return self.engine.get_psfs()
[docs] def get_psf(self, index: int = 0) -> np.ndarray: """ Get the PSF at a specific index. Parameters ---------- index : int, optional Index of the PSF to retrieve (default is 0). Returns ------- numpy.ndarray The inferred PSF at the specified index with shape (output_dim, output_dim). Notes ----- Ensures automatically that inference has been completed before accessing the PSF. If only a single star was passed during instantiation, the index defaults to 0 and bounds checking is relaxed. """ self._ensure_psf_inference_completed() inferred_psfs = self.engine.get_psfs() # If a single-star batch, ignore index bounds if inferred_psfs.shape[0] == 1: return inferred_psfs[0] # Otherwise, return the PSF at the requested index return inferred_psfs[index]
[docs] def clear_cache(self): """ Clear all cached properties and reset the instance. This method resets all lazy-loaded properties, including the config handler, PSF simulator, data handler, trained model, and inference engine. Useful for freeing memory or forcing a fresh initialization. Notes ----- After calling this method, accessing any property will trigger re-initialization. """ self._config_handler = None self._simPSF = None self._data_handler = None self._trained_psf_model = None self._n_bins_lambda = None self._batch_size = None self._cycle = None self._output_dim = None self.engine = None
[docs] class PSFInferenceEngine: """Engine to perform PSF inference using a trained model. This class handles the batch-wise computation of PSFs using a trained PSF model. It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access. Parameters ---------- trained_model : Model The trained PSF model to use for inference. batch_size : int The batch size for processing during inference. output_dim : int The output dimension (height and width) of the inferred PSFs. Attributes ---------- trained_model : Model The trained PSF model used for inference. batch_size : int The batch size for processing during inference. output_dim : int The output dimension (height and width) of the inferred PSFs. Examples -------- .. code-block:: python >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) >>> psfs = engine.compute_psfs(positions, seds) >>> single_psf = engine.get_psf(0) """ def __init__(self, trained_model, batch_size: int, output_dim: int): self.trained_model = trained_model self.batch_size = batch_size self.output_dim = output_dim self._inferred_psfs = None @property def inferred_psfs(self) -> np.ndarray: """Access the cached inferred PSFs, if available. Returns ------- numpy.ndarray or None The cached inferred PSFs, or None if not yet computed. """ return self._inferred_psfs
[docs] def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: """Compute and cache PSFs for the input source parameters. Parameters ---------- positions : tf.Tensor Tensor of shape (n_samples, 2) containing the (x, y) positions sed_data : tf.Tensor Tensor of shape (n_samples, n_bins, 2) containing the SEDs Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). Notes ----- PSFs are computed in batches according to the specified batch_size. Results are cached internally for subsequent access via get_psfs() or get_psf(). """ n_samples = positions.shape[0] self._inferred_psfs = np.zeros( (n_samples, self.output_dim, self.output_dim), dtype=np.float32 ) # Initialize counter counter = 0 while counter < n_samples: # Calculate the batch end element end_sample = min(counter + self.batch_size, n_samples) # Define the batch positions batch_pos = positions[counter:end_sample, :] batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter counter = end_sample return self._inferred_psfs
[docs] def get_psfs(self) -> np.ndarray: """Get all the generated PSFs. Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs
[docs] def get_psf(self, index: int) -> np.ndarray: """Get the PSF at a specific index. Returns ------- numpy.ndarray The inferred PSF at the specified index with shape (output_dim, output_dim). Raises ------ ValueError If PSFs have not yet been computed. """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs[index]
[docs] def clear_cache(self): """ Clear cached inferred PSFs. Resets the internal PSF cache to free memory. After calling this method, compute_psfs() must be called again before accessing PSFs. """ self._inferred_psfs = None