"""Measurement methods"""
import copy
import numpy as jnp
import numpy.ma as ma
from .frame import get_affine, get_scale_angle_flip_shift
from .source import Component
[docs]
def max_pixel(component):
"""Determine pixel with maximum value
Parameters
----------
component: :py:class:`~scarlet2.Component` or array
Component to analyze or its hyperspectral model
Returns
-------
array
Coordinates of the brightest pixel in pixel coordinates or in the model
frame (if `component` has a `bbox`)
"""
if isinstance(component, Component):
model = component()
origin = jnp.array(component.bbox.origin)
else:
model = component
origin = 0
return jnp.array(jnp.unravel_index(jnp.argmax(model), model.shape)) + origin
[docs]
def flux(component):
"""Determine total flux in every channel
Parameters
----------
component: :py:class:`~scarlet2.Component` or array
Component to analyze or its hyperspectral model
Returns
-------
float
"""
model = component() if isinstance(component, Component) else component
return model.sum(axis=(-2, -1))
[docs]
def centroid(component):
"""Determine spatial centroid of model
Parameters
----------
component: :py:class:`~scarlet2.Component` or array
Component to analyze or its hyperspectral model
Returns
-------
array
Coordinates of the centroid in pixel coordinates or in the model frame (if `component` has a `bbox`)
"""
if isinstance(component, Component):
model = component()
origin = jnp.array(component.bbox.spatial.origin)
else:
model = component
origin = 0, 0
grid_y, grid_x = jnp.indices(model.shape[-2:])
if len(model.shape) == 3:
grid_y = grid_y[None, :, :]
grid_x = grid_x[None, :, :]
f = flux(model)
c = (
(grid_y * model).sum(axis=(-2, -1)) / f + origin[0],
(grid_x * model).sum(axis=(-2, -1)) / f + origin[1],
)
return jnp.array(c)
[docs]
def fwhm(component):
"""Determine the Full-width at half maximum in pixels
Parameters
----------
component: py:class:`~scarlet2.Component` or array
Component to analyze or its hyperspectral model
Returns
-------
float
"""
model = component() if isinstance(component, Component) else component
peak_pixel = max_pixel(model)[-2:] # only spatial location
peak_value = model[:, peak_pixel[0], peak_pixel[1]]
half_value = peak_value / 2
num_pixels = jnp.count_nonzero(model >= half_value[:, None, None], axis=(1, 2))
diameter = 2 * jnp.sqrt(num_pixels) / jnp.pi
return diameter
[docs]
def snr(component, observations):
"""Determine SNR with `component` as weight function
Parameters
----------
component: :py:class:`~scarlet2.Component` or array
Component to analyze or its hyperspectral model
observations: :py:class:`scarlet2.Observation` or list
The observations to use for the SNR computation.
Returns
-------
float
"""
if not hasattr(observations, "__iter__"):
observations = (observations,)
if hasattr(component, "get_model"):
frame = None
model = component.get_model(frame=frame)
else:
model = component
m = []
w = []
var = []
# convolve model for every observation;
# flatten in channel direction because it may not have all C channels; concatenate
# do same thing for noise variance
for obs in observations:
noise_rms = 1 / jnp.sqrt(ma.masked_equal(obs.weights, 0))
ma.set_fill_value(noise_rms, jnp.inf)
model_ = obs.render(model)
m.append(model_.reshape(-1))
w.append((model_ / (model_.sum(axis=(-2, -1))[:, None, None])).reshape(-1))
noise_var = noise_rms**2
var.append(noise_var.reshape(-1))
m = jnp.concatenate(m)
w = jnp.concatenate(w)
var = jnp.concatenate(var)
# SNR from Erben (2001), eq. 16, extended to multiple bands
# SNR = (I @ W) / sqrt(W @ Sigma^2 @ W)
# with W = morph, Sigma^2 = diagonal variance matrix
snr = (m * w).sum() / jnp.sqrt(((var * w) * w).sum())
return snr
[docs]
class Moments(dict):
"""Base Moments class"""
def __init__(self, component, N=2, center=None, weight=None): # noqa: N803
r"""Moments of the brightness distribution
The dict is accessed by keys, which denote the power of y/x of the specific Moment:
`m[p,q] = \int dx dy f(y,x) y^p x^q`.
Notes
-----
Like all coordinates in scarlet2, moments are computed in (y,x) order.
Parameters
----------
component: :py:class:`~scarlet2.Component` or array
Component to analyze or its hyperspectral model
N: int >=0
Moment order
center: array
2D coordinate in frame of `component`
weight: array
weight function with same shape as `component`
"""
super().__init__()
model = component() if isinstance(component, Component) else component
if weight is None:
weight = 1
grid_y, grid_x = jnp.indices(model.shape[-2:])
if model.ndim == 3:
grid_y = grid_y[None, :, :]
grid_x = grid_x[None, :, :]
self.N = N
# moment code adapted from https://github.com/pmelchior/shapelens/blob/master/src/Moments.cc
for n in range(self.N + 1):
for m in range(n + 1):
# moments ordered by power in y, then x
self[m, n - m] = (grid_y**m * grid_x ** (n - m) * model * weight).sum(axis=(-2, -1))
if n == 1:
self._centroid = jnp.array((self[1, 0] / self[0, 0], self[0, 1] / self[0, 0]))
# shift grid to produce centered moments
if center is None:
center = self._centroid
self[1, 0] = jnp.zeros_like(self[1, 0])
self[0, 1] = jnp.zeros_like(self[0, 1])
else:
center = jnp.asarray(center)
# centroid wrt given center
self._centroid[0] -= center[0] * self[0, 0]
self._centroid[1] -= center[1] * self[0, 0]
self[1, 0] -= center[0] * self[0, 0]
self[0, 1] -= center[1] * self[0, 0]
if model.ndim == 3 and center[0].ndim == 1:
center = center[0][:, None, None], center[1][:, None, None]
grid_y = grid_y - center[0]
grid_x = grid_x - center[1]
@property
def order(self):
"""Moment order
Returns
-------
int
"""
# TODO: why is this not simply self.N? Probably left over from a dynamic computation of higher moments
return max(key[0] for key in self.keys())
@property
def flux(self):
"""Determine flux from 0th moment
Returns
-------
float
"""
return self[0, 0]
@property
def centroid(self):
"""Determine centroid from moments
Returns
-------
array
Coordinates of the centroid in the pixel frame of the data that defines these moments
"""
return self._centroid
@property
def size(self):
"""Determine size from moments
Returns
-------
float
"""
flux = self[0, 0]
return (self[0, 2] / flux * self[2, 0] / flux - (self[1, 1] / flux) ** 2) ** (1 / 4)
@property
def ellipticity(self):
"""Determine complex ellipticity from moments
Returns
-------
jnp.array
Ellipticity (2D) of the data that defines these moments
"""
ellipticity = (self[0, 2] - self[2, 0] + 2j * self[1, 1]) / (self[2, 0] + self[0, 2])
return jnp.array((ellipticity.real, ellipticity.imag))
[docs]
def normalize(self):
"""Normalize moments with respect to the flux
Returns
-------
self
"""
norm = self[0, 0]
for key in self.keys():
self[key] /= norm
return self
[docs]
def convolve(self, p):
"""Convolve moments with moments `p`
The moments are changed in place.
See Melchior et al. (2010), "Weak gravitational lensing with Deimos", Equation 9
Parameters
----------
p: Moments
Moments of the kernel to convolve with
Returns
-------
self
"""
g_ = self
g = copy.deepcopy(g_)
n_min = min(p.order, g.order)
for n in range(n_min + 1):
for i in range(n + 1):
j = n - i
g_[i, j] = jnp.zeros_like(g_[i, j])
for k in range(i + 1):
for l in range(j + 1): # noqa: E741
g_[i, j] += binomial(i, k) * binomial(j, l) * g[k, l] * p[i - k, j - l]
return self
[docs]
def deconvolve(self, p):
"""Deconvolve moments from moments `p`
The moments are changed in place.
See Melchior et al. (2010), "Weak gravitational lensing with Deimos", Table 1
Parameters
----------
p: Moments
Moments of the kernel to deconvolve from
Returns
-------
self
"""
g = self
n_min = min(p.order, g.order)
# use explicit relations for up to 2nd moments
g[0, 0] /= p[0, 0]
if n_min >= 1:
g[0, 1] -= g[0, 0] * p[0, 1]
g[1, 0] -= g[0, 0] * p[1, 0]
g[0, 1] /= p[0, 0]
g[1, 0] /= p[0, 0]
if n_min >= 2:
g[0, 2] -= g[0, 0] * p[0, 2] + 2 * g[0, 1] * p[0, 1]
g[1, 1] -= g[0, 0] * p[1, 1] + g[0, 1] * p[1, 0] + g[1, 0] * p[0, 1]
g[2, 0] -= g[0, 0] * p[2, 0] + 2 * g[1, 0] * p[1, 0]
g[0, 2] /= p[0, 0]
g[1, 1] /= p[0, 0]
g[2, 0] /= p[0, 0]
if n_min >= 3:
# use general formula
for n in range(3, n_min + 1):
for i in range(n + 1):
for j in range(n - i):
for k in range(i):
for l in range(j): # noqa: E741
g[i, j] -= binomial(i, k) * binomial(j, l) * g[k, l] * p[i - k, j - l]
for k in range(i):
g[i, j] -= binomial(i, k) * g[k, j] * p[i - k, 0]
for l in range(j): # noqa: E741
g[i, j] -= binomial(j, l) * g[i, l] * p[0, j - l]
g[i, j] /= p[0, 0]
return self
[docs]
def resize(self, c):
"""Change moments for a change of factor `c` of the size/spatial resolution
of the defining frame
This operation arises when one adjust the moments for a change in the size
of pixels of the defining frame, e.g. when asking "what would the moments
be if the pixels were factor c smaller (or the source c times larger)"?
The fluxes are unchanged, which corresponds to preservation of photons under resizing.
The moments are changed in place.
See Teague (1980), "Image analysis via the general theory of moments", eq. 34 for details.
Parameters
----------
c: float or list
Scaling factor for the size change. Can be different along x and y
Returns
-------
self
"""
if jnp.isscalar(c):
assert c > 0
flux_change = c**2
for e in self:
self[e] = self[e] * c ** (2 + e[0] + e[1]) / flux_change
elif len(c) == 2:
assert c[0] > 0 and c[1] > 0
flux_change = c[0] * c[1]
for e in self:
self[e] = self[e] * c[0] ** (e[0] + 1) * c[1] ** (e[1] + 1) / flux_change
else:
raise AttributeError("c must be a scalar of a list or array of two components")
return self
[docs]
def fliplr(self):
"""Flip moments along the x-axis
The moments are changed in place.
Returns
-------
self
"""
for e in self:
if e[1] % 2 == 1:
self[e] *= -1
return self
[docs]
def flipud(self):
"""Flip moments along the y-axis
The moments are changed in place.
Returns
-------
self
"""
for e in self:
if e[0] % 2 == 1:
self[e] *= -1
return self
[docs]
def rotate(self, phi):
"""Change moments for rotation of angle `phi`.
The moments are changed in place.
See Teague (1980), "Image analysis via the general theory of moments", eq. 36 for details.
Parameters
----------
phi: float
Rotation angle, in radian
Returns
-------
self
"""
mu_p = {}
for n in range(self.N + 1):
for j in range(n + 1):
k = n - j
value = 0
for r in range(j + 1):
for s in range(k + 1):
value += (
(-1) ** (k - s)
* binomial(j, r)
* binomial(k, s)
* jnp.cos(phi) ** (j - r + s)
* jnp.sin(phi) ** (k + r - s)
* self[j + k - r - s, r + s]
)
mu_p[j, k] = value
for e in self:
self[e] = mu_p[e]
return self
[docs]
def translate(self, shift):
"""Change moments for translation `shift`
The moments are changed in place.
Note: This changes all the moments, not just the dipole, for the new reference center.
See Teague (1980), "Image analysis via the general theory of moments", eq. 30 for details.
Parameters
----------
shift: (y, x)
translation, in pixels
Returns
-------
self
"""
mu = {}
for n in range(self.N + 1):
for j in range(n + 1):
k = n - j
value = 0
for r in range(j + 1):
for s in range(k + 1):
value += (
binomial(j, r)
* binomial(k, s)
* (shift[0]) ** (j - r)
* (shift[1]) ** (k - s)
* self[r, s]
)
mu[j, k] = value
for e in self:
self[e] = mu[e]
return self
[docs]
def transfer(self, wcs_in, wcs_out):
"""Compute rescaling and rotation from WCSs and apply to moments
The method adjusts moments measured with a frame defined by `wcs_in` to the frame `wcs_out`.
The moments are changed in place.
Parameters
----------
wcs_in: :py:class:`astropy.wcs.Wcsprm`
WCS of the frame with original moments
wcs_out: :py:class:`astropy.wcs.Wcsprm`
WCS of the frame to which the moments should be adjusted
Returns
-------
self
"""
if (wcs_in is not None) and (wcs_out is not None):
M_in = get_affine(wcs_in) # noqa: N806
M_out = get_affine(wcs_out) # noqa: N806
M = jnp.linalg.inv(M_out) @ M_in # noqa: N806, transformation from in pixel -> sky -> out pixels
scale, angle, flip, _ = get_scale_angle_flip_shift(M)
# if flipped: go to right-handed coord first before applying rotation
if flip == -1:
self.flipud() # our flip convention is along y-axis
self.rotate(angle).resize(scale)
return self
# def moments(component, N=2, center=None, weight=None):
# return Moments(component, N=N, center=center, weight=weight)
# adapted from https://github.com/pmelchior/shapelens/blob/src/DEIMOS.cc
[docs]
def binomial(n, k):
"""Binomial coefficient"""
if k == 0:
return 1
if k > n // 2:
return binomial(n, n - k)
result = 1
for i in range(1, k + 1):
result *= n - i + 1
result //= i
return result
[docs]
def forced_photometry(scene, obs):
"""Computes the spectra of every source in the scene to match the observations
Computes the best-fit amplitude of the rendered model of all components in every
channel of every observation as a linear inverse problem.
If sources/sources components in `scene` have non-flat spectra, the output of this function is
the correction factor that needs to be applied to those spectra to best match each channel of `obs`.
Parameters
----------
scene: :py:class:`scarlet2.scene.Scene`
Scene for which the spectra should be computed
obs: :py:class:`~scarlet2.Observation`
The observation used to determine the spectra.
Returns
-------
array
Array of the spectra, in the order of the sources in the scene
"""
# extract multi-channel model for every source
models = []
for i, src in enumerate(scene.sources):
# evaluate the model for any source so that fit includes it even if its spectrum is not updated
model = scene.evaluate_source(src) # assumes all sources are single components
# check for models with identical initializations, see scarlet repo issue #282
# if duplicate: raise ValueError
for model_indx in range(len(models)):
if jnp.allclose(model, models[model_indx]):
message = f"Source {i} has a model identical to source {model_indx}.\n"
message += "This is likely not intended, and the second source should be deleted."
raise ValueError(message)
models.append(model)
models = jnp.array(models)
num_models = len(models)
# independent channels, no mixing
# solve the linear inverse problem of the amplitudes in every channel
# given all the rendered morphologies
# spectrum = (M^T Sigma^-1 M)^-1 M^T Sigma^-1 * im
num_channels = obs.frame.C
images = obs.data
weights = obs.weights
morphs = jnp.stack([obs.render(model) for model in models], axis=0)
spectra = jnp.zeros((num_models, num_channels))
for c in range(num_channels):
im = images[c].reshape(-1)
w = weights[c].reshape(-1)
m = morphs[:, c, :, :].reshape(num_models, -1)
mw = m * w[None, :]
# check if all components have nonzero flux in c.
# because of convolutions, flux can be outside the box,
# so we need to compare weighted flux with unweighted flux,
# which is the same (up to a constant) for constant weights.
# so we check if *most* of the flux is from pixels with non-zero weight
nonzero = jnp.sum(mw, axis=1) / jnp.sum(m, axis=1) / jnp.mean(w) > 0.1
nonzero = jnp.flatnonzero(nonzero)
if len(nonzero) == num_models:
covar = jnp.linalg.inv(mw @ m.T)
spectra = spectra.at[:, c].set(covar @ m @ (im * w))
else:
covar = jnp.linalg.inv(mw[nonzero] @ m[nonzero].T)
spectra = spectra.at[nonzero, c].set(covar @ m[nonzero] @ (im * w))
return spectra
[docs]
def correlation_function(img, maxlength=2, threshold=0):
"""Computes the 2D correlation function of the image.
Parameters
----------
img: :py:class:`numpy.ndarray`
Image array, 2D or 3D. Masked pixels must be set to 0 in `img`.
maxlength: int
Maximum length of the correlation function
threshold: float
Minimum correlation coefficient to maintain
Returns
-------
dict, with keys (dy,dx) specifying the 2D offset in image pixels
"""
xi = dict()
n = dict()
# expand to image cubes for faster ellipsis
img_ = img[None, :, :] if img.ndim == 2 else img
height, width = img_.shape[-2:]
for dy in range(maxlength + 1):
for dx in range(maxlength + 1):
overlap = img_[..., dy:, dx:] * img_[..., : height - dy, : width - dx]
xi[dy, dx] = jnp.sum(overlap, axis=(-2, -1))
n[dy, dx] = jnp.sum(overlap != 0, axis=(-2, -1))
# normalize and filter correlations below threshold
# Note: possibly safer to set the largest negative correlation (which should not exist) as threshold
for k in xi:
xi[k] = jnp.maximum(xi[k] / jnp.maximum(n[k], 1), threshold) # prevent division by 0
# fill in the symmetric negative offsets
offsets = list(xi.keys())
for k in offsets:
dy, dx = k
if dy > 0:
dy *= -1
if dx > 0:
dx *= -1
xi[dy, dx] = xi[k]
return xi