Source code for scarlet2.detect
"""Detection methods
Higher-level helpers (``get_wavelets``, ``QuadTreeRegion``, ``get_peaks``, …)
are adapted from scarlet1's ``detect.py``.
Uses NumPy rather than JAX because detection involves dynamic data structures
(variable-length peak lists, irregularly shaped footprints) that are not
compatible with JAX's JIT. Both NumPy and JAX arrays are accepted as inputs.
"""
import heapq
from dataclasses import dataclass, field
import numpy as np
from equinox import tree_pformat
from scipy.ndimage import binary_fill_holes, find_objects
from scipy.ndimage import label as ndimage_label
from scipy.optimize import linear_sum_assignment
from .bbox import Box, overlap_slices
from .wavelets import get_multiresolution_support, starlet_transform
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
[docs]
@dataclass
class Peak:
"""A peak (local maximum) in a :class:`Footprint`.
Attributes
----------
y : int
Row index in the full image.
x : int
Column index in the full image.
flux : float
Pixel value at the peak location.
"""
y: int
x: int
flux: float
[docs]
@dataclass
class Footprint:
"""A detected footprint (connected region above threshold) in an image.
Attributes
----------
footprint : ndarray of bool, shape (height, width)
Boolean mask of the footprint pixels, sized to the bounding box.
peaks : list of Peak
Peaks found within this footprint, sorted brightest-first.
bounds : tuple of two (min, max) pairs
Bounding box ``((y_min, y_max), (x_min, x_max))`` in the full image,
with exclusive end coordinates (consistent with :class:`~scarlet2.bbox.Box`).
"""
footprint: np.ndarray
peaks: list[Peak]
bounds: tuple[tuple[int, int], tuple[int, int]]
def __repr__(self):
return tree_pformat(self)
# ---------------------------------------------------------------------------
# Low-level detection
# ---------------------------------------------------------------------------
def _get_patch_peaks(image, min_separation, y0=0, x0=0):
"""Find local maxima in an image patch.
A pixel is a local maximum if its value strictly exceeds all 8-connected
neighbours that lie within the image boundary.
Parameters
----------
image : 2D ndarray
The (possibly masked) image patch to search for peaks.
min_separation : float
Minimum pixel distance between peaks. When two peaks are closer than
this, the dimmer one is removed.
y0, x0 : int, optional
Row and column offsets added to peak coordinates so they refer to the
full image rather than the patch.
Returns
-------
peaks : list of Peak
Local maxima sorted brightest-first, with coordinates in the full image.
"""
height, width = image.shape
peaks = []
for i in range(height):
for j in range(width):
val = image[i, j]
# 4-connected neighbours
if i > 0 and val <= image[i - 1, j]:
continue
if i < height - 1 and val <= image[i + 1, j]:
continue
if j > 0 and val <= image[i, j - 1]:
continue
if j < width - 1 and val <= image[i, j + 1]:
continue
# diagonal neighbours
if i > 0 and j > 0 and val <= image[i - 1, j - 1]:
continue
if i < height - 1 and j < width - 1 and val <= image[i + 1, j + 1]:
continue
if i < height - 1 and j > 0 and val <= image[i + 1, j - 1]:
continue
if i > 0 and j < width - 1 and val <= image[i - 1, j + 1]:
continue
peaks.append(Peak(i + y0, j + x0, float(val)))
# Sort brightest first
peaks.sort(key=lambda p: p.flux, reverse=True)
# Remove peaks within min_separation of a brighter peak
min_separation2 = min_separation**2
i = 0
while i < len(peaks):
j = i + 1
while j < len(peaks):
dy = peaks[i].y - peaks[j].y
dx = peaks[i].x - peaks[j].x
if dy * dy + dx * dx < min_separation2:
peaks.pop(j)
else:
j += 1
i += 1
return peaks
[docs]
def footprints(image, min_separation=0, min_area=9, thresh=0):
"""Detect footprints and their peaks in an image.
Thresholds the image at ``thresh``, labels 4-connected regions using
:func:`scipy.ndimage.label`, filters by minimum area, and locates peaks
within each footprint.
Parameters
----------
image : 2D array-like
The image to detect sources in. Accepts NumPy or JAX arrays.
min_separation : float, optional
Minimum pixel separation between peaks within a footprint.
min_area : int, optional
Minimum number of pixels a footprint must contain to be kept.
thresh : float, optional
Detection threshold; pixels must strictly exceed this value.
Returns
-------
footprints : list of Footprint
Detected footprints, each containing the boolean mask (sized to the
bounding box), peak list, and bounding box in the full image.
"""
image = np.asarray(image)
labeled, n_labels = ndimage_label(image > thresh)
if n_labels == 0:
return []
_footprints = []
for slices, label in zip(find_objects(labeled), range(1, n_labels + 1), strict=False):
sy, sx = slices
sub_fp = labeled[sy, sx] == label
if sub_fp.sum() < min_area:
continue
y0, y1, x0, x1 = sy.start, sy.stop, sx.start, sx.stop
patch = image[y0:y1, x0:x1].copy()
patch[~sub_fp] = 0
peaks = _get_patch_peaks(patch, min_separation, y0=y0, x0=x0)
if peaks:
_footprints.append(Footprint(sub_fp, peaks, ((y0, y1), (x0, x1))))
return _footprints
# ---------------------------------------------------------------------------
# Box / footprint utilities (adapted from scarlet v1 detect.py)
# ---------------------------------------------------------------------------
[docs]
def box_intersect(box1, box2):
"""Check whether two :class:`~scarlet2.bbox.Box` instances overlap.
Parameters
----------
box1: Box
First box
box2: Box
Second box
Returns
-------
bool
"""
overlap = box1 & box2
return overlap.shape[0] != 0 and overlap.shape[1] != 0
[docs]
def footprint_intersect(footprint1, box1, footprint2, box2):
"""Check whether two footprint masks overlap.
Parameters
----------
footprint1, footprint2 : ndarray of bool
The boolean masks for the two footprints, each sized to its own
bounding box.
box1, box2 : Box
The corresponding bounding boxes.
Returns
-------
overlap : bool
"""
if not box_intersect(box1, box2):
return False
slices1, slices2 = overlap_slices(box1, box2)
return np.sum(footprint1[slices1] * footprint2[slices2]) > 0
[docs]
def footprint_iou(source1, source2):
"""Compute intersection over union (IoU) between two source footprints.
Parameters
----------
source1, source2 : :class:`HierarchicalFootprint`
Sources with a `.bbox` (:class:`~scarlet2.bbox.Box`) and `.footprint`
(boolean ndarray sized to that bbox).
Returns
-------
iou : float
IoU in ``[0, 1]``. Returns ``0`` if the bounding boxes do not overlap.
"""
def _mask(source):
fp = source.footprint
return fp.footprint if isinstance(fp, Footprint) else fp
if not box_intersect(source1.bbox, source2.bbox):
return 0.0
slices1, slices2 = overlap_slices(source1.bbox, source2.bbox)
mask1, mask2 = _mask(source1), _mask(source2)
intersection = int(np.sum(mask1[slices1] & mask2[slices2]))
union = int(np.sum(mask1)) + int(np.sum(mask2)) - intersection
return float(intersection / union) if union > 0 else 0.0
# ---------------------------------------------------------------------------
# QuadTree
# ---------------------------------------------------------------------------
[docs]
class QuadTreeRegion:
"""A QuadTree that stores bounding boxes (rather than points).
Boxes that span sub-region boundaries are stored in *all* overlapping
sub-regions so that :meth:`query` always returns the full set of
overlapping boxes.
"""
def __init__(self, bbox, capacity=5, sub_regions=None, boxes=None, depth=0, detect=None):
"""
Parameters
----------
bbox : Box
The box that encloses this region.
capacity : int
Maximum number of boxes before the region is split.
sub_regions : list of QuadTreeRegion, optional
Pre-existing sub-regions (normally left as ``None``).
boxes : list of Box, optional
Pre-existing boxes (normally left as ``None``).
depth : int
Depth of this node in the full tree (used for debugging).
detect : array-like, optional
Detection image; when provided enables debug visualisations.
"""
self.bbox = bbox
self.sub_regions = sub_regions
self.boxes = boxes if boxes is not None else []
self.capacity = capacity
self.depth = depth
self.detect = detect
[docs]
def footprint_image(self, bbox=None):
"""Return a 2-D image of all footprint masks in the tree.
Parameters
----------
bbox : Box, optional
Output image bounding box. If ``None``, the union of all
footprint bounding boxes is used.
Returns
-------
image : ndarray
"""
boxes = self.query(self.bbox)
if bbox is None:
bbox = Box((0, 0))
for box in boxes:
bbox = bbox | box
footprint = np.zeros(bbox.shape)
for box in boxes:
full, local = overlap_slices(bbox, box)
footprint[full] += box.footprint.footprint[local]
return footprint
@property
def peaks(self):
"""Yield all :class:`Peak` objects contained in the tree."""
for box in self.query(self.bbox):
yield from box.footprint.peaks
[docs]
def add(self, other_box):
"""Add a box to the region.
Parameters
----------
other_box : Box
The box to insert.
"""
if not box_intersect(self.bbox, other_box):
return
if self.sub_regions is not None:
self._add_to_sub_regions(other_box)
return
if len(self.boxes) < self.capacity - 1:
self.boxes.append(other_box)
else:
self.split()
self.boxes = None
self._add_to_sub_regions(other_box)
[docs]
def add_footprints(self, footprints):
"""Insert bounding boxes for a list of :class:`Footprint` objects.
Each box gets a ``.footprint`` attribute pointing back to the
originating :class:`Footprint` so it can be retrieved from a query.
Parameters
----------
footprints : list of Footprint
Footprints to add bounding boxes for.
Returns
-------
self
"""
for fp in footprints:
box = Box.from_bounds(*fp.bounds)
box.footprint = fp
self.add(box)
return self
[docs]
def split(self):
"""Sub-divide this region into four quadrants."""
height, width = self.bbox.shape
h2 = height // 2
w2 = width // 2
h3 = height - h2
w3 = width - w2
origin = self.bbox.origin
self.sub_regions = [
QuadTreeRegion(Box((h2, w2), origin), capacity=self.capacity, depth=self.depth + 1),
QuadTreeRegion(
Box((h3, w2), (origin[0] + h2, origin[1])),
capacity=self.capacity,
depth=self.depth + 1,
),
QuadTreeRegion(
Box((h2, w3), (origin[0], origin[1] + w2)),
capacity=self.capacity,
depth=self.depth + 1,
),
QuadTreeRegion(
Box((h3, w3), (origin[0] + h2, origin[1] + w2)),
capacity=self.capacity,
depth=self.depth + 1,
),
]
for box in self.boxes:
self._add_to_sub_regions(box)
def _add_to_sub_regions(self, other_box):
for region in self.sub_regions:
region.add(other_box)
[docs]
def query(self, other_box=None):
"""Return all boxes that overlap with ``other_box``.
Parameters
----------
other_box : Box, optional
Query box. Defaults to the full region bbox.
Returns
-------
results : set of Box
Boxes that overlap with ``other_box``. A ``set`` is used so that
boxes stored in multiple sub-regions are only returned once.
"""
if other_box is None:
other_box = self.bbox
if self.boxes is not None:
return {box for box in self.boxes if box_intersect(box, other_box)}
if self.sub_regions is not None:
results = set()
for region in self.sub_regions:
if box_intersect(region.bbox, other_box):
results |= region.query(other_box)
return results
return set()
# ---------------------------------------------------------------------------
# Multi-scale structure
# ---------------------------------------------------------------------------
[docs]
class SingleScaleStructure:
"""A connected set of pixels with common peaks at a single wavelet scale.
Using the terminology of Starck et al. 2011, a *structure* is a connected
set of significant wavelet coefficients at a given scale, together with
any peaks contributed by overlapping structures at other scales.
Attributes
----------
scale : int
The wavelet scale of this structure.
footprint : Footprint
The footprint at the primary scale.
bbox : Box
Bounding box of the primary footprint.
peaks : dict
``{scale: [Peak, …]}`` — peaks contributed from each scale.
"""
def __init__(self, scale, footprint):
"""
Parameters
----------
scale : int
Wavelet scale of the primary footprint.
footprint : Footprint
The footprint at the primary scale.
"""
self.scale = scale
self.footprint = footprint
self.bbox = Box.from_bounds(*footprint.bounds)
self.peaks = {scale: footprint.peaks}
self._all_peaks = None
[docs]
def add_footprint(self, scale, footprint):
"""Add peaks from a footprint at another scale.
Parameters
----------
scale : int
The wavelet scale of this structure.
footprint : Footprint
The footprint at the primary scale.
"""
if scale not in self.peaks:
self.peaks[scale] = []
self.peaks[scale] += footprint.peaks
self._all_peaks = None
[docs]
def add_scale_tree(self, scale, tree):
"""Add all footprints from a :class:`QuadTreeRegion` at another scale
that overlap with this structure.
Parameters
----------
scale : int
The wavelet scale of this structure.
tree : QuadTreeRegion
The tree at this scale.
Returns
-------
self : SingleScaleStructure
"""
for box in tree.query(self.bbox):
self.add_footprint(scale, box.footprint)
return self
@property
def all_peaks(self):
"""Set of ``(x, y)`` tuples for every peak across all scales."""
if self._all_peaks is not None:
return self._all_peaks
all_peaks = set()
for peaks in self.peaks.values():
all_peaks |= {(peak.x, peak.y) for peak in peaks}
self._all_peaks = all_peaks
return self._all_peaks
# ---------------------------------------------------------------------------
# Wavelet-based detection helpers
# ---------------------------------------------------------------------------
[docs]
def get_wavelets(images, variance, max_scale=3):
"""Compute significant starlet coefficients for a multi-band image cube.
Parameters
----------
images : array-like, shape (bands, height, width)
Observed images.
variance : array-like, shape (bands, height, width)
Per-pixel variances matching ``images``.
max_scale : int
Number of wavelet scales.
Returns
-------
coeffs : ndarray, shape (bands, max_scale+1, height, width)
Starlet coefficients masked to the multi-resolution support.
"""
images = np.asarray(images)
variance = np.asarray(variance)
sigma = np.median(np.sqrt(variance), axis=(-2, -1))
coeffs = []
for b, image in enumerate(images):
_coeffs = np.asarray(starlet_transform(image, scales=max_scale))
M, _ = get_multiresolution_support(image, _coeffs, sigma[b], K=3, epsilon=1e-1, max_iter=20)
coeffs.append(M * _coeffs)
return np.array(coeffs)
[docs]
def get_detect_wavelets(image, variance, max_scale=3, K=3, image_type="ground"):
"""Get starlet coefficients of a detection image for source finding.
The detection image is inverse varianced weighted sum of `images` across all bands.
Parameters
----------
image : array-like, shape (bands, height, width)
Image to run multi-scale detection on.
variance : array-like, shape (bands, height, width)
Variance for every pixel in `image`.
max_scale : int
Number of wavelet scales.
K: float
The multiple of the coefficient scatter to calculate significance.
Coefficients `w` with `|w| > K*sigma_j`, where `sigma_j` is
the standard deviation at the jth scale, are considered significant.
image_type: str
The type of image that is being used.
This should be ``"ground"`` for ground based images with wide PSFs or
``"space"`` for images from space-based telescopes with a narrow PSF.
Returns
-------
coeffs : ndarray, shape (max_scale+1, height, width)
Masked starlet coefficients of the summed detection image.
sigma_j : ndarray, shape (max_scale+1,)
Per-scale noise estimate used for thresholding.
See Also
--------
:func:`~scarlet2.wavelets.get_multiresolution_support`
"""
image = np.asarray(image)
variance = np.asarray(variance)
sigma = np.median(np.sqrt(variance), axis=(-2, -1))
weights = 1 / sigma**2 # inverse variance weighting, per band
detect = np.sum(image * weights[:, None, None], axis=0) / np.sum(weights)
sigma = np.sqrt(1 / weights.sum())
_coeffs = np.asarray(starlet_transform(detect, scales=max_scale))
M, sigma_j = get_multiresolution_support(detect, _coeffs, sigma, K=K, image_type=image_type)
return np.asarray(M) * _coeffs, np.asarray(sigma_j)
[docs]
def get_blend_trees(detect, scales=None, min_separation=0, min_area=9, thresh=0):
"""Build a :class:`QuadTreeRegion` for each wavelet scale in ``detect``.
Parameters
----------
detect : ndarray, shape (scales+1, height, width)
Masked starlet coefficients (e.g. from :func:`get_detect_wavelets`).
scales : list of int, optional
Indices into ``detect`` specifying which scales to use. If ``None``
(default) all scales are used.
min_separation : float, optional
Minimum pixel separation between peaks within a footprint.
min_area : int, optional
Minimum number of pixels a footprint must contain to be kept.
thresh : float, optional
Detection threshold; pixels must strictly exceed this value.
Returns
-------
trees : list of QuadTreeRegion
One tree per selected scale.
all_footprints : list of list of Footprint
Raw footprints at each selected scale (same ordering as ``trees``).
"""
scales = list(range(len(detect))) if scales is None else sorted(scales)
all_footprints = []
for s in scales:
_footprints = footprints(
np.asarray(detect[s]),
min_separation=min_separation,
min_area=min_area,
thresh=thresh,
)
all_footprints.append(_footprints)
trees = [
QuadTreeRegion(Box(detect.shape[-2:]), capacity=10).add_footprints(fps) for fps in all_footprints
]
return trees, all_footprints
[docs]
def get_blend_structures(detect, scales=None, min_separation=0, min_area=9, thresh=0):
"""Build :class:`SingleScaleStructure` objects for the third wavelet scale.
Each structure at the largest scale is linked to all overlapping footprints at
finer scales, creating a hierarchy that connects fine-scale peaks to coarser detections.
Parameters
----------
detect : ndarray, shape (scales+1, height, width)
Masked starlet coefficients (e.g. from :func:`get_detect_wavelets`).
scales : list of int, optional
Indices into ``detect`` specifying which scales to use. If ``None``
(default) all scales are used.
min_separation : float, optional
Minimum pixel separation between peaks within a footprint.
min_area : int, optional
Minimum number of pixels a footprint must contain to be kept.
thresh : float, optional
Detection threshold; pixels must strictly exceed this value.
Returns
-------
structures : list of SingleScaleStructure
Structures at largest scale with peaks from smaller scales attached.
"""
scales = list(range(len(detect))) if scales is None else sorted(scales)
all_footprints = []
for s in scales:
_footprints = footprints(
np.asarray(detect[s]),
min_separation=min_separation,
min_area=min_area,
thresh=thresh,
)
all_footprints.append(_footprints)
# start with the footprints at the largest selected scale
structures = [SingleScaleStructure(scales[-1], fp) for fp in all_footprints[-1]]
# add trees connecting to smaller selected scales
box = Box(detect.shape[-2:])
scale_trees = {
scale: QuadTreeRegion(box, capacity=10).add_footprints(fps)
for scale, fps in zip(scales[:-1], all_footprints[:-1], strict=False)
}
for i in range(len(structures)):
for scale, tree in scale_trees.items():
structures[i].add_scale_tree(scale, tree)
return structures
# ---------------------------------------------------------------------------
# Footprint splitting
# ---------------------------------------------------------------------------
[docs]
def split_footprint(fp, image, min_area=0):
"""Split a multi-peak :class:`Footprint` into single-peak sub-footprints.
Segments the footprint area by finding the saddle points between peaks using
a priority-queue flooding watershed. The wavelet coefficient image at the
relevant scale is inverted so that peaks become low-cost basins; the watershed
floods outward from each peak seed simultaneously in order of increasing cost,
and region boundaries follow the intensity saddles between peaks.
Parameters
----------
fp : Footprint
The footprint to split. Returned unchanged (as a one-element list) if
it contains at most one peak.
image : 2D ndarray
Wavelet coefficient image at the scale of ``fp``.
min_area : int, optional
Minimum number of pixels a sub-footprint must contain to be kept.
Peaks whose watershed region is smaller than this are dropped.
Default is ``0`` (keep all).
Returns
-------
list of Footprint
One single-peak :class:`Footprint` per peak in ``fp`` that meets the
minimum area requirement.
"""
if len(fp.peaks) <= 1:
return [fp]
(y0, y1), (x0, x1) = fp.bounds
mask = fp.footprint
sub_image = np.asarray(image[y0:y1, x0:x1], dtype=float)
# Cost: invert intensity so peaks are cheap; normalize to [0, 1].
vals = sub_image[mask]
vmin, vmax = vals.min(), vals.max()
if vmax > vmin:
cost = np.where(mask, (vmax - sub_image) / (vmax - vmin), 1.0)
else:
cost = np.where(mask, 0.0, 1.0)
# Priority-queue flooding watershed: expand from all seeds simultaneously
# in order of increasing pixel cost (decreasing intensity). A flood can
# only reach a pixel through adjacent labeled pixels, so it cannot arc
# around another seed's region.
h, w = mask.shape
labels = np.zeros((h, w), dtype=np.int32)
heap = []
_nbrs = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
for k, peak in enumerate(fp.peaks):
py, px = peak.y - y0, peak.x - x0
labels[py, px] = k + 1
for di, dj in _nbrs:
ni, nj = py + di, px + dj
if 0 <= ni < h and 0 <= nj < w and mask[ni, nj] and labels[ni, nj] == 0:
heapq.heappush(heap, (cost[ni, nj], ni, nj, k + 1))
while heap:
c, i, j, label = heapq.heappop(heap)
if labels[i, j] != 0:
continue
labels[i, j] = label
for di, dj in _nbrs:
ni, nj = i + di, j + dj
if 0 <= ni < h and 0 <= nj < w and mask[ni, nj] and labels[ni, nj] == 0:
heapq.heappush(heap, (cost[ni, nj], ni, nj, label))
sub_footprints = []
for k, peak in enumerate(fp.peaks):
raw_region = (labels == k + 1) & mask
if raw_region.sum() < min_area:
continue
region = binary_fill_holes(raw_region)
if not region.any():
continue
rows = np.where(region.any(axis=1))[0]
cols = np.where(region.any(axis=0))[0]
ry0, ry1 = int(rows[0]), int(rows[-1]) + 1
rx0, rx1 = int(cols[0]), int(cols[-1]) + 1
bounds = ((y0 + ry0, y0 + ry1), (x0 + rx0, x0 + rx1))
sub_footprints.append(Footprint(region[ry0:ry1, rx0:rx1], [peak], bounds))
return sub_footprints
# ---------------------------------------------------------------------------
# Multiscale detections
# ---------------------------------------------------------------------------
[docs]
def multiscale_footprints(
obs,
scales=None,
strict=True,
K=3,
split_peaks=True,
image_type="ground",
min_separation=0,
min_area=9,
thresh=0,
return_intermediates=False,
):
"""Detect footprints at multiple starlet scales from an observation.
Builds a detection image from the inverse-variance weighted sum of ``obs.data``,
computes its starlet decomposition, and detects connected footprints at each
requested scale. Optionally splits footprints that contain more than one
peak into non-overlapping single-peak sub-footprints via a watershed algorithm.
Parameters
----------
obs : :class:`~scarlet2.Observation`
The observation providing image data and per-pixel weights.
scales : list of int, optional
Starlet scale indices to detect on. Default is ``[1, 2, 3]``.
strict : bool, optional
If ``True``, the coarse residual plane is pushed one scale higher so that
the selected ``scales`` are cleanly separated from the smooth background.
Default ``True``.
K : float, optional
Detection threshold multiplier: wavelet coefficients with
``|w| > K * sigma_j`` are considered significant. Default ``3``.
split_peaks : bool, optional
If ``True`` (default), footprints that contain more than one peak are
split into non-overlapping single-peak sub-footprints using a
priority-queue watershed. If ``False``, multi-peak footprints are
returned intact.
image_type : str, optional
PSF width regime used when computing the multi-resolution support.
``"ground"`` (default) for wide ground-based PSFs; ``"space"`` for
narrow space-based PSFs.
min_separation : float, optional
Minimum pixel distance between peaks within a footprint. Dimmer peaks
closer than this to a brighter one are suppressed. Default ``0``.
min_area : int, optional
Minimum number of pixels a footprint (or watershed sub-footprint) must
contain to be kept. Default ``9``.
thresh : float, optional
Additional detection threshold; pixels must strictly exceed this value
to be included in a footprint. Default ``0``.
return_intermediates : bool, optional
If ``True``, return a 4-tuple ``(detect, sigma, scales, all_footprints)``
that exposes the intermediate products. Default ``False``.
Returns
-------
all_footprints : dict mapping int → list of :class:`Footprint`
Footprints at each requested scale, keyed by scale index. When
``split_peaks`` is ``True`` every footprint carries exactly one peak.
detect : ndarray, shape (max_scale+2, height, width)
Masked starlet coefficients of the detection image.
Only returned when ``return_intermediates`` is ``True``.
sigma : ndarray, shape (max_scale+2,)
Per-scale noise estimates used for thresholding.
Only returned when ``return_intermediates`` is ``True``.
scales : list of int
The sorted list of scales that were used.
Only returned when ``return_intermediates`` is ``True``.
See Also
--------
:func:`~scarlet2.detect.get_detect_wavelets` : builds the detection image.
:func:`~scarlet2.detect.split_footprint` : watershed splitting of multi-peak footprints.
:func:`~scarlet2.detect.hierarchical_footprints` : higher-level function that links
footprints across scales into a source hierarchy.
"""
# default scales
scales = [1, 2, 3] if scales is None else sorted(scales)
# Compute detection image from obs
# for strict scale separation, need to push the "remaining" largest scale to one larger than max_scale
max_scale = max(scales) + strict
detect, sigma = get_detect_wavelets(
obs.data, 1 / obs.weights, max_scale=max_scale, K=K, image_type=image_type
)
all_footprints = {
s: footprints(
detect[s],
min_separation=min_separation,
min_area=min_area,
thresh=thresh,
)
for s in scales
}
if split_peaks:
# Pre-split multi-peak footprints so that all footprints at each scale
# are non-overlapping and carry exactly one peak.
all_footprints = {
s: [sub for fp in fps for sub in split_footprint(fp, detect[s], min_area=min_area)]
for s, fps in all_footprints.items()
}
if return_intermediates:
return detect, sigma, scales, all_footprints
return all_footprints
# ---------------------------------------------------------------------------
# Hierarchical detections
# ---------------------------------------------------------------------------
[docs]
@dataclass
class HierarchicalFootprint:
"""A source detected in the starlet hierarchy.
Attributes
----------
peak : tuple of int
``(y, x)`` peak position, refined to the finest scale reached.
bbox : Box
Bounding box at the scale this source was first detected.
footprint : np.ndarray
Binary mask for detected source in bbox.
scale : int
Largest wavelet scale at which this source was first identified.
children : list of HierarchicalFootprint
Sources whose peaks lie inside this source's footprint and are
spatially inconsistent with this source's primary peak.
"""
peak: tuple[int, int]
bbox: Box
footprint: np.ndarray
scale: int
children: list["HierarchicalFootprint"] = field(default_factory=list)
def __repr__(self):
return tree_pformat(self)
[docs]
def hierarchical_footprints(
obs,
scales=None,
strict=True,
K=3,
split_peaks=True,
image_type="ground",
min_separation=0,
min_area=9,
thresh=0,
catalog=None,
flatten=True,
return_detect=False,
):
"""Decompose an observed image into a list of :class:`HierarchicalFootprint` objects.
Creates a detection image from the inverse-variance weighted sum of `obs.data`,
then iterates from the largest starlet scale to the smallest. At each scale,
every detected footprint is matched to the best-overlapping source already
registered from larger scales (measured by IoU). If the registered source's
peak lies inside the new footprint, it is a *primary* match: the source
peak is refined and its footprint is grown to the union. Otherwise, the new
footprint becomes a *child* of the best-matching source. Footprints with no
overlap with any registered source are promoted to new top-level sources.
When ``split_peaks`` is ``True`` (default), footprints that contain more than
one peak are split into separate sources by using a watershed algorithm.
Otherwise, additional peaks become children of the originating footprint.
If ``catalog`` is provided, the detected sources are matched to the catalog
positions via a global bipartite assignment (Hungarian algorithm). The cost
of assigning catalog entry ``i`` to source ``j`` is the squared distance from
the catalog position to the nearest peak of source ``j``, restricted to cases
where the catalog position falls inside source ``j``'s footprint mask.
Catalog entries with no containing footprint are returned as ``None``.
Each footprint's bounding box is grown beyond the detection threshold to the noise level.
Assuming an exponential profile I(r) = I0*exp(-r/h), the scale length h is estimated from the footprint
size as h = r_foot / ln(S), where r_foot is the mean distance from the peak to the
edge of the footprint bounding box and S = I0 / (K*sigma_j). The box is grown
to the radius where the profile reaches the noise level (1*sigma_j):
half_size = r_foot * ln(S*K) / ln(S).
Parameters
----------
obs : :class:`~scarlet2.Observation`
The observation providing image data, per-pixel weights, and the
coordinate frame used to convert ``centers`` to pixel positions.
scales : list of int, optional
Starlet scales (indices into the coefficient array, default `[1,2,3]`) to use for detection.
strict : bool, optional
If ``True``, the coarse residual plane is pushed one scale higher so
that the selected ``scales`` are cleanly separated without bleed from
the largest-scale smooth background. Default ``False``.
K : float, optional
Detection threshold multiplier used when ``sigma_scales`` is given.
Must match the value passed to :func:`~scarlet2.detect.get_detect_wavelets`.
Default ``3``.
split_peaks : bool, optional
If ``True`` (default), footprints with multiple peaks are split into separate sources
using a watershed algorithm. Otherwise, additional peaks become children of the
originating footprint, which retains the full footprint area, i.e. the children overlap.
Splitting peaks allows to reduce the overlap of mostly independent sources.
image_type: str
The type of image that is being used.
This should be ``"ground"`` for ground based images with wide PSFs or
``"space"`` for images from space-based telescopes with a narrow PSF.
min_separation : float, optional
Minimum pixel separation between peaks within a footprint.
min_area : int, optional
Minimum number of pixels a footprint must contain to be kept. Also used
as the minimum area for watershed sub-footprints when ``split_peaks`` is
``True``.
thresh : float, optional
Detection threshold; pixels must strictly exceed this value.
catalog : list of (y, x) tuples, optional
If given, the output is catalog-indexed: one entry per catalog position,
matched to the best overlapping detected source, or ``None`` if undetected.
Matching is a global optimal assignment — each source is assigned to at
most one catalog entry.
flatten : bool, optional
Whether to flatten the source list so that children appear as independent
entries. Default ``True``.
return_detect: bool, optional
Whether to return the detection image. Default ``False``.
Returns
-------
sources : list of HierarchicalFootprint or None
When ``catalog`` is ``None``: top-level sources, each potentially carrying
a tree of children. Sources detected only at finer scales appear as
additional top-level entries with their ``scale`` set accordingly.
When ``catalog`` is given: catalog-length list where each entry is the
matched :class:`HierarchicalFootprint`, or ``None`` if no source was
detected at that catalog position.
"""
# Get initial multiscale footprints
detect, sigma, scales, all_footprints = multiscale_footprints(
obs,
scales=scales,
strict=strict,
K=K,
min_separation=min_separation,
min_area=min_area,
thresh=thresh,
image_type=image_type,
split_peaks=split_peaks,
return_intermediates=True,
)
def snr_bbox(fp, scale, sigma):
"""Bounding box grown to the SNR-predicted extent of an exponential profile.
For a profile I(r) = I0*exp(-r/h), the footprint boundary lies at the
detection threshold K*sigma_j, so h = r_foot / ln(S) where
S = I0/(K*sigma_j) and r_foot is the mean distance from the peak to the
edge of the footprint bbox. The box is grown to the radius
where the profile drops to the noise level (1*sigma_j):
``half_size = r_foot * ln(S*K) / ln(S)``.
The result is unioned with the tight footprint bbox so the box never shrinks.
"""
bbox = Box.from_bounds(*fp.bounds)
outer_box = Box(detect.shape[1:])
# grow bbox if smaller than the kernel support at this scale
min_half = 2 * (2**scale)
delta = tuple(max(0, min_half - bbox.shape[d] // 2) for d in range(bbox.D))
bbox.grow(delta)
if sigma is None or sigma <= 0:
return bbox & outer_box
w_peak = fp.peaks[0].flux
S = w_peak / (K * sigma)
if S <= 1:
return bbox
peak = fp.peaks[0]
(y0, y1), (x0, x1) = fp.bounds
r_foot = np.mean([peak.y - y0, y1 - peak.y, peak.x - x0, x1 - peak.x])
half_size = int(np.ceil(r_foot * np.log(S * K) / np.log(S)))
snr_box = Box(
(2 * half_size + 1, 2 * half_size + 1),
origin=(peak.y - half_size, peak.x - half_size),
)
return (bbox | snr_box) & outer_box
def peak_in_footprint(y, x, fp):
"""True if pixel (y, x) lies inside the boolean mask of Footprint fp."""
ly = y - fp.bounds[0][0]
lx = x - fp.bounds[1][0]
h, w = fp.footprint.shape
return 0 <= ly < h and 0 <= lx < w and bool(fp.footprint[ly, lx])
def all_nodes(node_list):
"""Yield every SceneSource in the tree rooted at each node in node_list."""
for node in node_list:
yield node
if node is not None:
yield from all_nodes(node.children)
def peaks2children(fp, scale):
"""Return children for additional peaks in ``fp`` using watershed sub-footprints."""
children = []
for sub_fp in split_footprint(fp, detect[scale], min_area=min_area)[1:]:
child = HierarchicalFootprint(
peak=sub_fp.peaks[0],
bbox=Box.from_bounds(*sub_fp.bounds),
footprint=sub_fp,
scale=scale,
)
children.append(child)
return children
# --- initialize hierarchy at largest scale: additional peaks become children -
sources = []
scale = max(all_footprints.keys())
for fp in all_footprints[scale]:
node = HierarchicalFootprint(
peak=fp.peaks[0],
bbox=Box.from_bounds(*fp.bounds),
footprint=fp, # store entire Footprint temporarily, clean up later
scale=scale,
children=peaks2children(fp, scale),
)
sources.append(node)
# --- link smaller scale footprints to larger scale footprints -
# exclude largest scale, 2nd largest to smallest
smaller_scales = sorted(list(all_footprints.keys()), reverse=True)[::-1]
for scale in smaller_scales:
registered_nodes = list(all_nodes(sources))
for fp in all_footprints[scale]:
peak = fp.peaks[0]
node = HierarchicalFootprint(
peak=peak,
bbox=Box.from_bounds(*fp.bounds),
footprint=fp, # store entire Footprint temporarily, clean up later
scale=scale,
children=peaks2children(fp, scale),
)
# new fps are either matches to an existing source, to one of their children, or an orphan
overlapping = [
i
for i, rfp in enumerate(registered_nodes)
if peak_in_footprint(peak.y, peak.x, rfp.footprint)
]
if len(overlapping) == 0: # orphan: add to sources
sources.append(node)
else:
# determine best match: intersection over union
overlap = {i: footprint_iou(node, registered_nodes[i]) for i in overlapping}
max_i = max(overlap, key=overlap.get)
parent = registered_nodes[max_i]
# if peak of parent is in footprint of new fp: primary match
# use fp center to refine parent center and update footprint with union
if peak_in_footprint(parent.peak.y, parent.peak.x, fp):
parent.peak.y, parent.peak.x = node.peak.y, node.peak.x
# union bbox and footprint mask with the primary fp at this finer scale
primary_bbox = Box.from_bounds(*fp.bounds)
parent_bbox = Box.from_bounds(*parent.footprint.bounds)
union_bbox = parent_bbox | primary_bbox
union_mask = np.zeros(union_bbox.shape, dtype=bool)
p_slices, pp_slices = overlap_slices(union_bbox, parent_bbox)
union_mask[p_slices] |= parent.footprint.footprint[pp_slices]
q_slices, qp_slices = overlap_slices(union_bbox, primary_bbox)
union_mask[q_slices] |= fp.footprint[qp_slices]
parent.footprint = Footprint(union_mask, parent.footprint.peaks, union_bbox.bounds)
parent.bbox = union_bbox
# if not: new child
else:
parent.children.append(node)
# --- flatten list: children are listed separately -
if flatten:
sources = list(all_nodes(sources))
# --- catalog filter: keep only footprints containing a given sky position -
if catalog is not None:
# Build cost matrix: rows = catalog entries, cols = detected sources.
# Cost is squared distance from catalog position to nearest source peak,
# restricted to cases where the catalog position falls inside the footprint.
# np.inf marks invalid (position outside footprint) pairs.
_no_match = 1e18
cost = np.full((len(catalog), len(sources)), _no_match)
for i, (py, px) in enumerate(catalog):
for j, s in enumerate(sources):
if peak_in_footprint(int(py), int(px), s.footprint):
cost[i, j] = min((py - p.y) ** 2 + (px - p.x) ** 2 for p in s.footprint.peaks)
# Solve the global assignment problem; pairs with sentinel cost are unmatched.
row_ind, col_ind = linear_sum_assignment(cost)
matches = [None] * len(catalog)
for i, j in zip(row_ind, col_ind, strict=False):
if cost[i, j] < _no_match:
matches[i] = j
sources = [sources[j] if j is not None else None for j in matches]
# clean up list: pad footprint mask to match the (possibly enlarged) bbox,
# then replace the Footprint object with the plain boolean array.
for i in range(len(sources)):
# if limit_to is used, we can get None for non-detections, and we can have
# the several limit_to centers point to the same detection.
# in either case: don't postprocess them (again)
if sources[i] is None or not isinstance(sources[i].footprint, Footprint):
continue
fp_obj = sources[i].footprint # still a Footprint at this point
enlarged_bbox = snr_bbox(fp_obj, sources[i].scale, sigma[sources[i].scale])
tight_bbox = Box.from_bounds(*fp_obj.bounds)
padded = np.zeros(enlarged_bbox.shape, dtype=bool)
enlarged_slices, tight_slices = overlap_slices(enlarged_bbox, tight_bbox)
padded[enlarged_slices] = fp_obj.footprint[tight_slices]
sources[i].bbox = enlarged_bbox # now only numpy array of footprint
sources[i].footprint = padded
if flatten:
sources[i].children = []
if return_detect:
return sources, detect
return sources