Source code for scarlet2.init

"""Helper methods to initialize sources"""

import operator
import warnings
from dataclasses import dataclass
from functools import reduce

import astropy
import astropy.units as u
import jax.numpy as jnp
from equinox import tree_pformat

from . import Scenery, measure
from .bbox import Box, insert_into
from .detect import HierarchicalFootprint, hierarchical_footprints
from .frame import get_relative_jacobian_shift
from .morphology import GaussianMorphology
from .observation import Observation


# function to calculate values of perimeter pixels
def _get_edge_pixels(img, box):
    box_slices = box.slices
    box_start = box.start
    box_stop = box.stop
    edge = [
        img[:, box_slices[0], box_start[1]],
        img[:, box_slices[0], box_stop[1] - 1],
        img[:, box_start[0], box_slices[1]],
        img[:, box_stop[0] - 1, box_slices[1]],
    ]
    return jnp.concatenate(edge, axis=1)


[docs] def best_box_moments(obs, center_pix, sizes=[11, 17, 25, 35, 47, 61, 77], min_snr=3, min_corr=0.99): # noqa: B006 """Make a bounding box for source at center This method finds small box around the center so that the edge flux has a minimum SNR, the color of the edge pixels remains highly correlated to the center pixel color, and the flux in all channels is lower than with the last acceptable box size. This should effectively isolate the source against the noise background and neighboring objects. Parameters ---------- obs: :py:class:`~scarlet2.Observation` The Observation instance to use when defining the bounding box. center_pix: tuple source center, in pixel coordinates sizes: list[int] a list of box sizes to cycle through min_snr: float minimum SNR of edge pixels (aggregated over all observation channel) to allow increase of box size min_corr: float minimum correlation coefficient between center and edge color to allow increase of box size Returns ------- :py:class:`~scarlet2.Box`, :py:class:`~scarlet2.Moments` """ assert isinstance(obs, Observation) assert obs.weights is not None, "Observation weights are required" assert obs.frame.bbox.spatial.contains(center_pix), ( f"Center pixel {center_pix} not contained in observation" ) assert len(sizes) > 0 # increase box size from list until SNR is below threshold or spectrum changes significantly peak_spectrum = pixel_spectrum(obs, center_pix, correct_psf=True) last_spectrum = jnp.empty(len(peak_spectrum)) moments = [] for i in range(len(sizes)): box2d = Box((sizes[i], sizes[i])) box2d.set_center(center_pix.astype(int)) edge_pixels = _get_edge_pixels(obs.data, box2d) valid_edge_pixel = edge_pixels != 0 edge_spectrum = jnp.sum(edge_pixels, axis=-1) / jnp.sum(valid_edge_pixel, axis=-1) weight_edge_pixels = _get_edge_pixels(obs.weights, box2d) snr_edge_pixels = jnp.abs(edge_pixels) * jnp.sqrt(weight_edge_pixels) mean_snr = jnp.sum(jnp.sum(snr_edge_pixels, axis=-1) / jnp.sum(valid_edge_pixel, axis=-1)) spec_corr = ( jnp.dot(edge_spectrum, peak_spectrum) / jnp.sqrt(jnp.dot(peak_spectrum, peak_spectrum)) / jnp.sqrt(jnp.dot(edge_spectrum, edge_spectrum)) ) m = boxed_moments(obs, center_pix, bbox=box2d) if mean_snr < min_snr or max(box2d.shape) > max(obs.frame.bbox.spatial.shape): break if i > 0 and ( spec_corr < min_corr or jnp.any(edge_spectrum > last_spectrum) or jnp.any(jnp.isnan(m.size)) ): box2d = Box((sizes[i - 1], sizes[i - 1])) box2d.set_center(center_pix.astype(int)) m = moments[-1] break last_spectrum = edge_spectrum moments.append(m) box = obs.frame.bbox[0] @ box2d m = standardize_moments(m, obs) return box, m
[docs] def boxed_moments( obs, center, footprint=None, bbox=None, ): """Measure 2nd moments of the observation in the region of the bounding box The methods cuts out the pixel included in `bbox` or in `footprint`, measures their 2nd moments with respect to `center`, adjust the spatial coordinates and the PSF to match the model frame. Parameters ---------- obs: :py:class:`~scarlet2.Observation` The Observation instance to derive moments from center: tuple central pixel of the source footprint: array, optional 2D image with non-zero values for all pixels associated with this source (aka a segmentation map or footprint) bbox: :py:class:`~scarlet2.BBox`, optional box to cut out source from observation, in pixel coordinates Returns ------- :py:class:`~scarlet2.Moments` 2nd moments, corrected to match the model frame Raises ------ AssertionError If neither `bbox` or `footprint` are set """ assert isinstance(obs, Observation) # construct box from footprint if bbox is None: assert footprint is not None bbox = Box.from_data(footprint) # construct footprint as step function inside bbox if footprint is None: assert bbox is not None footprint = jnp.zeros(obs.frame.bbox.spatial.shape) footprint = footprint.at[bbox.spatial.slices].set(1) if bbox.D == 2: bbox = obs.frame.bbox[0] @ bbox # cutout image and footprint cutout_img = obs.data[bbox.slices] cutout_fp = footprint[bbox.spatial.slices] center_ = center - jnp.array(bbox.spatial.origin) m = measure.Moments(cutout_img, center=center_, weight=cutout_fp[None, :, :], N=2) # m = standardize_moments(m, obs) return m
[docs] def standardize_moments(g, obs): """Standardize the 2nd moments to match the model frame Adjust the spatial coordinates and deconvolved from the observation PSF to mimic moments measured in the model frame. Parameters ---------- g: :py:class:`~scarlet2.Moments` 2nd moments, measured in observed frame obs: :py:class:`~scarlet2.Observation` The Observation instance to derive moments from center: tuple central pixel of the source footprint: array, optional 2D image with non-zero values for all pixels associated with this source (aka a segmentation map or footprint) bbox: :py:class:`~scarlet2.BBox`, optional box to cut out source from observation, in pixel coordinates Returns ------- :py:class:`~scarlet2.Moments` 2nd moments, deconvolved, and in the coordinates of the model frame """ try: frame = Scenery.scene.frame except AttributeError: print("Adaptive morphology can only be created within the context of a Scene") print("Use 'with Scene(frame) as scene: Source(...)'") raise # adjust moments for model frame g.transfer(obs.frame.wcs, frame.wcs) # deconvolve from PSF (actually: difference kernel between obs PSF and model frame PSF) if frame.psf is not None and obs.frame.psf is not None: if hasattr(obs, "_dp"): p = obs._dp else: # moments of difference kernel between the model PSF and the observed PSF p = measure.Moments(obs.frame.psf(), N=2) p.transfer(obs.frame.wcs, frame.wcs) p0 = measure.Moments(frame.psf(), N=2) p.deconvolve(p0) # store in obs for repeated use object.__setattr__(obs, "_dp", p) # deconvolve from difference kernel g.deconvolve(p) return g
[docs] def from_gaussian_moments( obs, center, box_sizes=None, min_snr=3, min_corr=0.99, min_value=1e-6, max_value=1 - 1e-6, ): """Create a Gaussian-shaped morphology and associated spectrum from the observation(s). The method determines an suitable bounding box that contains the source given its `center`, computes the deconvolved moments up to order 2, constructs the spectrum from the 0th moment and a morphology image from the 2nd moments (assuming a Gaussian shape). If multiple observations are given, it takes the median of the moments in the same channel. Parameters ---------- obs: :py:class:`~scarlet2.Observation` or list Observation from which the source is initialized. center: tuple Central pixel of the source box_sizes: None, list[int] or list[SkyCoord] A list of box sizes to choose from. If `None`, chooses multiples of the PSF FWHM from `obs`. min_snr: float Minimum SNR of edge pixels (aggregated over all observation channel) to allow increase of box size min_corr: float Minimum correlation coefficient between center and edge color to allow increase of box size min_value: float Minimum pixel value (useful to set to > 0 for positivity constraints) max_value: float Minimum pixel value (useful to set to < 1 for unit interval constraints) Returns ------- (array,array) Spectrum and morphology arrays Warnings -------- This method is stable only for isolated sources. In cases of significant blending, the size of the bounding box and the measured moments are likely biased high. See Also -------- make_bbox: Defines bounding box that contains the source standardized_moments: Computes 2nd moments for source in bounding box """ try: frame = Scenery.scene.frame except AttributeError: print("from_gaussian_moments() can only be called within the context of a Scene") print("Use 'with Scene(frame) as scene: Source(...)'") raise # TODO: implement with source footprints given for each observation # get moments from all channels in all observations observations = obs if isinstance(obs, (list, tuple)) else (obs,) # centers and box_sizes are defined in pixel in the model frame # therefore need to convert back to skycoord and convert to pixel in obs frames centers = [obs_.frame.get_pixel(center) for obs_ in observations] if box_sizes is None: # growing sizes in units of the observed PSF psf_sizes = [ measure.fwhm(obs.frame.psf()).min() if obs.frame.psf is not None else 1 for obs in observations ] # in obs pixels magic_number = lambda i: 6.0 if i == 0 else 1.5 * magic_number(i - 1) # noqa:E731 # NOTE: Not forced to be odd box_sizes = [[int(psf_size * magic_number(i)) for i in range(10)] for psf_size in psf_sizes] else: assert len(box_sizes) > 0 if u.get_physical_type(box_sizes[0]) == "angle": box_sizes = [[obs.frame.u_to_pixel(size) for size in box_sizes] for obs in observations] else: # assume that all pixels are in proper observed frame box_sizes = [box_sizes for obs in observations] boxes_moments = [ best_box_moments(obs_, center_, sizes=sizes_, min_snr=min_snr, min_corr=min_corr) for obs_, center_, sizes_ in zip(observations, centers, box_sizes, strict=False) ] # flat lists of spectra, sorted in order of model frame channels spectra = jnp.concatenate([m[0, 0] for box, m in boxes_moments]) channels = reduce(operator.add, [obs_.frame.channels for obs_ in observations]) spectrum = _sort_spectra(spectra, channels) # average over all channels moments = [m.normalize() for box, m in boxes_moments] # flux normalization m = moments[0] # moments from first observation for key in m: m[key] = jnp.concatenate([m[key] for box, m in boxes_moments]) # combine all observations m[key] = jnp.median( m[key] ) # this is not SNR weighted nor consistent aross different moments, but works(?) # average box size across observations if frame.wcs is not None: size = jnp.mean( jnp.array( [ frame.u_to_pixel(obs.frame.pixel_to_angle(max(box.spatial.shape))) for (box, moments), obs in zip(boxes_moments, observations, strict=False) ] ) ).astype(int) else: size = jnp.mean(jnp.array([max(box.spatial.shape) for box, moments in boxes_moments])).astype(int) # create morphology and evaluate at center morph = GaussianMorphology.from_moments(m, shape=(size, size)) morph = morph() spectrum /= morph.sum() morph = jnp.minimum(jnp.maximum(morph, min_value), max_value) return spectrum, morph
[docs] def compact_morphology(min_value=1e-6, max_value=1 - 1e-6): """Create image of the point source morphology model, i.e. the most compact source possible Parameters ---------- min_value: float Minimum pixel value (needed for positively constrained morphologies) max_value: float minimum pixel value (useful to set to < 1 for unit interval constraints) Returns ------- array 2D array, normalized to the range [0,1] """ try: frame = Scenery.scene.frame except AttributeError: print("Compact morphology can only be created within the context of a Scene") print("Use 'with Scene(frame) as scene: Source(...)'") raise if frame.psf is None: raise AttributeError("Compact morphology can only be create with a PSF in the model frame") morph = frame.psf.morphology() morph = jnp.minimum(jnp.maximum(morph, min_value), max_value) return morph
# initialise the spectrum
[docs] def pixel_spectrum(obs, pos, correct_psf=False): """Get the spectrum at a given position in the observation(s). Yields the spectrum of a single-pixel source with flux 1 in every channel, concatenated for all observations. Parameters ---------- obs: `:py:class:`~scarlet2.Observation` or list Observation(s) to extract pixel SED from pos: tuple Position in the observation. Needs to be in sky coordinates if multiple observations have different locations or pixel scales. correct_psf: bool, optional Whether PSF shape variations in the observations should be corrected. If `True`, this method homogenizes the PSFs of the observations, which yields the correct spectrum for a flux=1 point source. Returns ------- array or list If `obs` is a list, the method returns the associate list of spectra. """ # for multiple observations, get spectrum from each observation and then # combine channels in order of model frame if isinstance(obs, (list, tuple)): # flat lists of spectra and channels in order of observations spectra = jnp.concatenate([pixel_spectrum(obs_, pos, correct_psf=correct_psf) for obs_ in obs]) channels = reduce(operator.add, [obs_.frame.channels for obs_ in obs]) spectrum = _sort_spectra(spectra, channels) return spectrum assert isinstance(obs, Observation) pixel = obs.frame.get_pixel(pos).astype(int) if not obs.frame.bbox.spatial.contains(pixel): raise ValueError(f"Pixel coordinate expected, got {pixel}") spectrum = obs.data[:, pixel[0], pixel[1]].copy() if correct_psf and obs.frame.psf is not None: try: frame = Scenery.scene.frame except AttributeError: print("Adaptive morphology can only be created within the context of a Scene") print("Use 'with Scene(frame) as scene: Source(...)'") raise if frame.psf is None: raise AttributeError("Adaptive morphology can only be create with a PSF in the model frame") # correct spectrum for PSF-induced change in peak pixel intensity psf_model = obs.frame.psf() psf_peak = psf_model.max(axis=(-2, -1)) psf0_model = frame.psf() psf0_peak = psf0_model.max(axis=(-2, -1)) spectrum /= psf_peak / psf0_peak if jnp.any(spectrum <= 0): # If the flux in all channels is <=0, # the new sed will be filled with NaN values, # which will cause the code to crash later msg = f"Zero or negative spectrum {spectrum} at {pos}" if jnp.all(spectrum <= 0): print("Zero or negative spectrum in all channels: Setting spectrum to 1") spectrum = jnp.ones_like(spectrum) print(msg) return spectrum
def _sort_spectra(spectra, channels): try: frame = Scenery.scene.frame except AttributeError: print("Multi-observation initialization can only be created within the context of a Scene") print("Use 'with Scene(frame) as scene: ...") raise spectrum = [] for channel in frame.channels: try: idx = channels.index(channel) spectrum.append(spectra[idx]) except ValueError: msg = f"Channel '{channel}' not found in observations. Setting amplitude to 0." print(msg) spectrum.append(0) spectrum = jnp.array(spectrum) return spectrum
[docs] @dataclass class HierarchicalSourceInfo: """Information for initializing a hierarchical source. Attributes ---------- peak: astropy.coordinates.SkyCoord Sky position of detected peak center: astropy.coordinates.SkyCoord Sky position of the center of the bounding box spectrum: jnp.array 1-D array of per-channel flux, shape ``(model.frame.C,)`` morphology: jnp.array 2-D intensity map, clipped to the footprint bounding box, normalized to maximum of 1 footprint: HierarchicalFootprint The footprint from :func:`~scarlet2.detect.hierarchical_footprints` """ peak: astropy.coordinates.SkyCoord center: astropy.coordinates.SkyCoord spectrum: jnp.array morphology: jnp.array footprint: HierarchicalFootprint def __repr__(self): return tree_pformat(self)
def _footprints_to_sources(obs, detect, footprints, catalog): try: frame = Scenery.scene.frame except AttributeError: print("Attributes defined in sky coordinates can only be created within the context of a Scene") print("Use 'with Scene(frame) as scene: (...)'") raise # we need a proper renderer to determine placement in the model frame obs.check_set_renderer(frame) # spatial differences jac, shift = get_relative_jacobian_shift(frame, obs.frame) if not jnp.allclose(jnp.linalg.det(jac), 1) or not jnp.allclose(shift, shift.astype(int)): warnings.warn( "hierarchical_sources does not support resampling. Align `obs` first!", stacklevel=2, ) # map between frame and obs channels: assumes clean 1-to-1 mapping if frame.channels != obs.frame.channels: channel_map = frame.map_channels(obs.frame) frame_channels = jnp.asarray(list(channel_map.keys())) obs_channels = jnp.asarray(list(channel_map.values())) _images = jnp.copy(obs.data) shape = obs.frame.bbox.spatial.shape res = [ None, ] * len(footprints) # get all scales with footprints, from largest to smallest scales = sorted(set(fp.scale for fp in footprints if fp is not None), reverse=True) for scale in scales: for i, fp in enumerate(footprints): if fp is not None and fp.scale == scale: footprint_map = insert_into(jnp.zeros(shape), jnp.asarray(fp.footprint), fp.bbox) source_detect = detect[scale] * footprint_map # spectrum: MLE of flux per band given the morphology from source_detect spectrum = (_images * source_detect[None, :, :]).sum(axis=(-2, -1)) / jnp.sum( source_detect**2 ) spectrum = jnp.maximum(spectrum, 0) _images -= spectrum[:, None, None] * source_detect[None, :, :] # morph: detection image at detection scale within the footprint morph = jnp.maximum(source_detect[fp.bbox.slices], 0) # max normalization factor = morph.max() morph = morph / factor spectrum *= factor # convert peak and box center to RA/Dec peak = obs.frame.get_sky_coord((fp.peak.y, fp.peak.x)) center = obs.frame.get_sky_coord(fp.bbox.center) # ensure correct placement of observed spectrum in model_frame if frame.channels != obs.frame.channels: spectrum = jnp.zeros(frame.C).at[frame_channels].set(spectrum[obs_channels]) res[i] = HierarchicalSourceInfo(peak, center, spectrum, morph, fp) # sweep for catalog sources with non-detections: initialize them as compact sources if catalog is not None: for i, peak in enumerate(catalog): if footprints[i] is None: pixel = peak.astype(int) spectrum = jnp.asarray(_images[:, pixel[0], pixel[1]]) spectrum = jnp.maximum(spectrum, 0) if frame.channels != obs.frame.channels: spectrum = jnp.zeros(frame.C).at[frame_channels].set(spectrum[obs_channels]) morph = compact_morphology() peak = obs.frame.get_sky_coord(peak) center = peak res[i] = HierarchicalSourceInfo(peak, center, spectrum, morph, footprints[i]) return res
[docs] def hierarchical_sources( obs, scales=None, strict=True, K=3, split_peaks=True, image_type="ground", min_separation=0, min_area=9, thresh=0, catalog=None, ): """Initialize sources from a wavelet-based hierarchical footprint detection. Computes a detection image from ``obs``, decomposes it into a hierarchy of footprints across starlet scales, and returns one :class:`HierarchicalSourceInfo` per footprint, suitable for constructing :class:`~scarlet2.Source` objects. Spectra and morphologies are initialized from the wavelet detection image at each source's scale, with a least-squares flux estimate that progressively subtracts brighter/larger sources before fitting fainter/ smaller ones (largest scale first). Parameters ---------- obs : :class:`~scarlet2.Observation` The observation providing image data, per-pixel weights, and the coordinate frame used to convert ``centers`` to pixel positions. scales : list of int, optional Starlet scales (indices into the coefficient array, default `[1,2,3]`) to use for detection. strict : bool, optional If ``True``, the coarse residual plane is pushed one scale higher so that the selected ``scales`` are cleanly separated without bleed from the largest-scale smooth background. Default ``False``. K : float, optional Detection threshold multiplier: coefficients with ``|w| > K * sigma_j`` are considered significant. Also used for the SNR-based bounding box extension. Default ``3``. split_peaks : bool, optional If ``True`` (default), footprints with multiple peaks are split into separate sources using a watershed algorithm. Otherwise, additional peaks become children of the originating footprint, which retains the full footprint area, i.e. the children overlap. Splitting peaks allows to reduce the overlap of mostly independent sources. image_type: str The type of image that is being used. This should be ``"ground"`` for ground based images with wide PSFs or ``"space"`` for images from space-based telescopes with a narrow PSF. min_separation : float, optional Minimum pixel separation between peaks within a footprint. min_area : int, optional Minimum number of pixels a footprint must contain to be kept. thresh : float, optional Detection threshold; pixels must strictly exceed this value. catalog : list of :class:`astropy.coordinates.SkyCoord`, optional If given, only footprints that contain one of these sky positions are returned. Converted to pixel coordinates via ``obs.frame.get_pixel``. Returns ------- sources : list of :class:`HierarchicalSourceInfo` One entry per detected footprint. Will create a "compact" source, i.e. a morphology array from the model PSF for every location in `catalog` that is not detected as a footprint. See Also -------- :func:`~scarlet2.detect.hierarchical_footprints` """ catalog_ = obs.frame.get_pixel(catalog) if catalog is not None else None footprints, detect = hierarchical_footprints( obs, scales=scales, strict=strict, K=K, split_peaks=split_peaks, image_type=image_type, min_separation=min_separation, min_area=min_area, thresh=thresh, flatten=True, catalog=catalog_, return_detect=True, ) return _footprints_to_sources(obs, detect, footprints, catalog_)