"""Plotting functions"""
from abc import ABC, abstractmethod
from warnings import warn
import astropy
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Polygon
from . import measure
from .bbox import Box, insert_into
from .detect import HierarchicalFootprint
from .renderer import ChannelRenderer
[docs]
def channels_to_rgb(channels):
"""Get the linear mapping of multiple channels to RGB channels.
Channels are assumed to be ordered by wavelength, shortest first.
Each channel i is treated as occupying the unit interval ``[i, i+1]`` in
spectral-index space ``[0, N]``. That space is divided into three equal
contiguous bands allocated to B, G, and R::
Blue: [0, N/3]
Green: [N/3, 2N/3]
Red: [2N/3, N]
The weight of channel i in each RGB output is the fractional overlap of
``[i, i+1]`` with the corresponding band. By construction:
* Column sums equal 1 — 100% of each input channel's intensity reaches
the display (photon conservation).
* Row sums all equal N/3 — every RGB output channel receives the same
total weight (doubly balanced).
Parameters
----------
channels: int
Number of channels (any positive integer).
Returns
-------
array
Shape ``(3, channels)`` mapping input channels onto RGB.
"""
assert channels >= 1, "channels must be a positive integer"
channel_map = np.zeros((3, channels))
if channels == 1:
channel_map[:, 0] = 1 / 3
else:
s = channels / 3 # width of each RGB band in spectral-index units
for i in range(channels):
lo, hi = float(i), float(i + 1)
channel_map[2, i] = max(0.0, min(hi, s) - lo) # B: [0, s]
channel_map[1, i] = max(0.0, min(hi, 2 * s) - max(lo, s)) # G: [s, 2s]
channel_map[0, i] = max(0.0, min(hi, 3 * s) - max(lo, 2 * s)) # R: [2s, 3s]
return channel_map
[docs]
class Norm(ABC):
"""Base class to normalize the color values of RGB images"""
def __init__(self):
self._uint8Max = np.iinfo(np.uint8).max
[docs]
def get_intensity(self, im):
"""Compute total intensity image"""
return np.maximum(0, im).sum(axis=0)
[docs]
def clip(self, im, min_value, max_value):
"""Clip image between min_value and max_value"""
return np.maximum(0, np.minimum(im - min_value, max_value - min_value))
[docs]
def convert_to_uint8(self, im):
"""Convert three-channel image to RGB image with uint8 dtype"""
im_clipped = self.clip(np.nan_to_num(im, nan=0.0), 0, 1)
uint_im = (im_clipped * self._uint8Max).astype("uint8")
im_flipped = uint_im.transpose().swapaxes(0, 1) # 3 x Ny x Nx -> Ny x Nx x 3
return im_flipped
[docs]
def make_rgb_image(self, *im):
"""Compute RGB image from three-channel image"""
# backwards compatible to astropy Mapping Call
return self.convert_to_uint8(self.__call__(np.stack(im, axis=0)))
[docs]
@abstractmethod
def __call__(self, im):
"""Compute normalized three-channel image"""
pass
[docs]
class LinearNorm(Norm):
"""Class for linear normalization"""
def __init__(self, minimum, maximum):
"""Linear norm, mapping the interval [`minimum`, `maximum`] to [0,1]
Parameters
----------
minimum: float
Value that will be mapped to 0
maximum: float
Value that will be mapped to 1
"""
self.min_value, self.max_value = minimum, maximum
super().__init__()
[docs]
def __call__(self, im):
"""Compute linear normalized image"""
return self.clip(im, self.min_value, self.max_value) / (self.max_value - self.min_value)
[docs]
class LinearPercentileNorm(LinearNorm):
"""Class for linear normalization based on percentiles"""
def __init__(self, img, percentiles=(1, 99)):
"""Norm that is linear between the two elements of `percentiles` of `img`
Parameters
----------
img: array
Image to normalize
percentiles: array-like, optional
Lower and upper percentile to consider. Pixel values below will be
set to zero, above to saturated. Default is (1, 99)
"""
assert len(percentiles) == 2
vmin, vmax = np.percentile(img, percentiles)
super().__init__(minimum=vmin, maximum=vmax)
[docs]
class AsinhNorm(Norm):
"""AsinhNorm class"""
def __init__(self, min_value, max_value, beta):
"""Norm that scales as arcsinh(I / beta) between `m` and `M`
See Lupton+(2004) https://ui.adsabs.harvard.edu/abs/2004PASP..116..133L
Parameters
----------
min_value: float
Minimum value to consider
max_value: float
Maximum value to consider
beta: float
Turnover point of arcsinh. Below it norm behaves linear, above
it norm approximates ln(2*I)
"""
self.min_value, self.max_value, self.beta = min_value, max_value, beta
self._rgb_max = 1
super().__init__()
[docs]
def set_rgb_max(self, img, vibrance=0.15):
"""Set maximum value of normalized image
Parameters
----------
img: array
Three-channel image
vibrance: float
Allowance to exceed normalization of three-channel image.
Makes images more vibrant but causes slight color shifts towards white in the highlights.
"""
rgb = self.__call__(img)
self._rgb_max = rgb[np.isfinite(rgb)].max() / (1 + vibrance)
[docs]
def __call__(self, img):
"""Compute Asinh normalized image"""
min_value = self.min_value
max_value = self.max_value
intensity = self.get_intensity(img)
with np.errstate(invalid="ignore", divide="ignore"): # n.b. np.where can't and doesn't short-circuit
# clip between m and M
i_ = self.clip(intensity - min_value, 0, max_value - min_value)
# arcsinh scaling from Lupton+(2004)
f = np.arcsinh(i_ / self.beta) # no need to normalize, done below
rgb = np.where(intensity[None, :, :] > 0, img / (intensity / f)[None, :, :], 0)
# keep rgb between 0 and 1 (with an allowance of self.vibrance)
rgb = rgb / self._rgb_max
return rgb
[docs]
class AsinhPercentileNorm(AsinhNorm):
"""AsinhPercentileNorm class"""
def __init__(self, img, percentiles=(45, 50, 99), vibrance=0.15):
"""Norm that scales as arcsinh(I / beta) between bottom and top percentile
Uses the middle percentile to define the turnover `beta`. The defaults
are chosen such that the median (percentile 50) tries to catch emission
slightly above the sky level, while the minimum is aiming for the sky
intensity itself.
Parameters
----------
img: array_like
Image to normalize
percentiles: array_like
Lower, middle, and upper percentile to consider. Pixel values below will be
set to zero, above to one. Asinh turnover is given by middle percentile.
Default is (45,50,99)
vibrance: float
Allowance to exceed normalization of three-channel image.
Makes images more vibrant but causes slight color shifts in the highlights.
"""
assert len(percentiles) == 3
min_value, beta, max_value = np.percentile(img, percentiles)
super().__init__(min_value, max_value, beta)
super().set_rgb_max(img, vibrance=vibrance)
[docs]
class AsinhAutomaticNorm(AsinhNorm):
"""AsinhAutomaticNorm class"""
def __init__(
self,
observation,
channel_map=None,
minimum=0,
upper_percentile=99.5,
noise_level=1,
vibrance=0.15,
):
"""Norm that scales as arcsinh(I / beta) with parameters chosen automatically
The turnover `beta` is taken from the at `noise_level` * RMS, where RMS is the
total variance of the observations. This norm should automatically create an
image scaling that picks out low-surface brightness features and highlights.
Parameters
----------
observation: py:class:`~scarlet2.Observation`
Observation object with weights
channel_map: array
Linear mapping from channels to RGB, dimensions (3, channels)
minimum: float
Minimum value to consider.
upper_percentile: float
Upper percentile: Pixel values above will be saturated.
noise_level: float
Factor to be multiplied to the total noise RMS to define the turnover point
vibrance: float
Allowance to exceed normalization of three-channel image.
Makes images more vibrant but causes slight color shifts in the highlights.
"""
if channel_map is None:
channel_map = channels_to_rgb(observation.frame.C)
im3 = img_to_3channel(observation.data, channel_map=channel_map)
var3 = np.where(observation.weights > 0, 1 / observation.weights, 0) # filter pixels with 0 weight
var3 = img_to_3channel(var3, channel_map=channel_map)
# total intensity and variance images
i = self.get_intensity(im3)
v = self.get_intensity(var3)
# find upper clipping point
(max_value,) = np.percentile(i.flatten(), [upper_percentile])
min_value = minimum
# find a good turnover point for arcsinh: ~noise level
rms = np.median(np.sqrt(v))
beta = rms * noise_level
super().__init__(min_value, max_value, beta)
super().set_rgb_max(im3, vibrance=vibrance)
[docs]
def img_to_3channel(img, channel_map=None):
"""Convert multi-band image cube into 3 RGB channels
Parameters
----------
img: array
This should be an array with dimensions (channels, height, width).
channel_map: array
Linear mapping from channels to RGB, dimensions (3, channels)
Returns
-------
array
Dimensions (3, height, width), type float
"""
# expand single img into cube
assert img.ndim in [2, 3]
if len(img.shape) == 2:
ny, nx = img.shape
img_ = img.reshape(1, ny, nx)
elif len(img.shape) == 3:
img_ = img
num_channels = len(img_)
# filterWeights: channel x band
if channel_map is None:
channel_map = channels_to_rgb(num_channels) # channel x band
else:
assert channel_map.shape == (3, len(img))
# filter out masked and bad values
img_ = np.where(np.isfinite(img_), img_, 0)
# map channels onto RGB channels
_, ny, nx = img_.shape
rgb = np.dot(channel_map, img_.reshape(num_channels, -1)).reshape(3, ny, nx)
return rgb
[docs]
def img_to_rgb(img, channel_map=None, norm=None, mask=None):
"""Convert images to normalized RGB.
If normalized values are outside of the range [0..255], they will be
truncated such as to preserve the corresponding color.
Parameters
----------
img: array
This should be an array with dimensions (channels, height, width).
channel_map: array
Linear mapping from channels to RGB, dimensions (3, channels)
norm: Norm, optional
Norm to use for mapping in the allowed range [0..255]. If `norm=None`,
`scarlet.display.LinearPercentileNorm` will be used.
mask: array_like, optional
A [0,1] binary mask set define where pixels will have 0 opacity.
Returns
-------
array
Dimensions (3, height, width), type float
"""
im3 = img_to_3channel(img, channel_map=channel_map)
if norm is None:
norm = LinearPercentileNorm(im3)
rgb = norm.make_rgb_image(*im3)
if mask is not None:
rgb = np.dstack([rgb, ~mask * 255])
return rgb
panel_size = 4.0
[docs]
def observation(
observation,
norm=None,
channel_map=None,
show_psf=False,
add_peaks=None,
add_footprints=None,
add_labels=False,
sky_coords=None,
split_channels=False,
fig_kwargs=None,
title_kwargs=None,
label_kwargs=None,
):
"""Plot observation
Show entire content of `observation`, optionally with list of sources given
by `sky_coords` or a PSF image.
Parameters
----------
observation: :py:class:`~scarlet2.Observation`
The observation object to plot
norm: Norm, optional
Norm to scale the intensity of `observation` into RGB 0..256
channel_map: array, optional
Linear mapping from channels to RGB, dimensions (3, channels)
show_psf: bool, optional
Whether to plot a panel with the PSF model of `observation` centered in
the middle
add_peaks: (None, list[astropy.SkyCoords], list[SourceFootprints]), optional
Whether to plot a text label with the running number for each of the listed coordinates
add_footprints: (None, list[SourceFootprints]), optional
Whether to plot the footprints as semi-transparent layer over the image
add_labels: bool, optional
Whether source IDs are shown at the location of `sky_coords`.
Deprecated: use `add_peaks` instead!
sky_coords: list[astropy.SkyCoords], optional
Coordinates to plot source IDs.
Deprecated: use `add_peaks` instead!
split_channels: bool, optional
Whether to split the observation into separate channels
fig_kwargs: dict, optional
Additional arguments for `mpl.subplots`
title_kwargs: dict, optional
Additional arguments for `mpl.set_title`
label_kwargs: dict, optional
Additional arguments for `mpl.text`. Default is None and will be set to
`{"color": "w", "ha": "center", "va": "center"}`
Returns
-------
mpl.Figure
"""
if fig_kwargs is None:
fig_kwargs = {}
if title_kwargs is None:
title_kwargs = {}
if label_kwargs is None:
label_kwargs = {"color": "w", "ha": "center", "va": "center"}
if add_labels or sky_coords is not None:
warn(
"`add_labels` and `sky_coords` are deprecated, use `add_peaks=ra_dec` instead!",
DeprecationWarning,
stacklevel=2,
)
add_peaks = sky_coords
if show_psf:
assert observation.frame.psf is not None, "show_psf requires observation.frame.psf to be set"
psf_model = observation.frame.psf()
rows = len(observation.frame.channels) if split_channels else 1
panels = 1 if show_psf is False else 2
figsize = fig_kwargs.pop("figsize", None)
if figsize is None:
figsize = (panel_size * panels, panel_size * rows)
fig, ax = plt.subplots(rows, panels, figsize=figsize, squeeze=False, **fig_kwargs)
if not hasattr(ax, "__iter__"):
ax = (ax,)
extent = observation.frame.bbox.get_extent()
if add_peaks is not None and len(add_peaks):
centers = []
for _ in add_peaks:
if isinstance(_, astropy.coordinates.SkyCoord):
centers.append(observation.frame.get_pixel(_))
elif isinstance(_, HierarchicalFootprint):
centers.append((_.peak.y, _.peak.x))
else:
centers.append(_)
else:
centers = []
if add_footprints is not None and len(add_footprints):
shape = observation.frame.bbox.spatial.shape
num_scales = len(np.unique(np.asarray([sfp.scale for sfp in add_footprints if sfp is not None])))
footprint_map = np.zeros(shape)
for sfp in add_footprints:
if sfp is not None:
footprint_map += insert_into(np.zeros(shape), 1 / num_scales * sfp.footprint, sfp.bbox)
for row in range(rows):
if split_channels:
data = observation.data[row]
mask = observation.weights[row] == 0
name = observation.frame.channels[row]
if show_psf:
psf = psf_model[row]
# make PSF as bright as the brightest pixel of the observation
psf *= data.max() / psf.max()
else:
data = observation.data
# Mask any pixels with zero weight in all channels
mask = np.sum(observation.weights, axis=0) == 0
name = observation.name if hasattr(observation, "name") else ""
if show_psf:
psf = psf_model
# make PSF as bright as the brightest pixel of the observation
psf *= observation.data.mean(axis=0).max() / psf_model.mean(axis=0).max()
# if there are no masked pixels, do not use a mask
if np.all(mask == 0):
mask = None
panel = 0
ax[row, panel].imshow(
img_to_rgb(data, norm=norm, channel_map=channel_map, mask=mask),
extent=extent,
origin="lower",
)
ax[row, panel].set_title(f"Observation {name}", **title_kwargs)
if add_peaks is not None:
for k, center in enumerate(centers):
if center is not None:
ax[row, panel].text(*center[::-1], k, **label_kwargs)
if add_footprints is not None:
ax[row, panel].imshow(footprint_map, cmap="grey", alpha=0.3, extent=extent, origin="lower")
if show_psf:
panel = 1
psf_image = np.zeros(data.shape)
# insert into middle of "blank" observation
full_box = Box(psf_image.shape)
shift = tuple(psf_image.shape[d] // 2 - psf.shape[d] // 2 for d in range(full_box.D))
model_box = Box(psf.shape) + shift
psf_image = insert_into(psf_image, psf, model_box)
# slices = scarlet.box.overlapped_slices
ax[row, panel].imshow(img_to_rgb(psf_image, norm=norm), origin="lower")
ax[row, panel].set_title("PSF", **title_kwargs)
fig.tight_layout()
return fig
# ------------------------------------------------------ #
# include a routine to calculate the hallucination score #
# ----------------------------------------------------- #
# ruff: noqa: F821
# ignore jnp functions for outdated hallucination score
[docs]
def cut_square_box(arr, center, size):
"""
Cut out a square box from a 2D array based on the center and size.
Parameters:
arr: numpy.ndarray
The input 2D array.
center: tuple
The center of the box in the format (row_center, col_center).
size: int
The size of the square box (side length).
Returns:
numpy.ndarray: The square box extracted from the input array.
"""
# get the dimensions of the data
obs_dim = arr.ndim
row_center, col_center = center
# col_center, row_center = center
half_size = size // 2
# Calculate the indices for slicing
start_row = row_center - half_size
end_row = start_row + size
start_col = col_center - half_size
end_col = start_col + size
# Ensure the indices are within the array bounds
start_row = max(0, start_row)
start_col = max(0, start_col)
if obs_dim == 2:
end_row = min(arr.shape[0], end_row)
end_col = min(arr.shape[1], end_col)
else:
end_row = min(arr.shape[1], end_row)
end_col = min(arr.shape[2], end_col)
# Cut out the square box
if obs_dim == 2:
square_box = arr[start_row:end_row, start_col:end_col]
else:
square_box = arr[:, start_row:end_row, start_col:end_col]
# pad array up if needed (ie box outside array bounds)
pad = False
if obs_dim == 2:
if square_box.shape[0] < size or square_box.shape[1] < size:
pad_low = size - square_box.shape[0]
pad_high = size - square_box.shape[1]
pad = True
else:
if square_box.shape[1] < size or square_box.shape[2] < size:
pad_low = size - square_box.shape[1]
pad_high = size - square_box.shape[2]
pad = True
# perform the padding
if pad:
# If the square box is not the correct size, pad it with zeros
if pad_low < 0:
pad_low = 0
if pad_high < 0:
pad_high = 0
if obs_dim <= 2:
square_box = np.pad(square_box, ((pad_low, 0), (pad_high, 0)), mode="constant", constant_values=0)
else:
# Get the original array shape
original_height, original_width, num_channels = square_box.shape
# Create a new zero-padded array
padded_rgb_array = np.zeros(
(original_height + 2 * pad_high, original_width + 2 * pad_low, num_channels),
dtype=square_box.dtype,
)
# Place the original RGB array in the center of the padded array
padded_rgb_array[pad_high : pad_high + original_height, pad_low : pad_low + original_width, :] = (
square_box
)
return square_box
# @jax.grad
[docs]
def neural_grad(galaxy, src):
"""Calculate the gradient of the neural network"""
parameters = src.get_parameters(return_info=True)
prior = 2 * sum(
info["prior"].log_prob(galaxy) for name, (p, info) in parameters.items() if info["prior"] is not None
)
return prior
[docs]
def log_like(morph, spectrum, data, weights):
"""Calculate the log-likelihood of the model given the data"""
model = morph[None, :, :] * spectrum[:, None, None]
d = jnp.prod(jnp.asarray(data.shape)) - jnp.sum(weights == 0)
log_norm = d / 2 * jnp.log(2 * jnp.pi)
log_like = -jnp.sum(weights * (model - data) ** 2) / 2
return log_like - log_norm
# --------------------- #
# Hessian approximation #
# --------------------- #
# https://arxiv.org/pdf/2006.00719.pdf
# for regular functions f
[docs]
def hvp(f, primals, tangents):
"""Calculate the Hessian-vector product of a function f"""
return jvp(grad(f), primals, tangents)[1]
# for score functions
[docs]
def hvp_grad(grad_f, primals, tangents):
"""Calculate the Hessian-vector product of a gradient function grad_f"""
return jvp(grad_f, primals, tangents)[1]
# diagonals of Hessian from HVPs
[docs]
def hvp_rad(hvp, shape):
"""Approximate the diagonal of the Hessian"""
max_iters = 100 # maximum number of iterations
h = jnp.zeros(shape, dtype=jnp.float32)
h_ = jnp.zeros(shape, dtype=jnp.float32)
for i in range(max_iters):
key = random.PRNGKey(i)
z = random.rademacher(key, shape, dtype=jnp.float32)
h += jnp.multiply(z, hvp(z))
if i > 0:
norm = jnp.linalg.norm(h / (i + 1) - h_ / i, ord=2)
if norm < 1e-6 * jnp.linalg.norm(h / (i + 1), ord=2): # gets reasonable results with 1e-2
break
h_ = h
return h / (i + 1)
# TODO: fix the jit compilation errors here
[docs]
def hallucination_score(scene, obs, src_num):
"""Calculate the hallucination score of a source in `scene` based on `obs`"""
src = scene.sources[src_num]
center = np.array(src.morphology.bbox.center)[::-1]
morph = src.morphology.data
f = lambda morph: neural_grad(morph, src)
jit_hvp_x2 = jit(lambda z: hvp_grad(f, (morph,), (z,)))
hvp_nn = hvp_rad(jit_hvp_x2, morph.shape)
hvp_nn = np.array(hvp_nn)
model_scene = scene()
morph = model_scene[
src_num
] # FIXME: this must be wrong because that is a channel image, not a source image
spectrum = jnp.array((1,))
data = obs.data
weights = obs.weights
# jit the HVP for this loss and this morph model
f = lambda morph: log_like(morph, spectrum, data, weights) # noqa: E731
jit_hvp_x = jit(lambda z: hvp(f, (morph,), (z,)))
hvp_ll = hvp_rad(jit_hvp_x, morph.shape)
box_size = hvp_nn.shape[1]
# Cut out the square box
hvp_ll_cut = cut_square_box(hvp_ll, center, box_size)
hallucination = -hvp_nn + hvp_ll_cut
return -hallucination * src.morphology(), jnp.sum(-hallucination * src.morphology())
[docs]
def confidence(scene, observation):
"""The confidence of each source in `scene` based on the hallucination score"""
sources = scene.sources
n_sources = len(sources)
metrics = np.zeros(n_sources)
for k, _ in enumerate(sources):
_, metric = hallucination_score(scene, observation, k)
metrics[k] = metric
return metrics
[docs]
def sources(
scene,
observation=None,
norm=None,
channel_map=None,
show_model=True,
show_observed=False,
show_rendered=False,
show_spectrum=True,
model_mask=None,
add_labels=False,
add_boxes=False,
fig_kwargs=None,
title_kwargs=None,
label_kwargs=None,
box_kwargs=None,
):
"""Plot all sources in `scene`
Creates one figure, with each source in `scene` occupying one row. Depending
on the chosen options, multiple panels per source will be created.
Parameters
----------
scene: :py:class:`~scarlet2.Scene`
The scene object containing the sources and their models
observation: :py:class:`~scarlet2.Observation`, optional
The observation to render the sources for, or to show the data of.
Only needed when `show_observed` or `show_rendered` is True.
norm: Norm, optional
Norm to scale the intensity of `observation` into RGB 0..256
channel_map: array, optional
Linear mapping from channels to RGB, dimensions (3, channels)
show_model: bool, optional
Whether to show the internal model of each source
show_observed: bool, optional
Whether to show the observations in the same region as the source
show_rendered: bool, optional
Whether to show the model of each source rendered into the frame of `observation`
show_spectrum: bool, optional
Whether to show the spectrum of each source
model_mask: array, optional
A mask to apply to the model. If not given, no mask is applied
add_labels: bool, optional
Whether each source is labeled with its numerical index in the source list
add_boxes: bool, optional
Whether to plot the bounding box of each source
fig_kwargs: dict, optional
Additional arguments for `mpl.subplots`
title_kwargs: dict, optional
Additional arguments for `mpl.set_title`
label_kwargs: dict, optional
Additional arguments for `mpl.plot` of the source centers. Defaults to
{"color": "w", "marker": "x", "mew": 1, "ms": 10}
box_kwargs: dict, optional
Additional arguments for `mpl.Polygon`.
Defaults to {"facecolor": "none", "edgecolor": "w", "lw": 0.5}
Returns
-------
mpl.Figure
"""
if fig_kwargs is None:
fig_kwargs = {}
if title_kwargs is None:
title_kwargs = {}
if label_kwargs is None:
label_kwargs = {"color": "w", "ha": "center", "va": "center"}
if box_kwargs is None:
box_kwargs = {"facecolor": "none", "edgecolor": "w", "lw": 0.5}
if show_rendered or show_observed:
assert observation is not None, "show_rendered or show_observed requires observation"
sources = scene.sources
n_sources = len(sources)
panels = sum((show_model, show_observed, show_rendered, show_spectrum))
figsize = fig_kwargs.pop("figsize", None)
if figsize is None:
figsize = (panel_size * panels, panel_size * n_sources)
fig, ax = plt.subplots(n_sources, panels, figsize=figsize, squeeze=False, **fig_kwargs)
for k, src in enumerate(sources):
# model in its bbox
panel = 0
model = src()
if show_model:
if observation is not None:
c = ChannelRenderer(scene.frame, observation.frame)
model = c(model)
# Show the unrendered model in it's bbox
extent = src.bbox.get_extent()
ax[k][panel].imshow(
img_to_rgb(model, norm=norm, channel_map=channel_map, mask=model_mask),
extent=extent,
origin="lower",
)
ax[k][panel].set_title(f"Source {k}", **title_kwargs)
if add_labels:
center = src.center
ax[k][panel].text(*(center[::-1]), k, **label_kwargs) # x,y
panel += 1
if show_rendered or show_observed:
observation.check_set_renderer(scene.frame)
if add_labels:
center_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(center)).flatten()
if add_boxes:
start, stop = src.bbox.spatial.start, src.bbox.spatial.stop
corners = np.array(
[start, np.array((start[0], stop[1])), stop, np.array((stop[0], start[1]))]
)
corners_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(corners))
# model in observation frame
if show_rendered:
model = scene.evaluate_source(src)
model_ = observation.render(model)
ax[k][panel].imshow(
img_to_rgb(model_, norm=norm, channel_map=channel_map, mask=model_mask),
origin="lower",
)
ax[k][panel].set_title(f"Source {k} Rendered", **title_kwargs)
if add_labels:
ax[k][panel].text(*(center_obs[::-1]), k, **label_kwargs) # x,y
if add_boxes:
poly = Polygon(corners_obs[:, ::-1], closed=True, **box_kwargs)
ax[k][panel].add_artist(poly)
panel += 1
if show_observed:
name = observation.name if hasattr(observation, "name") else ""
# Center the observation on the source and display it
ax[k][panel].imshow(
img_to_rgb(observation.data, norm=norm, channel_map=channel_map),
origin="lower",
)
ax[k][panel].set_title(f"Observation {name}", **title_kwargs)
if add_labels:
ax[k][panel].text(*(center_obs[::-1]), k, **label_kwargs) # x,y
if add_boxes:
poly = Polygon(corners_obs[:, ::-1], closed=True, **box_kwargs)
ax[k][panel].add_artist(poly)
panel += 1
if show_spectrum:
# needs to be evaluated in the source box to prevent truncation
spectra = [
measure.flux(src),
] + [measure.flux(component) for component in src.components]
for spectrum in spectra:
ax[k][panel].plot(spectrum)
ax[k][panel].set_xticks(range(len(spectrum)))
if scene.frame.channels is not None:
ax[k][panel].set_xticklabels(scene.frame.channels)
ax[k][panel].set_title("Spectrum", **title_kwargs)
ax[k][panel].set_xlabel("Channel")
ax[k][panel].set_ylabel("Flux")
fig.tight_layout()
return fig
[docs]
def scene(
scene,
observation=None,
norm=None,
channel_map=None,
show_model=True,
show_observed=False,
show_rendered=False,
show_residual=False,
add_labels=True,
add_boxes=False,
split_channels=False,
fig_kwargs=None,
title_kwargs=None,
label_kwargs=None,
box_kwargs=None,
):
"""Plot all sources to recreate the scene.
The functions provide a fast way of evaluating the quality of the entire model,
i.e. the combination of all sources that seek to fit the observation.
Parameters
----------
scene: :py:class:`~scarlet2.Scene`
The scene object containing the sources and their models
observation: :py:class:`~scarlet2.Observation`, optional
The observation containing the data
norm: Norm
Norm to scale the intensity of `observation` into RGB 0..256
channel_map: array_like
Linear mapping from channels to RGB, dimensions (3, channels)
show_model: bool
Whether the internal model is shown in the model frame
show_observed: bool
Whether the observation is shown
show_rendered: bool
Whether the model, rendered to match the observation, is shown
show_residual: bool
Whether the residuals between rendered model and observation is shown
add_labels: bool
Whether each source is labeled with its numerical index in the source list
add_boxes: bool
Whether each source box is shown
split_channels: bool
Whether to split the observation into separate channels
fig_kwargs: dict
kwargs for plt.figure()
title_kwargs: dict
kwargs for plt.title()
label_kwargs: dict
kwargs for source labels, default {"color": "w", "ha": "center", "va": "center"}
box_kwargs: dict
kwargs for source boxes, default {"facecolor": "none", "edgecolor": "w", "lw": 0.5}
Returns
-------
mpl.Figure
"""
if fig_kwargs is None:
fig_kwargs = {}
if title_kwargs is None:
title_kwargs = {}
if label_kwargs is None:
label_kwargs = {"color": "w", "ha": "center", "va": "center"}
if box_kwargs is None:
box_kwargs = {"facecolor": "none", "edgecolor": "w", "lw": 0.5}
# for animations with multiple scenes
if hasattr(scene, "__iter__"):
scenes = scene
scene = scenes[0]
if show_observed or show_rendered or show_residual:
assert observation is not None, "Provide matched observation to show observed frame"
rows = len(observation.frame.channels) if split_channels else 1
panels = sum((show_model, show_observed, show_rendered, show_residual))
figsize = fig_kwargs.pop("figsize", None)
if figsize is None:
figsize = (panel_size * panels, panel_size * rows)
fig, ax = plt.subplots(rows, panels, figsize=figsize, squeeze=False, **fig_kwargs)
model = scene()
if show_rendered or show_residual:
observation.check_set_renderer(scene.frame)
model_rendered = observation.render(model)
if show_model and observation is not None:
c = ChannelRenderer(scene.frame, observation.frame)
model = c(model)
if show_observed or show_residual:
data = observation.data
mask = observation.weights == 0
for row in range(rows):
if split_channels:
sel = row
name = observation.frame.channels[row]
channel_map = None
else:
sel = slice(None)
name = observation.name if hasattr(observation, "name") else ""
panel = 0
if show_model:
extent = scene.frame.bbox.get_extent()
model_img = ax[row, panel].imshow(
img_to_rgb(model[sel], norm=norm, channel_map=channel_map),
extent=extent,
origin="lower",
)
ax[row, panel].set_title("Model", **title_kwargs)
panel += 1
if show_rendered:
rendered_img = ax[row, panel].imshow(
img_to_rgb(model_rendered[sel], norm=norm, channel_map=channel_map),
origin="lower",
)
ax[row, panel].set_title("Model Rendered", **title_kwargs)
panel += 1
if show_observed or show_rendered:
if split_channels: # noqa: SIM108
mask_ = mask[sel]
else:
# Mask any pixels with zero weight in all channels
mask_ = np.sum(mask, axis=0) > 0
if np.all(mask_ == 0):
mask_ = None
if show_observed:
_ = ax[row, panel].imshow(
img_to_rgb(data[sel], norm=norm, channel_map=channel_map, mask=mask_),
origin="lower",
)
ax[row, panel].set_title(f"Observation {name}", **title_kwargs)
panel += 1
if show_residual:
residual = data[sel] - model_rendered[sel]
norm_ = LinearPercentileNorm(residual)
residual_img = ax[row, panel].imshow(
img_to_rgb(residual, norm=norm_, channel_map=channel_map, mask=mask_),
origin="lower",
)
ax[row, panel].set_title("Obs - Model", **title_kwargs)
panel += 1
for k, src in enumerate(scene.sources):
if add_boxes:
start, stop = src.bbox.spatial.start, src.bbox.spatial.stop
corners = np.array(
[start, np.array((start[0], stop[1])), stop, np.array((stop[0], start[1]))]
)
if observation is not None:
corners_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(corners))
for panel in range(panels):
corners_ = corners if panel == 0 and show_model else corners_obs
poly = Polygon(corners_[:, ::-1], closed=True, **box_kwargs) # needs x,y
ax[row, panel].add_artist(poly)
if add_labels:
center = src.center
if observation is not None:
center_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(center)).flatten()
for panel in range(panels):
center_ = center if panel == 0 and show_model else center_obs
ax[row, panel].text(*(center_[::-1]), k, **label_kwargs) # x,y
fig.tight_layout()
try:
# animate multiple scenes
n_frames = len(scenes)
# update only images dependent on the current state of scene
def update(i):
updated = []
scene = scenes[i]
model = scene()
if show_model:
model_img.set_data(img_to_rgb(model, norm=norm, channel_map=channel_map))
updated.append(model_img)
if show_rendered or show_residual:
model = observation.render(model)
if show_rendered:
rendered_img.set_data(img_to_rgb(model, norm=norm, channel_map=channel_map, mask=mask_))
updated.append(rendered_img)
if show_residual:
residual = observation.data - model
norm_ = LinearPercentileNorm(residual)
residual_img.set_data(img_to_rgb(residual, norm=norm_, channel_map=channel_map, mask=mask_))
updated.append(residual_img)
return updated
ani = animation.FuncAnimation(fig=fig, func=update, frames=n_frames, interval=30)
return ani
except NameError:
return fig