Source code for shapepipe.modules.psfex_interp_package.psfex_interp

"""PSFEX INTERPOLATION SCRIPT.

This module computes the PSFs from a PSFEx model at several galaxy positions.

:Authors: Morgan Schmitz and Axel Guinot

"""

import os
import re

import numpy as np
from astropy.io import fits
from sqlitedict import SqliteDict

from shapepipe.pipeline import file_io

try:
    import galsim.hsm as hsm
    from galsim import Image
except ImportError:
    import_fail = True
else:
    import_fail = False


NOT_ENOUGH_STARS = 'Fail_stars'
BAD_CHI2 = 'Fail_chi2'
FILE_NOT_FOUND = 'File_not_found'


[docs]def interpsfex(dotpsfpath, pos, thresh_star, thresh_chi2): """Interpolate PSFEx. Use PSFEx generated model to perform spatial PSF interpolation. Parameters ---------- dotpsfpath : str Path to ``.psf`` file (PSFEx output) pos : numpy.ndarray Positions where the PSF model should be evaluated thresh_star : int Threshold of stars under which the PSF is not interpolated thresh_chi2 : int Threshold for chi squared Returns ------- numpy.ndarray Array of PSFs, each row is the PSF image at the corresponding position requested """ if not os.path.exists(dotpsfpath): return FILE_NOT_FOUND # read PSF model and extract basis and polynomial degree and scale position PSF_model = fits.open(dotpsfpath)[1] # Check number of stars used to compute the PSF if PSF_model.header['ACCEPTED'] < thresh_star: return NOT_ENOUGH_STARS if PSF_model.header['CHI2'] > thresh_chi2: return BAD_CHI2 PSF_basis = np.array(PSF_model.data)[0][0] try: deg = PSF_model.header['POLDEG1'] except KeyError: # constant PSF model return PSF_basis[0, :, :] # scale coordinates x_interp, x_scale = ( PSF_model.header['POLZERO1'], PSF_model.header['POLSCAL1'] ) y_interp, y_scale = ( PSF_model.header['POLZERO2'], PSF_model.header['POLSCAL2'] ) xs, ys = (pos[:, 0] - x_interp) / x_scale, (pos[:, 1] - y_interp) / y_scale # compute polynomial coefficients coeffs = np.array([[x ** idx for idx in range(deg + 1)] for x in xs]) cross_coeffs = np.array([ np.concatenate([ [(x ** idx_j) * (y ** idx_i) for idx_j in range(deg - idx_i + 1)] for idx_i in range(1, deg + 1) ]) for x, y in zip(xs, ys) ]) coeffs = np.hstack((coeffs, cross_coeffs)) # compute interpolated PSF PSFs = np.array([ np.sum( [coeff * atom for coeff, atom in zip(coeffs_posi, PSF_basis)], axis=0, ) for coeffs_posi in coeffs ]) return PSFs
[docs]class PSFExInterpolator(object): """The PSFEx Interpolator Class. This class uses a PSFEx output file to compute the PSF at desired positions. Parameters ---------- dotpsf_path : str Path to PSFEx output file galcat_path : str Path to SExtractor-like galaxy catalogue output_path : str Path to folder where output PSFs should be written img_number : str File number string w_log : logging.Logger Logging instance pos_params : list, optional Desired position parameters, ff provided, there should be exactly two, and they must also be present in the galaxy catalogue; otherwise, they are read directly from the ``.psf`` file. get_shapes : bool If ``True`` will compute shapes for the PSF model star_thresh : int Threshold of stars under which the PSF is not interpolated thresh_chi2 : int Threshold for chi squared """ def __init__( self, dotpsf_path, galcat_path, output_path, img_number, w_log, pos_params=None, get_shapes=True, star_thresh=20, chi2_thresh=2, ): # Path to PSFEx output file if ( isinstance(dotpsf_path, type(None)) or os.path.isfile(dotpsf_path) ): self._dotpsf_path = dotpsf_path else: raise ValueError(f'Cound not find file {dotpsf_path}.') # Path to catalogue containing galaxy positions if os.path.isfile(galcat_path): self._galcat_path = galcat_path else: raise ValueError(f'Cound not find file {galcat_path}.') # Path to output file to be written self._output_path = output_path + '/galaxy_psf' # Path to output file to be written for validation self._output_path_validation = output_path + '/validation_psf' # if required, compute and save shapes self._compute_shape = get_shapes # Number of stars under which we don't interpolate the PSF self._star_thresh = star_thresh self._chi2_thresh = chi2_thresh # Logging self._w_log = w_log # handle provided, but empty pos_params (for use within # CosmoStat's ShapePipe) if pos_params: if not len(pos_params) == 2: raise ValueError( f'{len(pos_params)} position parameters were passed on; ' + 'there should be exactly two.' ) self._pos_params = pos_params else: self._pos_params = None self.gal_pos = None self.interp_PSFs = None self._img_number = img_number
[docs] def process(self): """Process. Process the PSF interpolation single-epoch run. """ if self.gal_pos is None: self._get_galaxy_positions() if self.interp_PSFs is None: self._interpolate() if ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == NOT_ENOUGH_STARS ): self._w_log.info( 'Not enough stars to interpolate the psf in the file ' + f'{self._dotpsf_path}.' ) elif ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == BAD_CHI2 ): self._w_log.info( f'Bad chi2 for the psf model in the file {self._dotpsf_path}.' ) elif ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == FILE_NOT_FOUND ): self._w_log.info(f'Psf model file {self._dotpsf_path} not found.') else: if self._compute_shape: self._get_psfshapes() self._write_output()
[docs] def _get_position_parameters(self): """Get Position Parameters. Read position parameters from ``.psf`` file. """ dotpsf = file_io.FITSCatalogue(self._dotpsf_path) dotpsf.open() self._pos_params = [ dotpsf.get_header()['POLNAME1'], dotpsf.get_header()['POLNAME2'] ] dotpsf.close()
[docs] def _get_galaxy_positions(self): """Get Galaxy Positions. Extract galaxy positions from galaxy catalogue. """ if self._pos_params is None: self._get_position_parameters() galcat = file_io.FITSCatalogue(self._galcat_path, SEx_catalogue=True) galcat.open() try: self.gal_pos = np.array([ [x, y] for x, y in zip( galcat.get_data()[self._pos_params[0]], galcat.get_data()[self._pos_params[1]] ) ]) self._w_log.info( f'Read {self.gal_pos.shape[0]} positions from galaxy catalog' ) except KeyError as detail: # extract erroneous position parameter from original exception err_pos_param = detail.args[0][4:-15] pos_param_err = ( f'Required position parameter {err_pos_param}' + 'was not found in galaxy catalog. Leave ' + 'pos_params (or EXTRA_CODE_OPTION) blank to ' + 'read them from .psf file.' ) raise KeyError(pos_param_err) galcat.close()
[docs] def _interpolate(self): """Interpolate. Run Sheldon and Rykoff's PSFEx interpolator method at the desired positions. """ self.interp_PSFs = interpsfex( self._dotpsf_path, self.gal_pos, self._star_thresh, self._chi2_thresh, )
[docs] def _get_psfshapes(self): """Get PSF Shapes. Compute shapes of PSF at galaxy positions using HSM. """ if import_fail: raise ImportError('Galsim is required to get shapes information') psf_moms = [ hsm.FindAdaptiveMom(Image(psf), strict=False) for psf in self.interp_PSFs ] self.psf_shapes = np.array([ [ moms.observed_shape.g1, moms.observed_shape.g2, moms.moments_sigma, int(bool(moms.error_message)) ] for moms in psf_moms ])
[docs] def _write_output(self): """Write Output. Save computed PSFs to a FITS file. """ output = file_io.FITSCatalogue( self._output_path + self._img_number + '.fits', open_mode=file_io.BaseCatalogue.OpenMode.ReadWrite, SEx_catalogue=True, ) if self._compute_shape: data = { 'VIGNET': self.interp_PSFs, 'E1_PSF_HSM': self.psf_shapes[:, 0], 'E2_PSF_HSM': self.psf_shapes[:, 1], 'SIGMA_PSF_HSM': self.psf_shapes[:, 2], 'FLAG_PSF_HSM': self.psf_shapes[:, 3].astype(int) } else: data = {'VIGNET': self.interp_PSFs} output.save_as_fits(data, sex_cat_path=self._galcat_path)
[docs] def process_validation(self, psfex_cat_path): """Process Validation. Process validation steps. Parameters ---------- str Path to PSFEx catalogue """ if not os.path.isfile(psfex_cat_path): raise ValueError(f'Cound not find file {psfex_cat_path}.') if self.gal_pos is None: self._get_galaxy_positions() if self.interp_PSFs is None: self._interpolate() if ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == NOT_ENOUGH_STARS ): self._w_log.info( 'Not enough stars to interpolate the psf in the file ' + f'{self._dotpsf_path}.' ) elif ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == BAD_CHI2 ): self._w_log.info( f'Bad chi2 for the psf model in the file {self._dotpsf_path}.' ) elif ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == FILE_NOT_FOUND ): self._w_log.info(f'Psf model file {self._dotpsf_path} not found.') else: star_cat = file_io.FITSCatalogue( self._galcat_path, SEx_catalogue=True, ) star_cat.open() star_dict = {} star_vign = np.copy(star_cat.get_data()['VIGNET']) star_dict['NUMBER'] = np.copy(star_cat.get_data()['NUMBER']) star_dict['X'] = np.copy(star_cat.get_data()['XWIN_IMAGE']) star_dict['Y'] = np.copy(star_cat.get_data()['YWIN_IMAGE']) star_dict['RA'] = np.copy(star_cat.get_data()['XWIN_WORLD']) star_dict['DEC'] = np.copy(star_cat.get_data()['YWIN_WORLD']) star_dict['MAG'] = np.copy(star_cat.get_data()['MAG_AUTO']) star_dict['SNR'] = np.copy(star_cat.get_data()['SNR_WIN']) star_cat.close() self._get_psfshapes() self._get_starshapes(star_vign) psfex_cat_dict = self._get_psfexcatdict(psfex_cat_path) self._write_output_validation(star_dict, psfex_cat_dict)
[docs] def _get_starshapes(self, star_vign): """Get Star Shapes. Compute shapes of stars at stars positions using HSM. Parameters ---------- numpy.ndarray Array containing the star's vignets. """ if import_fail: raise ImportError('Galsim is required to get shapes information') masks = np.zeros_like(star_vign) masks[np.where(star_vign == -1e30)] = 1 star_moms = [ hsm.FindAdaptiveMom(Image(star), badpix=Image(mask), strict=False) for star, mask in zip(star_vign, masks) ] self.star_shapes = np.array([ [ moms.observed_shape.g1, moms.observed_shape.g2, moms.moments_sigma, int(bool(moms.error_message)) ] for moms in star_moms ])
[docs] def _get_psfexcatdict(self, psfex_cat_path): """Get PSFEx Catalogue Dictionary. Get data from PSFEx ``.cat`` file. Parameters ---------- psfex_cat_path : str Path to the ``.cat`` file from PSFEx Returns ------- dict Dictionary containing information from the PFSEx ``.cat`` file """ psfex_cat = file_io.FITSCatalogue(psfex_cat_path, SEx_catalogue=True) psfex_cat.open() psfex_cat_dict = {} psfex_cat_dict['SOURCE_NUMBER'] = np.copy( psfex_cat.get_data()['SOURCE_NUMBER'] ) psfex_cat_dict['DELTAX_IMAGE'] = np.copy( psfex_cat.get_data()['DELTAX_IMAGE'] ) psfex_cat_dict['DELTAY_IMAGE'] = np.copy( psfex_cat.get_data()['DELTAY_IMAGE'] ) psfex_cat_dict['CHI2_PSF'] = np.copy( psfex_cat.get_data()['CHI2_PSF'] ) return psfex_cat_dict
[docs] def _write_output_validation(self, star_dict, psfex_cat_dict): """Write Output Validation. Save computed PSFs and stars to fits file. Parameters ---------- star_dict : dict Dictionary containing star information psfex_cat_dict : dict Dictionary containing information from the PFSEx ``.cat`` file """ output = file_io.FITSCatalogue( self._output_path_validation + self._img_number + '.fits', open_mode=file_io.BaseCatalogue.OpenMode.ReadWrite, SEx_catalogue=True, ) data = { 'E1_PSF_HSM': self.psf_shapes[:, 0], 'E2_PSF_HSM': self.psf_shapes[:, 1], 'SIGMA_PSF_HSM': self.psf_shapes[:, 2], 'FLAG_PSF_HSM': self.psf_shapes[:, 3].astype(int), 'E1_STAR_HSM': self.star_shapes[:, 0], 'E2_STAR_HSM': self.star_shapes[:, 1], 'SIGMA_STAR_HSM': self.star_shapes[:, 2], 'FLAG_STAR_HSM': self.star_shapes[:, 3].astype(int) } data = {**data, **star_dict} data['ACCEPTED'] = np.ones_like(data['NUMBER'], dtype='int16') star_used = psfex_cat_dict.pop('SOURCE_NUMBER') for idx in range(len(data['NUMBER'])): if idx + 1 not in star_used: data['ACCEPTED'][idx] = 0 output.save_as_fits(data, sex_cat_path=self._galcat_path)
[docs] def process_me(self, dot_psf_dir, dot_psf_pattern, f_wcs_path): """Process Multi-Epoch. Process the multi-epoch. Parameters ---------- dot_psf_dir : str Path to the directory containing the ``.psf`` files dot_psf_pattern : str Common pattern of the ``.psf`` files f_wcs_path : str Path to the log file containing the WCS for each CCDs """ if os.path.exists(dot_psf_dir): self._dot_psf_dir = dot_psf_dir else: raise ValueError(f'Cound not find directory {dot_psf_dir}.') self._dot_psf_pattern = dot_psf_pattern if os.path.isfile(f_wcs_path): self._f_wcs_file = SqliteDict(f_wcs_path) else: raise ValueError(f'Cound not find file {f_wcs_path}.') if self.gal_pos is None: self._get_galaxy_positions() output_dict = self._interpolate_me() self._write_output_me(output_dict)
[docs] def _interpolate_me(self): """Interpolate Multi-Epoch. Interpolate PSFs for multi-epoch run. Returns ------- dict Dictionnary containing object Ids, the interpolated PSFs and shapes (optionally) """ cat = file_io.FITSCatalogue(self._galcat_path, SEx_catalogue=True) cat.open() all_id = np.copy(cat.get_data()['NUMBER']) n_epoch = np.copy(cat.get_data()['N_EPOCH']) list_ext_name = cat.get_ext_name() hdu_ind = [ idx for idx in range(len(list_ext_name)) if 'EPOCH' in list_ext_name[idx] ] final_list = [] for hdu_index in hdu_ind: exp_name = cat.get_data(hdu_index)['EXP_NAME'][0] ccd_list = list(set(cat.get_data(hdu_index)['CCD_N'])) array_psf = None array_id = None array_shape = None array_exp_name = None for ccd in ccd_list: if ccd == -1: continue dot_psf_path = ( f'{self._dot_psf_dir}/{self._dot_psf_pattern}-{exp_name}' + f'-{ccd}.psf' ) ind_obj = np.where(cat.get_data(hdu_index)['CCD_N'] == ccd)[0] obj_id = all_id[ind_obj] gal_pos = np.array( self._f_wcs_file[exp_name][ccd]['WCS'].all_world2pix( self.gal_pos[:, 0][ind_obj], self.gal_pos[:, 1][ind_obj], 0, ) ).T self.interp_PSFs = interpsfex( dot_psf_path, gal_pos, self._star_thresh, self._chi2_thresh, ) if ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == NOT_ENOUGH_STARS ): self._w_log.info( f'Not enough stars find in the ccd {ccd} of the ' + f'exposure {exp_name}. Object inside this ccd will ' + 'lose an epoch.' ) continue if ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == BAD_CHI2 ): self._w_log.info( f'Bad chi2 for the psf model in the ccd {ccd} of the ' + f'exposure {exp_name}. Object inside this ccd will ' + 'lose an epoch.' ) continue if ( isinstance(self.interp_PSFs, str) and self.interp_PSFs == FILE_NOT_FOUND ): self._w_log.info( f'Psf model file {self._dotpsf_path} not found. ' + 'Object inside this ccd will lose an epoch.' ) continue if array_psf is None: array_psf = np.copy(self.interp_PSFs) else: array_psf = np.concatenate( (array_psf, np.copy(self.interp_PSFs)) ) if array_id is None: array_id = np.copy(obj_id) else: array_id = np.concatenate((array_id, np.copy(obj_id))) if self._compute_shape: self._get_psfshapes() if array_shape is None: array_shape = np.copy(self.psf_shapes) else: array_shape = np.concatenate(( array_shape, np.copy(self.psf_shapes), )) else: array_shape = None exp_name_tmp = np.array([ exp_name + '-' + str(ccd) for _ in range(len(obj_id)) ]) if array_exp_name is None: array_exp_name = exp_name_tmp else: array_exp_name = np.concatenate( (array_exp_name, exp_name_tmp) ) final_list.append([ array_id, array_psf, array_shape, array_exp_name ]) self._f_wcs_file.close() cat.close() output_dict = {} n_empty = 0 for id_tmp in all_id: output_dict[id_tmp] = {} counter = 0 for j in range(len(final_list)): where_res = np.where(final_list[j][0] == id_tmp)[0] if (len(where_res) != 0): output_dict[id_tmp][final_list[j][3][where_res[0]]] = {} output_dict[id_tmp][ final_list[j][3][where_res[0]] ]['VIGNET'] = final_list[j][1][where_res[0]] if self._compute_shape: shape_dict = {} shape_dict['E1_PSF_HSM'] = ( final_list[j][2][where_res[0]][0] ) shape_dict['E2_PSF_HSM'] = ( final_list[j][2][where_res[0]][1] ) shape_dict['SIGMA_PSF_HSM'] = ( final_list[j][2][where_res[0]][2] ) shape_dict['FLAG_PSF_HSM'] = ( final_list[j][2][where_res[0]][3] ) output_dict[id_tmp][ final_list[j][3][where_res[0]] ]['SHAPES'] = shape_dict counter += 1 if counter == 0: output_dict[id_tmp] = 'empty' n_empty += 1 self._w_log.info(f'{n_empty}/{len(all_id)} PSFs are empty') return output_dict
[docs] def _write_output_me(self, output_dict): """Write Output Multi-Epoch. Save computed PSFs to numpy object file for multi-epoch run. Parameters ---------- output_dict : dict Dictionnary of outputs to save """ output_file = SqliteDict( self._output_path + self._img_number + '.sqlite' ) for idx in output_dict.keys(): output_file[str(idx)] = output_dict[idx] output_file.commit() output_file.close()