Source code for scarlet2.bbox

import equinox as eqx
import jax.numpy as jnp


[docs] class Box(eqx.Module): """Bounding Box for data array A Bounding box describes the location of a data array in the model coordinate system. It is used to identify spatial and channel overlap and to map from model to observed frames and back. The `BBox` code is agnostic about the meaning of the dimensions. We generally use this convention: - 2D shapes denote (Height, Width) - 3D shapes denote (Channels, Height, Width) """ shape: tuple """Size of the array""" origin: tuple """Start coordinate (in 2D: lower-left corner) of the array in model frame""" def __init__(self, shape, origin=None): """ Parameters ---------- shape: tuple Size of the array origin: tuple, optional Start coordinate (in 2D: lower-left corner) of the array in model frame """ self.shape = tuple(shape) if origin is None: origin = (0,) * len(shape) self.origin = tuple(origin)
[docs] @staticmethod def from_bounds(*bounds): """Initialize a box from its bounds Parameters ---------- bounds: tuple of (min,max) pairs Min/Max coordinate for every dimension Returns ------- bbox: Box A new box bounded by the input bounds. """ shape = tuple(max(0, cmax - cmin) for cmin, cmax in bounds) origin = (cmin for cmin, cmax in bounds) return Box(shape, origin=origin)
[docs] @staticmethod def from_data(X, min_value=0): # noqa: N803 """Define box where `X` is above `min_value` Parameters ---------- X : jnp.ndarray Data to threshold min_value : float Minimum value of the result. Returns ------- bbox : :class:`scarlet2.bbox.Box` Bounding box for the thresholded `X` """ sel = min_value < X if sel.any(): nonzero = jnp.where(sel) bounds = [] for dim in range(len(X.shape)): bounds.append((nonzero[dim].min(), nonzero[dim].max() + 1)) else: bounds = [[0, 0]] * len(X.shape) return Box.from_bounds(*bounds)
[docs] def contains(self, p): """Whether the box contains a given coordinate `p`""" if len(p) != self.D: raise ValueError(f"Dimension mismatch in {p} and {self.D}") for d in range(self.D): if p[d] < self.origin[d] or p[d] >= self.origin[d] + self.shape[d]: return False return True
[docs] def get_extent(self): """Return the start and end coordinates.""" return [self.start[-1], self.stop[-1], self.start[-2], self.stop[-2]]
@property def D(self): # noqa: N802 """Dimensionality of this BBox""" return len(self.shape) @property def start(self): """Tuple of start coordinates""" return self.origin @property def stop(self): """Tuple of stop coordinates""" return tuple(o + s for o, s in zip(self.origin, self.shape, strict=False)) @property def center(self): """Tuple of center coordinates""" return tuple(o + s // 2 for o, s in zip(self.origin, self.shape, strict=False)) @property def bounds(self): """Bounds of the box""" return tuple((o, o + s) for o, s in zip(self.origin, self.shape, strict=False)) @property def slices(self): """Bounds of the box as slices""" return tuple([slice(o, o + s) for o, s in zip(self.origin, self.shape, strict=False)]) @property def spatial(self): """Spatial component of higher-dimensional box""" assert self.D >= 2 return self[-2:]
[docs] def set_center(self, pos): """Center box at given position""" pos_ = tuple(_.item() for _ in pos) origin = tuple(o + p - c for o, p, c in zip(self.origin, pos_, self.center, strict=False)) object.__setattr__(self, "origin", origin)
[docs] def grow(self, delta): """Grow the Box by the given delta in each direction""" if not hasattr(delta, "__iter__"): delta = [delta] * self.D origin = tuple([self.origin[d] - delta[d] for d in range(self.D)]) shape = tuple([self.shape[d] + 2 * delta[d] for d in range(self.D)]) return Box(shape, origin=origin)
[docs] def shrink(self, delta): """Shrink the Box by the given delta in each direction""" if not hasattr(delta, "__iter__"): delta = [delta] * self.D origin = tuple([self.origin[d] + delta[d] for d in range(self.D)]) shape = tuple([self.shape[d] - 2 * delta[d] for d in range(self.D)]) return Box(shape, origin=origin)
[docs] def __or__(self, other): """Union of two bounding boxes Parameters ---------- other: `Box` The other bounding box in the union Returns ------- result: `Box` The smallest rectangular box that contains *both* boxes. """ if other.D != self.D: raise ValueError(f"Dimension mismatch in the boxes {other} and {self}") bounds = [] for d in range(self.D): bounds.append((min(self.start[d], other.start[d]), max(self.stop[d], other.stop[d]))) return Box.from_bounds(*bounds)
[docs] def __and__(self, other): """Intersection of two bounding boxes If there is no intersection between the two bounding boxes then an empty bounding box is returned. Parameters ---------- other: `Box` The other bounding box in the intersection Returns ------- result: `Box` The rectangular box that is in the overlap region of both boxes. """ if other.D != self.D: raise ValueError(f"Dimension mismatch in the boxes {other} and {self}") assert other.D == self.D bounds = [] for d in range(self.D): bounds.append((max(self.start[d], other.start[d]), min(self.stop[d], other.stop[d]))) return Box.from_bounds(*bounds)
def __getitem__(self, i): s_ = self.shape[i] o_ = self.origin[i] if not hasattr(s_, "__iter__"): s_ = (s_,) o_ = (o_,) return Box(s_, origin=o_) def __add__(self, offset): if not hasattr(offset, "__iter__"): offset = (offset,) * self.D origin = tuple([a + o for a, o in zip(self.origin, offset, strict=False)]) return Box(self.shape, origin=origin) def __sub__(self, offset): if not hasattr(offset, "__iter__"): offset = (offset,) * self.D origin = tuple([a - o for a, o in zip(self.origin, offset, strict=False)]) return Box(self.shape, origin=origin) def __matmul__(self, bbox): bounds = self.bounds + bbox.bounds return Box.from_bounds(*bounds) def __copy__(self): return Box(self.shape, origin=self.origin) def __eq__(self, other): return self.shape == other.shape and self.origin == other.origin def __hash__(self): return hash((self.shape, self.origin))
def overlap_slices(bbox1, bbox2, return_boxes=False): """Slices of bbox1 and bbox2 that overlap Parameters ---------- bbox1: `~scarlet.bbox.Box` The first box to use for comparing overlap. bbox2: `~scarlet.bbox.Box` The second box to use for comparing overlap. return_boxes: bool If True return new boxes corresponding to the overlapping portion of each of the input boxes. If False, return the overlapping portion of the original boxes. Default False. Returns ------- slices: tuple of slices The slice of an array bounded by `bbox1` and the slice of an array bounded by `bbox2` in the overlapping region. """ overlap = bbox1 & bbox2 _bbox1 = overlap - bbox1.origin _bbox2 = overlap - bbox2.origin if return_boxes: return _bbox1, _bbox2 slices = ( _bbox1.slices, _bbox2.slices, ) return slices def insert_into(image, sub, bbox): """Insert `sub` into `image` according to this bbox Inverse operation to :func:`~scarlet.bbox.Box.extract_from`. Parameters ---------- image: array Full image sub: array Smaller sub-image bbox: Box Bounding box that describes the shape and position of `sub` in the pixel coordinates of `image`. Returns ------- image: array Image with `sub` inserted at `bbox`. """ imbox = Box(image.shape) im_slices, sub_slices = overlap_slices(imbox, bbox) try: image[im_slices] = sub[sub_slices] # numpy arrays except TypeError: image = image.at[im_slices].set(sub[sub_slices]) # jax arrays return image