"""Helper methods to initialize sources"""
import operator
from functools import reduce
import astropy.units as u
import jax.numpy as jnp
from . import Scenery, measure
from .bbox import Box
from .morphology import GaussianMorphology
from .observation import Observation
# function to calculate values of perimeter pixels
def _get_edge_pixels(img, box):
box_slices = box.slices
box_start = box.start
box_stop = box.stop
edge = [
img[:, box_slices[0], box_start[1]],
img[:, box_slices[0], box_stop[1] - 1],
img[:, box_start[0], box_slices[1]],
img[:, box_stop[0] - 1, box_slices[1]],
]
return jnp.concatenate(edge, axis=1)
[docs]
def best_box_moments(obs, center_pix, sizes=[11, 17, 25, 35, 47, 61, 77], min_snr=3, min_corr=0.99): # noqa: B006
"""Make a bounding box for source at center
This method finds small box around the center so that the edge flux has a minimum SNR,
the color of the edge pixels remains highly correlated to the center pixel color,
and the flux in all channels is lower than with the last acceptable box size.
This should effectively isolate the source against the noise background and neighboring objects.
Parameters
----------
obs: :py:class:`~scarlet2.Observation`
The Observation instance to use when defining the bounding box.
center_pix: tuple
source center, in pixel coordinates
sizes: list[int]
a list of box sizes to cycle through
min_snr: float
minimum SNR of edge pixels (aggregated over all observation channel) to allow increase of box size
min_corr: float
minimum correlation coefficient between center and edge color to allow increase of box size
Returns
-------
:py:class:`~scarlet2.Box`, :py:class:`~scarlet2.Moments`
"""
assert isinstance(obs, Observation)
assert obs.weights is not None, "Observation weights are required"
assert obs.frame.bbox.spatial.contains(center_pix), (
f"Center pixel {center_pix} not contained in observation"
)
assert len(sizes) > 0
# increase box size from list until SNR is below threshold or spectrum changes significantly
peak_spectrum = pixel_spectrum(obs, center_pix, correct_psf=True)
last_spectrum = jnp.empty(len(peak_spectrum))
moments = []
for i in range(len(sizes)):
box2d = Box((sizes[i], sizes[i]))
box2d.set_center(center_pix.astype(int))
edge_pixels = _get_edge_pixels(obs.data, box2d)
valid_edge_pixel = edge_pixels != 0
edge_spectrum = jnp.sum(edge_pixels, axis=-1) / jnp.sum(valid_edge_pixel, axis=-1)
weight_edge_pixels = _get_edge_pixels(obs.weights, box2d)
snr_edge_pixels = jnp.abs(edge_pixels) * jnp.sqrt(weight_edge_pixels)
mean_snr = jnp.sum(jnp.sum(snr_edge_pixels, axis=-1) / jnp.sum(valid_edge_pixel, axis=-1))
spec_corr = (
jnp.dot(edge_spectrum, peak_spectrum)
/ jnp.sqrt(jnp.dot(peak_spectrum, peak_spectrum))
/ jnp.sqrt(jnp.dot(edge_spectrum, edge_spectrum))
)
m = boxed_moments(obs, center_pix, bbox=box2d)
if mean_snr < min_snr or max(box2d.shape) > max(obs.frame.bbox.spatial.shape):
break
if i > 0 and (
spec_corr < min_corr or jnp.any(edge_spectrum > last_spectrum) or jnp.any(jnp.isnan(m.size))
):
box2d = Box((sizes[i - 1], sizes[i - 1]))
box2d.set_center(center_pix.astype(int))
m = moments[-1]
break
last_spectrum = edge_spectrum
moments.append(m)
box = obs.frame.bbox[0] @ box2d
m = standardize_moments(m, obs)
return box, m
[docs]
def boxed_moments(
obs,
center,
footprint=None,
bbox=None,
):
"""Measure 2nd moments of the observation in the region of the bounding box
The methods cuts out the pixel included in `bbox` or in `footprint`, measures
their 2nd moments with respect to `center`, adjust the spatial coordinates
and the PSF to match the model frame.
Parameters
----------
obs: :py:class:`~scarlet2.Observation`
The Observation instance to derive moments from
center: tuple
central pixel of the source
footprint: array, optional
2D image with non-zero values for all pixels associated with this source
(aka a segmentation map or footprint)
bbox: :py:class:`~scarlet2.BBox`, optional
box to cut out source from observation, in pixel coordinates
Returns
-------
:py:class:`~scarlet2.Moments`
2nd moments, corrected to match the model frame
Raises
------
AssertionError
If neither `bbox` or `footprint` are set
"""
assert isinstance(obs, Observation)
# construct box from footprint
if bbox is None:
assert footprint is not None
bbox = Box.from_data(footprint)
# construct footprint as step function inside bbox
if footprint is None:
assert bbox is not None
footprint = jnp.zeros(obs.frame.bbox.spatial.shape)
footprint = footprint.at[bbox.spatial.slices].set(1)
if bbox.D == 2:
bbox = obs.frame.bbox[0] @ bbox
# cutout image and footprint
cutout_img = obs.data[bbox.slices]
cutout_fp = footprint[bbox.spatial.slices]
center_ = center - jnp.array(bbox.spatial.origin)
m = measure.Moments(cutout_img, center=center_, weight=cutout_fp[None, :, :], N=2)
# m = standardize_moments(m, obs)
return m
[docs]
def standardize_moments(g, obs):
"""Standardize the 2nd moments to match the model frame
Adjust the spatial coordinates and deconvolved from the observation PSF to mimic moments measured in the
model frame.
Parameters
----------
g: :py:class:`~scarlet2.Moments`
2nd moments, measured in observed frame
obs: :py:class:`~scarlet2.Observation`
The Observation instance to derive moments from
center: tuple
central pixel of the source
footprint: array, optional
2D image with non-zero values for all pixels associated with this source
(aka a segmentation map or footprint)
bbox: :py:class:`~scarlet2.BBox`, optional
box to cut out source from observation, in pixel coordinates
Returns
-------
:py:class:`~scarlet2.Moments`
2nd moments, deconvolved, and in the coordinates of the model frame
"""
try:
frame = Scenery.scene.frame
except AttributeError:
print("Adaptive morphology can only be created within the context of a Scene")
print("Use 'with Scene(frame) as scene: Source(...)'")
raise
# adjust moments for model frame
g.transfer(obs.frame.wcs, frame.wcs)
# deconvolve from PSF (actually: difference kernel between obs PSF and model frame PSF)
if frame.psf is not None and obs.frame.psf is not None:
if hasattr(obs, "_dp"):
p = obs._dp
else:
# moments of difference kernel between the model PSF and the observed PSF
p = measure.Moments(obs.frame.psf(), N=2)
p.transfer(obs.frame.wcs, frame.wcs)
p0 = measure.Moments(frame.psf(), N=2)
p.deconvolve(p0)
# store in obs for repeated use
object.__setattr__(obs, "_dp", p)
# deconvolve from difference kernel
g.deconvolve(p)
return g
[docs]
def from_gaussian_moments(
obs,
center,
box_sizes=None,
min_snr=3,
min_corr=0.99,
min_value=1e-6,
max_value=1 - 1e-6,
):
"""Create a Gaussian-shaped morphology and associated spectrum from the observation(s).
The method determines an suitable bounding box that contains the source given its `center`,
computes the deconvolved moments up to order 2, constructs the spectrum from
the 0th moment and a morphology image from the 2nd moments (assuming a Gaussian shape).
If multiple observations are given, it takes the median of the moments in the same channel.
Parameters
----------
obs: :py:class:`~scarlet2.Observation` or list
Observation from which the source is initialized.
center: tuple
Central pixel of the source
box_sizes: None, list[int] or list[SkyCoord]
A list of box sizes to choose from. If `None`, chooses multiples of the PSF FWHM from `obs`.
min_snr: float
Minimum SNR of edge pixels (aggregated over all observation channel) to allow increase of box size
min_corr: float
Minimum correlation coefficient between center and edge color to allow increase of box size
min_value: float
Minimum pixel value (useful to set to > 0 for positivity constraints)
max_value: float
Minimum pixel value (useful to set to < 1 for unit interval constraints)
Returns
-------
(array,array)
Spectrum and morphology arrays
Warnings
--------
This method is stable only for isolated sources. In cases of significant blending,
the size of the bounding box and the measured moments are likely biased high.
See Also
--------
make_bbox: Defines bounding box that contains the source
standardized_moments: Computes 2nd moments for source in bounding box
"""
try:
frame = Scenery.scene.frame
except AttributeError:
print("from_gaussian_moments() can only be called within the context of a Scene")
print("Use 'with Scene(frame) as scene: Source(...)'")
raise
# TODO: implement with source footprints given for each observation
# get moments from all channels in all observations
observations = obs if isinstance(obs, (list, tuple)) else (obs,)
# centers and box_sizes are defined in pixel in the model frame
# therefore need to convert back to skycoord and convert to pixel in obs frames
centers = [obs_.frame.get_pixel(center) for obs_ in observations]
if box_sizes is None:
# growing sizes in units of the observed PSF
psf_sizes = [
measure.fwhm(obs.frame.psf()).min() if obs.frame.psf is not None else 1 for obs in observations
] # in obs pixels
magic_number = lambda i: 6.0 if i == 0 else 1.5 * magic_number(i - 1) # noqa:E731
# NOTE: Not forced to be odd
box_sizes = [[int(psf_size * magic_number(i)) for i in range(10)] for psf_size in psf_sizes]
else:
assert len(box_sizes) > 0
if u.get_physical_type(box_sizes[0]) == "angle":
box_sizes = [[obs.frame.u_to_pixel(size) for size in box_sizes] for obs in observations]
else:
# assume that all pixels are in proper observed frame
box_sizes = [box_sizes for obs in observations]
boxes_moments = [
best_box_moments(obs_, center_, sizes=sizes_, min_snr=min_snr, min_corr=min_corr)
for obs_, center_, sizes_ in zip(observations, centers, box_sizes, strict=False)
]
# flat lists of spectra, sorted in order of model frame channels
spectra = jnp.concatenate([m[0, 0] for box, m in boxes_moments])
channels = reduce(operator.add, [obs_.frame.channels for obs_ in observations])
spectrum = _sort_spectra(spectra, channels)
# average over all channels
moments = [m.normalize() for box, m in boxes_moments] # flux normalization
m = moments[0] # moments from first observation
for key in m:
m[key] = jnp.concatenate([m[key] for box, m in boxes_moments]) # combine all observations
m[key] = jnp.median(
m[key]
) # this is not SNR weighted nor consistent aross different moments, but works(?)
# average box size across observations
if frame.wcs is not None:
size = jnp.mean(
jnp.array(
[
frame.u_to_pixel(obs.frame.pixel_to_angle(max(box.spatial.shape)))
for (box, moments), obs in zip(boxes_moments, observations, strict=False)
]
)
).astype(int)
else:
size = jnp.mean(jnp.array([max(box.spatial.shape) for box, moments in boxes_moments])).astype(int)
# create morphology and evaluate at center
morph = GaussianMorphology.from_moments(m, shape=(size, size))
morph = morph()
spectrum /= morph.sum()
morph = jnp.minimum(jnp.maximum(morph, min_value), max_value)
return spectrum, morph
[docs]
def compact_morphology(min_value=1e-6, max_value=1 - 1e-6):
"""Create image of the point source morphology model, i.e. the most compact source possible
Parameters
----------
min_value: float
Minimum pixel value (needed for positively constrained morphologies)
max_value: float
minimum pixel value (useful to set to < 1 for unit interval constraints)
Returns
-------
array
2D array, normalized to the range [0,1]
"""
try:
frame = Scenery.scene.frame
except AttributeError:
print("Compact morphology can only be created within the context of a Scene")
print("Use 'with Scene(frame) as scene: Source(...)'")
raise
if frame.psf is None:
raise AttributeError("Compact morphology can only be create with a PSF in the model frame")
morph = frame.psf.morphology()
morph = jnp.minimum(jnp.maximum(morph, min_value), max_value)
return morph
# initialise the spectrum
[docs]
def pixel_spectrum(obs, pos, correct_psf=False):
"""Get the spectrum at a given position in the observation(s).
Yields the spectrum of a single-pixel source with flux 1 in every channel,
concatenated for all observations.
Parameters
----------
obs: `:py:class:`~scarlet2.Observation` or list
Observation(s) to extract pixel SED from
pos: tuple
Position in the observation. Needs to be in sky coordinates if multiple
observations have different locations or pixel scales.
correct_psf: bool, optional
Whether PSF shape variations in the observations should be corrected.
If `True`, this method homogenizes the PSFs of the observations, which
yields the correct spectrum for a flux=1 point source.
Returns
-------
array or list
If `obs` is a list, the method returns the associate list of spectra.
"""
# for multiple observations, get spectrum from each observation and then
# combine channels in order of model frame
if isinstance(obs, (list, tuple)):
# flat lists of spectra and channels in order of observations
spectra = jnp.concatenate([pixel_spectrum(obs_, pos, correct_psf=correct_psf) for obs_ in obs])
channels = reduce(operator.add, [obs_.frame.channels for obs_ in obs])
spectrum = _sort_spectra(spectra, channels)
return spectrum
assert isinstance(obs, Observation)
pixel = obs.frame.get_pixel(pos).astype(int)
if not obs.frame.bbox.spatial.contains(pixel):
raise ValueError(f"Pixel coordinate expected, got {pixel}")
spectrum = obs.data[:, pixel[0], pixel[1]].copy()
if correct_psf and obs.frame.psf is not None:
try:
frame = Scenery.scene.frame
except AttributeError:
print("Adaptive morphology can only be created within the context of a Scene")
print("Use 'with Scene(frame) as scene: Source(...)'")
raise
if frame.psf is None:
raise AttributeError("Adaptive morphology can only be create with a PSF in the model frame")
# correct spectrum for PSF-induced change in peak pixel intensity
psf_model = obs.frame.psf()
psf_peak = psf_model.max(axis=(-2, -1))
psf0_model = frame.psf()
psf0_peak = psf0_model.max(axis=(-2, -1))
spectrum /= psf_peak / psf0_peak
if jnp.any(spectrum <= 0):
# If the flux in all channels is <=0,
# the new sed will be filled with NaN values,
# which will cause the code to crash later
msg = f"Zero or negative spectrum {spectrum} at {pos}"
if jnp.all(spectrum <= 0):
print("Zero or negative spectrum in all channels: Setting spectrum to 1")
spectrum = jnp.ones_like(spectrum)
print(msg)
return spectrum
def _sort_spectra(spectra, channels):
try:
frame = Scenery.scene.frame
except AttributeError:
print("Multi-observation initialization can only be created within the context of a Scene")
print("Use 'with Scene(frame) as scene: ...")
raise
spectrum = []
for channel in frame.channels:
try:
idx = channels.index(channel)
spectrum.append(spectra[idx])
except ValueError:
msg = f"Channel '{channel}' not found in observations. Setting amplitude to 0."
print(msg)
spectrum.append(0)
spectrum = jnp.array(spectrum)
return spectrum