wf_psf.inference.psf_inference
Inference.
A module which provides a PSFInference class to perform inference with trained PSF models. It is able to load a trained model, perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs.
- Authors:
Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>
Classes
|
Handle configuration loading and management for PSF inference. |
|
Perform PSF inference using a pre-trained WaveDiff model. |
|
Engine to perform PSF inference using a trained model. |
- class wf_psf.inference.psf_inference.InferenceConfigHandler(inference_config_path: str)[source]
Bases:
objectHandle configuration loading and management for PSF inference.
This class manages the loading of inference, training, and data configuration files required for PSF inference operations.
- Parameters:
inference_config_path (str) – Path to the inference configuration YAML file.
- inference_config
Loaded inference configuration.
- Type:
RecursiveNamespace or None
- training_config
Loaded training configuration.
- Type:
RecursiveNamespace or None
- data_config
Loaded data configuration.
- Type:
RecursiveNamespace or None
- trained_model_path
Path to the trained model directory.
- Type:
Path
- trained_model_config_path
Path to the training configuration file.
- Type:
Path
Methods
Load configuration files based on the inference config.
overwrite_model_params([training_config, ...])Overwrite training model_params with values from inference_config if available.
Extract and set the configuration paths from the inference config.
- ids = ('inference_conf',)
- load_configs()[source]
Load configuration files based on the inference config.
Loads the inference configuration first, then uses it to determine and load the training and data configurations.
Notes
Updates the following attributes in-place: - inference_config - training_config - data_config (if data_config_path is specified)
- static overwrite_model_params(training_config=None, inference_config=None)[source]
Overwrite training model_params with values from inference_config if available.
- Parameters:
training_config (RecursiveNamespace) – Configuration object from training phase.
inference_config (RecursiveNamespace) – Configuration object from inference phase.
Notes
Updates are applied in-place to training_config.training.model_params.
- class wf_psf.inference.psf_inference.PSFInference(inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None)[source]
Bases:
objectPerform PSF inference using a pre-trained WaveDiff model.
This class handles the setup for PSF inference, including loading configuration files, instantiating the PSF simulator and data handler, and preparing the input data required for inference.
- Parameters:
inference_config_path (str) – Path to the inference configuration YAML file.
x_field (array-like, optional) – x coordinates in SHE convention.
y_field (array-like, optional) – y coordinates in SHE convention.
seds (array-like, optional) – Spectral energy distributions (SEDs).
sources (array-like, optional) – Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]).
masks (array-like, optional) – Corresponding masks for the sources (same shape as sources). Defaults to None.
- x_field
x coordinates for PSF positions.
- Type:
array-like or None
- y_field
y coordinates for PSF positions.
- Type:
array-like or None
- seds
Spectral energy distributions.
- Type:
array-like or None
- sources
Source postage stamps.
- Type:
array-like or None
- masks
Source masks.
- Type:
array-like or None
- engine
The inference engine instance.
- Type:
PSFInferenceEngine or None
Examples
Basic usage with position coordinates and SEDs:
psf_inf = PSFInference( inference_config_path="config.yaml", x_field=[100.5, 200.3], y_field=[150.2, 250.8], seds=sed_array ) psf_inf.run_inference() psf = psf_inf.get_psf(0)
- Attributes:
batch_sizeGet the batch size for inference.
config_handlerGet or create the configuration handler.
cycleGet the cycle number for inference.
data_configGet the data configuration.
data_handlerGet or create the data handler.
inference_configGet the inference configuration.
n_bins_lambdaGet the number of wavelength bins for inference.
output_dimGet the output dimension for PSF inference.
simPSFGet or create the PSF simulator.
trained_psf_modelGet or load the trained PSF model.
training_configGet the training configuration.
Methods
Clear all cached properties and reset the instance.
Combine x_field and y_field into position pairs.
get_psf([index])Get the PSF at a specific index.
get_psfs()Get all inferred PSFs.
Load the trained PSF model based on the inference configuration.
Prepare the configuration for inference.
Run PSF inference and return the full PSF array.
- property batch_size
Get the batch size for inference.
- Returns:
The batch size for processing during inference.
- Return type:
- clear_cache()[source]
Clear all cached properties and reset the instance.
This method resets all lazy-loaded properties, including the config handler, PSF simulator, data handler, trained model, and inference engine. Useful for freeing memory or forcing a fresh initialization.
Notes
After calling this method, accessing any property will trigger re-initialization.
- property config_handler
Get or create the configuration handler.
- Returns:
The configuration handler instance with loaded configs.
- Return type:
- property cycle
Get the cycle number for inference.
- Returns:
The cycle number used for loading the trained model.
- Return type:
- property data_config
Get the data configuration.
- Returns:
The data configuration object, or None if not available.
- Return type:
RecursiveNamespace or None
- property data_handler
Get or create the data handler.
- Returns:
The data handler instance configured for inference.
- Return type:
- get_positions()[source]
Combine x_field and y_field into position pairs.
- Returns:
Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None.
- Return type:
- Raises:
ValueError – If x_field and y_field have different lengths.
- get_psf(index: int = 0) ndarray[source]
Get the PSF at a specific index.
- Parameters:
index (int, optional) – Index of the PSF to retrieve (default is 0).
- Returns:
The inferred PSF at the specified index with shape (output_dim, output_dim).
- Return type:
Notes
Ensures automatically that inference has been completed before accessing the PSF. If only a single star was passed during instantiation, the index defaults to 0 and bounds checking is relaxed.
- get_psfs()[source]
Get all inferred PSFs.
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
Notes
Ensures automatically that inference has been completed before accessing the PSFs.
- property inference_config
Get the inference configuration.
- Returns:
The inference configuration object.
- Return type:
- load_inference_model()[source]
Load the trained PSF model based on the inference configuration.
- Returns:
The loaded trained PSF model.
- Return type:
Model
Notes
Constructs the weights path pattern based on the trained model path, model subdirectory, model name, id name, and cycle number specified in the configuration files.
- property n_bins_lambda
Get the number of wavelength bins for inference.
- Returns:
The number of wavelength bins used during inference.
- Return type:
- property output_dim
Get the output dimension for PSF inference.
- Returns:
The output dimension (height and width) of the inferred PSFs.
- Return type:
- prepare_configs()[source]
Prepare the configuration for inference.
Overwrites training model parameters with inference configuration values.
- run_inference()[source]
Run PSF inference and return the full PSF array.
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
Notes
Prepares configurations and input data, initializes the inference engine, and computes the PSF for all input positions.
- property simPSF
Get or create the PSF simulator.
- Returns:
The PSF simulator instance.
- Return type:
simPSF
- property trained_psf_model
Get or load the trained PSF model.
- Returns:
The loaded trained PSF model.
- Return type:
Model
- property training_config
Get the training configuration.
- Returns:
The training configuration object.
- Return type:
- class wf_psf.inference.psf_inference.PSFInferenceEngine(trained_model, batch_size: int, output_dim: int)[source]
Bases:
objectEngine to perform PSF inference using a trained model.
This class handles the batch-wise computation of PSFs using a trained PSF model. It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access.
- Parameters:
- trained_model
The trained PSF model used for inference.
- Type:
Model
Examples
>>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) >>> psfs = engine.compute_psfs(positions, seds) >>> single_psf = engine.get_psf(0)
- Attributes:
inferred_psfsAccess the cached inferred PSFs, if available.
Methods
Clear cached inferred PSFs.
compute_psfs(positions, sed_data)Compute and cache PSFs for the input source parameters.
get_psf(index)Get the PSF at a specific index.
get_psfs()Get all the generated PSFs.
- clear_cache()[source]
Clear cached inferred PSFs.
Resets the internal PSF cache to free memory. After calling this method, compute_psfs() must be called again before accessing PSFs.
- compute_psfs(positions: Tensor, sed_data: Tensor) ndarray[source]
Compute and cache PSFs for the input source parameters.
- Parameters:
positions (tf.Tensor) – Tensor of shape (n_samples, 2) containing the (x, y) positions
sed_data (tf.Tensor) – Tensor of shape (n_samples, n_bins, 2) containing the SEDs
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
Notes
PSFs are computed in batches according to the specified batch_size. Results are cached internally for subsequent access via get_psfs() or get_psf().
- get_psf(index: int) ndarray[source]
Get the PSF at a specific index.
- Returns:
numpy.ndarray
The inferred PSF at the specified index with shape (output_dim, output_dim).
- Raises:
ValueError – If PSFs have not yet been computed.
- get_psfs() ndarray[source]
Get all the generated PSFs.
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
- property inferred_psfs: ndarray
Access the cached inferred PSFs, if available.
- Returns:
The cached inferred PSFs, or None if not yet computed.
- Return type:
numpy.ndarray or None