wf_psf.inference.psf_inference
Inference.
A module which provides a PSFInference class to perform inference with trained PSF models. It is able to load a trained model, perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs.
- Authors:
Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>
Classes
|
Handle configuration loading and management for PSF inference. |
|
Perform PSF inference using a pre-trained WaveDiff model. |
|
Engine to perform PSF inference using a trained model. |
- class wf_psf.inference.psf_inference.InferenceConfigHandler(inference_config_path: str)[source]
Bases:
objectHandle configuration loading and management for PSF inference.
This class manages the loading of inference, training, and data configuration files required for PSF inference operations.
- Parameters:
inference_config_path (str) – Path to the inference configuration YAML file.
- inference_config
Loaded inference configuration.
- Type:
RecursiveNamespace or None
- training_config
Loaded training configuration.
- Type:
RecursiveNamespace or None
- data_config
Loaded data configuration.
- Type:
RecursiveNamespace or None
- trained_model_path
Path to the trained model directory.
- Type:
Path
- trained_model_config_path
Path to the training configuration file.
- Type:
Path
- Attributes:
- schema_mode
Methods
Load configuration files based on the inference config.
overwrite_model_params([training_config, ...])Overwrite training model_params with values from inference_config if available.
Extract and set the configuration paths from the inference config.
- ids = ('inference_conf',)
- load_configs()[source]
Load configuration files based on the inference config.
Loads the inference configuration first, then uses it to determine and load the training and data configurations.
Notes
- Updates the following attributes in-place:
inference_config
training_config
data_config (if data_config_path is specified)
- static overwrite_model_params(training_config=None, inference_config=None)[source]
Overwrite training model_params with values from inference_config if available.
- Parameters:
training_config (RecursiveNamespace) – Configuration object from training phase.
inference_config (RecursiveNamespace) – Configuration object from inference phase.
Notes
Updates are applied in-place to
training_config.model_params.
- property schema_mode: DatasetMode
- class wf_psf.inference.psf_inference.PSFInference(inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None)[source]
Bases:
objectPerform PSF inference using a pre-trained WaveDiff model.
This class handles the setup for PSF inference, including loading configuration files, instantiating the PSF simulator and data handler, and preparing the input data required for inference.
- Parameters:
inference_config_path (str) – Path to the inference configuration YAML file.
x_field (array-like, optional) – x coordinates in SHE convention.
y_field (array-like, optional) – y coordinates in SHE convention.
seds (array-like, optional) – Spectral energy distributions (SEDs).
sources (array-like, optional) – Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]).
masks (array-like, optional) – Corresponding masks for the sources (same shape as sources). Defaults to None.
- x_field
x coordinates for PSF positions.
- Type:
array-like or None
- y_field
y coordinates for PSF positions.
- Type:
array-like or None
- seds
Spectral energy distributions.
- Type:
array-like or None
- sources
Source postage stamps.
- Type:
array-like or None
- masks
Source masks.
- Type:
array-like or None
- engine
The inference engine instance.
- Type:
PSFInferenceEngine or None
Examples
Basic usage with position coordinates and SEDs:
psf_inf = PSFInference( inference_config_path="config.yaml", x_field=[100.5, 200.3], y_field=[150.2, 250.8], seds=sed_array ) psf_inf.run_inference() psf = psf_inf.get_psf(0)
- Attributes:
batch_sizeGet the batch size for inference.
config_handlerGet or create the configuration handler.
cycleGet the cycle number for inference.
data_configGet the data configuration.
inference_configGet the inference configuration.
inference_data_adapterCreate and return a DataAdapter for inference data using the factory.
model_data_adapterCreate and return a Model DataAdapter for loading trained PSF model using the factory.
n_bins_lambdaGet the number of wavelength bins for inference.
output_dimGet the output dimension for PSF inference.
simPSFGet or create the PSF simulator.
trained_psf_modelGet or load the trained PSF model.
training_configGet the training configuration.
Methods
Clear all cached properties and reset the instance.
Combine x_field and y_field into position pairs.
get_psf([index])Get the PSF at a specific index.
get_psfs()Get all inferred PSFs.
Load the trained PSF model based on the inference configuration.
Prepare the configuration for inference.
Run PSF inference and return the full PSF array.
- property batch_size
Get the batch size for inference.
- Returns:
The batch size for processing during inference.
- Return type:
- clear_cache()[source]
Clear all cached properties and reset the instance.
This method resets all lazy-loaded properties, including the config handler, PSF simulator, data handler, trained model, and inference engine. Useful for freeing memory or forcing a fresh initialization.
Notes
After calling this method, accessing any property will trigger re-initialization.
- property config_handler
Get or create the configuration handler.
- Returns:
The configuration handler instance with loaded configs.
- Return type:
- property cycle
Get the cycle number for inference.
- Returns:
The cycle number used for loading the trained model.
- Return type:
- property data_config
Get the data configuration.
- Returns:
The data configuration object, or None if not available.
- Return type:
RecursiveNamespace or None
- get_positions()[source]
Combine x_field and y_field into position pairs.
- Returns:
Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None.
- Return type:
- Raises:
ValueError – If x_field and y_field have different lengths.
- get_psf(index: int = 0) ndarray[source]
Get the PSF at a specific index.
- Parameters:
index (int, optional) – Index of the PSF to retrieve (default is 0).
- Returns:
The inferred PSF at the specified index with shape (output_dim, output_dim).
- Return type:
Notes
Ensures automatically that inference has been completed before accessing the PSF. If only a single star was passed during instantiation, the index defaults to 0 and bounds checking is relaxed.
- get_psfs()[source]
Get all inferred PSFs.
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
Notes
Ensures automatically that inference has been completed before accessing the PSFs.
- property inference_config
Get the inference configuration.
- Returns:
The inference configuration object.
- Return type:
- property inference_data_adapter
Create and return a DataAdapter for inference data using the factory.
- Returns:
A fully prepared data adapter with LoadedDataset ready for inference.
- Return type:
- load_inference_model()[source]
Load the trained PSF model based on the inference configuration.
- Returns:
The loaded trained PSF model.
- Return type:
Model
Notes
Constructs the weights path pattern based on the trained model path, model subdirectory, model name, id name, and cycle number specified in the configuration files.
- property model_data_adapter
Create and return a Model DataAdapter for loading trained PSF model using the factory.
- Returns:
A fully prepared model data adapter with LoadedDataset.
- Return type:
- property n_bins_lambda
Get the number of wavelength bins for inference.
- Returns:
The number of wavelength bins used during inference.
- Return type:
- property output_dim
Get the output dimension for PSF inference.
- Returns:
The output dimension (height and width) of the inferred PSFs.
- Return type:
- prepare_configs()[source]
Prepare the configuration for inference.
Overwrites training model parameters with inference configuration values.
- run_inference()[source]
Run PSF inference and return the full PSF array.
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
Notes
Prepares configurations and input data, initializes the inference engine, and computes the PSF for all input positions.
- property simPSF
Get or create the PSF simulator.
- Returns:
The PSF simulator instance.
- Return type:
simPSF
- property trained_psf_model
Get or load the trained PSF model.
- Returns:
The loaded trained PSF model.
- Return type:
Model
- property training_config
Get the training configuration.
- Returns:
The training configuration object.
- Return type:
- class wf_psf.inference.psf_inference.PSFInferenceEngine(trained_model, batch_size: int, output_dim: int)[source]
Bases:
objectEngine to perform PSF inference using a trained model.
This class handles the batch-wise computation of PSFs using a trained PSF model. It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access.
- Parameters:
- trained_model
The trained PSF model used for inference.
- Type:
Model
Examples
engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) psfs = engine.compute_psfs(positions, seds) single_psf = engine.get_psf(0)
- Attributes:
inferred_psfsAccess the cached inferred PSFs, if available.
Methods
Clear cached inferred PSFs.
compute_psfs(positions, sed_data)Compute and cache PSFs for the input source parameters.
get_psf(index)Get the PSF at a specific index.
get_psfs()Get all the generated PSFs.
- clear_cache()[source]
Clear cached inferred PSFs.
Resets the internal PSF cache to free memory. After calling this method, compute_psfs() must be called again before accessing PSFs.
- compute_psfs(positions: Tensor, sed_data: Tensor) ndarray[source]
Compute and cache PSFs for the input source parameters.
- Parameters:
positions (tf.Tensor) – Tensor of shape (n_samples, 2) containing the (x, y) positions
sed_data (tf.Tensor) – Tensor of shape (n_samples, n_bins, 2) containing the SEDs
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
Notes
PSFs are computed in batches according to the specified batch_size. Results are cached internally for subsequent access via get_psfs() or get_psf().
- get_psf(index: int) ndarray[source]
Get the PSF at a specific index.
- Returns:
numpy.ndarray
The inferred PSF at the specified index with shape (output_dim, output_dim).
- Raises:
ValueError – If PSFs have not yet been computed.
- get_psfs() ndarray[source]
Get all the generated PSFs.
- Returns:
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
- Return type:
- property inferred_psfs: ndarray
Access the cached inferred PSFs, if available.
- Returns:
The cached inferred PSFs, or None if not yet computed.
- Return type:
numpy.ndarray or None