Source code for scarlet2.morphology

import astropy.units as u
import equinox as eqx
import jax.numpy as jnp
import jax.scipy

from . import Scenery, measure
from .module import Module
from .wavelets import starlet_reconstruction, starlet_transform


[docs] class Morphology(Module): """Morphology base class""" @property def shape(self): """Shape (2D) of the morphology model""" raise NotImplementedError
[docs] class ProfileMorphology(Morphology): """Base class for morphologies based on a radial profile""" size: float """Size of the profile Can be given as an astropy angle, which will be transformed with the WCS of the current :py:class:`~scarlet2.Scene`. """ ellipticity: (None, jnp.array) """Ellipticity of the profile""" _shape: tuple = eqx.field(repr=False) def __init__(self, size, ellipticity=None, shape=None): if isinstance(size, u.Quantity): try: size = Scenery.scene.frame.u_to_pixel(size) except AttributeError: print("`size` defined in astropy units can only be used within the context of a Scene") print("Use 'with Scene(frame) as scene: (...)'") raise self.size = size self.ellipticity = ellipticity # default shape: square 15x size if shape is None: # explicit call to int() to avoid bbox sizes being jax-traced size = int(jnp.ceil(15 * self.size)) # odd shapes for unique center pixel if size % 2 == 0: size += 1 shape = (size, size) self._shape = shape @property def shape(self): """Shape of the bounding box for the profile. If not set during `__init__`, uses a square box with an odd number of pixels not smaller than `10*size`. """ return self._shape
[docs] def f(self, r2): """Radial profile function Parameters ---------- r2: float or array Radius (distance from the center) squared """ raise NotImplementedError
[docs] def __call__(self, delta_center=jnp.zeros(2)): # noqa: B008 """Evaluate the model""" _y = jnp.arange(-(self.shape[-2] // 2), self.shape[-2] // 2 + 1, dtype=float) - delta_center[-2] _x = jnp.arange(-(self.shape[-1] // 2), self.shape[-1] // 2 + 1, dtype=float) - delta_center[-1] if self.ellipticity is None: r2 = _y[:, None] ** 2 + _x[None, :] ** 2 else: e1, e2 = self.ellipticity g_factor = 1 / (1.0 + jnp.sqrt(1.0 - (e1**2 + e2**2))) g1, g2 = self.ellipticity * g_factor __x = ((1 - g1) * _x[None, :] - g2 * _y[:, None]) / jnp.sqrt(1 - (g1**2 + g2**2)) __y = (-g2 * _x[None, :] + (1 + g1) * _y[:, None]) / jnp.sqrt(1 - (g1**2 + g2**2)) r2 = __y**2 + __x**2 r2 /= self.size**2 r2 = jnp.maximum(r2, 1e-3) # prevents infs at R2 = 0 morph = self.f(r2) return morph
[docs] class GaussianMorphology(ProfileMorphology): """Gaussian radial profile"""
[docs] def f(self, r2): """Radial profile function Parameters ---------- r2: float or array Radius (distance from the center) squared """ return jnp.exp(-r2 / 2)
[docs] def __call__(self, delta_center=jnp.zeros(2)): # noqa: B008 """Evaluate the model""" # faster circular 2D Gaussian: instead of N^2 evaluations, use outer product of 2 1D Gaussian evals if self.ellipticity is None: _y = jnp.arange(-(self.shape[-2] // 2), self.shape[-2] // 2 + 1, dtype=float) - delta_center[-2] _x = jnp.arange(-(self.shape[-1] // 2), self.shape[-1] // 2 + 1, dtype=float) - delta_center[-1] # with pixel integration f = lambda x, s: ( # noqa: E731 0.5 * ( 1 - jax.scipy.special.erfc((0.5 - x) / jnp.sqrt(2) / s) + 1 - jax.scipy.special.erfc((0.5 + x) / jnp.sqrt(2) / s) ) ) # # without pixel integration # f = lambda x, s: jnp.exp(-(x ** 2) / (2 * s ** 2)) / (jnp.sqrt(2 * jnp.pi) * s) return jnp.outer(f(_y, self.size), f(_x, self.size)) else: return super().__call__(delta_center)
[docs] @staticmethod def from_image(image): """Create Gaussian radial profile from the 2nd moments of `image` Parameters ---------- image: array 2D array to measure :py:class:`~scarlet2.measure.Moments` from. Returns ------- GaussianMorphology """ assert image.ndim == 2 center = measure.centroid(image) # compute moments and create Gaussian from it g = measure.Moments(image, center=center, N=2) return GaussianMorphology.from_moments(g, shape=image.shape)
[docs] @staticmethod def from_moments(g, shape=None): """Create Gaussian radial profile from the moments `g` Parameters ---------- g: :py:class:`~scarlet2.measure.Moments` Moments, order >= 2 shape: tuple Shape of the bounding box Returns ------- GaussianMorphology """ t = g.size ellipticity = g.ellipticity # create image of Gaussian with these 2nd moments if jnp.isfinite(t) and jnp.isfinite(ellipticity).all(): morph = GaussianMorphology(t, ellipticity, shape=shape) else: raise ValueError( f"Gaussian morphology not possible with size={t}, and ellipticity={ellipticity}!" ) return morph
[docs] class SersicMorphology(ProfileMorphology): """Sersic radial profile""" n: float """Sersic index""" def __init__(self, n, size, ellipticity=None, shape=None): self.n = n super().__init__(size, ellipticity=ellipticity, shape=shape)
[docs] def f(self, r2): """Radial profile function Parameters ---------- r2: float or array Radius (distance from the center) squared """ n = self.n n2 = n * n # simplest form of bn: Capaccioli (1989) # bn = 1.9992 * n - 0.3271 # # better treatment in Ciotti & Bertin (1999), eq. 18 # stable to n > 0.36, with errors < 10^5 bn = 2 * n - 0.333333 + 0.009877 / n + 0.001803 / n2 + 0.000114 / (n2 * n) - 0.000072 / (n2 * n2) # MacArthur, Courteau, & Holtzman (2003), eq. A2 # much more stable for n < 0.36 # not using it here to avoid if clause in jitted code # bn = 0.01945 - 0.8902 * n + 10.95 * n2 - 19.67 * n2 * n + 13.43 * n2 * n2 # Graham & Driver 2005, eq. 1 # we're given R^2, so we use R2^(0.5/n) instead of 1/n return jnp.exp(-bn * (r2 ** (0.5 / n) - 1))
[docs] class StarletMorphology(Morphology): """Morphology in the starlet basis Notes ----- The starlet basis is overcomplete, which means it can exactly represent the same image in multiple ways. If used without constraints or priors on the starlet coefficients, this morphology model is functionally indistinguishable from a 2D pixel array, while using more operations. See Also -------- scarlet2.wavelets.Starlet """ coeffs: jnp.ndarray """Starlet coefficients"""
[docs] def __call__(self, **kwargs): """Evaluate the model""" return starlet_reconstruction(self.coeffs)
@property def shape(self): """Shape (2D) of the morphology model""" return self.coeffs.shape[-2:] # wavelet coeffs: scales x n1 x n2
[docs] @staticmethod def from_image(image, min_value=None, max_value=None): """Create starlet morphology from `image` Parameters ---------- image: array 2D image array to determine coefficients from. min_value: (float, None): Minimum value threshold for coefficients max_value: (float, None): Minimum value threshold for coefficients Returns ------- StarletMorphology """ # Starlet transform of image (n1,n2) into coefficient with 3 dimensions: (scales+1,n1,n2) coeffs = starlet_transform(image) if min_value is not None: coeffs = coeffs.at[coeffs < min_value].set(min_value) if max_value is not None: coeffs = coeffs.at[coeffs > max_value].set(max_value) return StarletMorphology(coeffs)