Source code for scarlet2.plot

"""Plotting functions"""

from abc import ABC, abstractmethod

import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from jax import grad, jit, jvp
from matplotlib.patches import Polygon

from . import measure
from .bbox import Box, insert_into
from .renderer import ChannelRenderer


[docs] def channels_to_rgb(channels): """Get the linear mapping of multiple channels to RGB channels The mapping created here assumes the the channels are ordered in wavelength direction, starting with the shortest wavelength. The mapping seeks to produce a relatively even weights for across all channels. It does not consider e.g. signal-to-noise variations across channels or human perception. Parameters ---------- channels: int in range(0,7) Number of channels Returns ------- array (3, channels) to map onto RGB """ assert channels in range(0, 8), f"No mapping has been implemented for more than {channels} channels" channel_map = np.zeros((3, channels)) if channels == 1: channel_map[0, 0] = channel_map[1, 0] = channel_map[2, 0] = 1 elif channels == 2: channel_map[0, 1] = 0.667 channel_map[1, 1] = 0.333 channel_map[1, 0] = 0.333 channel_map[2, 0] = 0.667 channel_map /= 0.667 elif channels == 3: channel_map[0, 2] = 1 channel_map[1, 1] = 1 channel_map[2, 0] = 1 elif channels == 4: channel_map[0, 3] = 1 channel_map[0, 2] = 0.333 channel_map[1, 2] = 0.667 channel_map[1, 1] = 0.667 channel_map[2, 1] = 0.333 channel_map[2, 0] = 1 channel_map /= 1.333 elif channels == 5: channel_map[0, 4] = 1 channel_map[0, 3] = 0.667 channel_map[1, 3] = 0.333 channel_map[1, 2] = 1 channel_map[1, 1] = 0.333 channel_map[2, 1] = 0.667 channel_map[2, 0] = 1 channel_map /= 1.667 elif channels == 6: channel_map[0, 5] = 1 channel_map[0, 4] = 0.667 channel_map[0, 3] = 0.333 channel_map[1, 4] = 0.333 channel_map[1, 3] = 0.667 channel_map[1, 2] = 0.667 channel_map[1, 1] = 0.333 channel_map[2, 2] = 0.333 channel_map[2, 1] = 0.667 channel_map[2, 0] = 1 channel_map /= 2 elif channels == 7: channel_map[:, 6] = 2 / 3.0 channel_map[0, 5] = 1 channel_map[0, 4] = 0.667 channel_map[0, 3] = 0.333 channel_map[1, 4] = 0.333 channel_map[1, 3] = 0.667 channel_map[1, 2] = 0.667 channel_map[1, 1] = 0.333 channel_map[2, 2] = 0.333 channel_map[2, 1] = 0.667 channel_map[2, 0] = 1 channel_map /= 2 return channel_map
[docs] class Norm(ABC): """Base class to normalize the color values of RGB images""" def __init__(self): self._uint8Max = float(np.iinfo(np.uint8).max)
[docs] def get_intensity(self, im): """Compute total intensity image""" return jnp.maximum(0, im).sum(axis=0)
[docs] def clip(self, im, min_value, max_value): """Clip image between min_value and max_value""" return jnp.maximum(0, jnp.minimum(im - min_value, max_value - min_value))
[docs] def convert_to_uint8(self, im): """Convert three-channel image to RGB image with uint8 dtype""" im_clipped = self.clip(im, 0, 1) uint_im = (im_clipped * self._uint8Max).astype("uint8") im_flipped = uint_im.transpose().swapaxes(0, 1) # 3 x Ny x Nx -> Ny x Nx x 3 return im_flipped
[docs] def make_rgb_image(self, *im): """Compute RGB image from three-channel image""" # backwards compatible to astropy Mapping Call return self.convert_to_uint8(self.__call__(jnp.stack(im, axis=0)))
[docs] @abstractmethod def __call__(self, im): """Compute normalized three-channel image""" pass
[docs] class LinearNorm(Norm): """Class for linear normalization""" def __init__(self, minimum, maximum): """Linear norm, mapping the interval [`minimum`, `maximum`] to [0,1] Parameters ---------- minimum: float Value that will be mapped to 0 maximum: float Value that will be mapped to 1 """ self.min_value, self.max_value = minimum, maximum super().__init__()
[docs] def __call__(self, im): """Compute linear normalized image""" return self.clip(im, self.min_value, self.max_value) / (self.max_value - self.min_value)
[docs] class LinearPercentileNorm(LinearNorm): """Class for linear normalization based on percentiles""" def __init__(self, img, percentiles=(1, 99)): """Norm that is linear between the two elements of `percentiles` of `img` Parameters ---------- img: array Image to normalize percentiles: array-like, optional Lower and upper percentile to consider. Pixel values below will be set to zero, above to saturated. Default is (1, 99) """ assert len(percentiles) == 2 vmin, vmax = np.percentile(img, percentiles) super().__init__(minimum=vmin, maximum=vmax)
[docs] class AsinhNorm(Norm): """AsinhNorm class""" def __init__(self, min_value, max_value, beta): """Norm that scales as arcsinh(I / beta) between `m` and `M` See Lupton+(2004) https://ui.adsabs.harvard.edu/abs/2004PASP..116..133L Parameters ---------- min_value: float Minimum value to consider max_value: float Maximum value to consider beta: float Turnover point of arcsinh. Below it norm behaves linear, above it norm approximates ln(2*I) """ self.min_value, self.max_value, self.beta = min_value, max_value, beta self._rgb_max = 1 super().__init__()
[docs] def set_rgb_max(self, img, vibrance=0.15): """Set maximum value of normalized image Parameters ---------- img: array Three-channel image vibrance: float Allowance to exceed normalization of three-channel image. Makes images more vibrant but causes slight color shifts towards white in the highlights. """ rgb = self.__call__(img) self._rgb_max = rgb[np.isfinite(rgb)].max() / (1 + vibrance)
[docs] def __call__(self, img): """Compute Asinh normalized image""" min_value = self.min_value max_value = self.max_value intensity = self.get_intensity(img) with np.errstate(invalid="ignore", divide="ignore"): # n.b. np.where can't and doesn't short-circuit # clip between m and M i_ = self.clip(intensity - min_value, 0, max_value - min_value) # arcsinh scaling from Lupton+(2004) f = np.arcsinh(i_ / self.beta) # no need to normalize, done below rgb = img / (intensity / f)[None, :, :] # keep rgb between 0 and 1 (with an allowance of self.vibrance) rgb = rgb / self._rgb_max return rgb
[docs] class AsinhPercentileNorm(AsinhNorm): """AsinhPercentileNorm class""" def __init__(self, img, percentiles=(45, 50, 99), vibrance=0.15): """Norm that scales as arcsinh(I / beta) between bottom and top percentile Uses the middle percentile to define the turnover `beta`. The defaults are chosen such that the median (percentile 50) tries to catch emission slightly above the sky level, while the minimum is aiming for the sky intensity itself. Parameters ---------- img: array_like Image to normalize percentiles: array_like Lower, middle, and upper percentile to consider. Pixel values below will be set to zero, above to one. Asinh turnover is given by middle percentile. Default is (45,50,99) vibrance: float Allowance to exceed normalization of three-channel image. Makes images more vibrant but causes slight color shifts in the highlights. """ assert len(percentiles) == 3 min_value, beta, max_value = np.percentile(img, percentiles) super().__init__(min_value, max_value, beta) super().set_rgb_max(img, vibrance=vibrance)
[docs] class AsinhAutomaticNorm(AsinhNorm): """AsinhAutomaticNorm class""" def __init__( self, observation, channel_map=None, minimum=0, upper_percentile=99.5, noise_level=1, vibrance=0.15, ): """Norm that scales as arcsinh(I / beta) with parameters chosen automatically The turnover `beta` is taken from the at `noise_level` * RMS, where RMS is the total variance of the observations. This norm should automatically create an image scaling that picks out low-surface brightness features and highlights. Parameters ---------- observation: py:class:`~scarlet2.Observation` Observation object with weights channel_map: array Linear mapping from channels to RGB, dimensions (3, channels) minimum: float Minimum value to consider. upper_percentile: float Upper percentile: Pixel values above will be saturated. noise_level: float Factor to be multiplied to the total noise RMS to define the turnover point vibrance: float Allowance to exceed normalization of three-channel image. Makes images more vibrant but causes slight color shifts in the highlights. """ if channel_map is None: channel_map = channels_to_rgb(observation.frame.C) im3 = img_to_3channel(observation.data, channel_map=channel_map) var3 = np.where(observation.weights > 0, 1 / observation.weights, 0) # filter pixels with 0 weight var3 = img_to_3channel(var3, channel_map=channel_map) # total intensity and variance images i = self.get_intensity(im3) v = self.get_intensity(var3) # find upper clipping point (max_value,) = np.percentile(i.flatten(), [upper_percentile]) min_value = minimum # find a good turnover point for arcsinh: ~noise level rms = np.median(np.sqrt(v)) beta = rms * noise_level super().__init__(min_value, max_value, beta) super().set_rgb_max(im3, vibrance=vibrance)
[docs] def img_to_3channel(img, channel_map=None): """Convert multi-band image cube into 3 RGB channels Parameters ---------- img: array This should be an array with dimensions (channels, height, width). channel_map: array Linear mapping from channels to RGB, dimensions (3, channels) Returns ------- array Dimensions (3, height, width), type float """ # expand single img into cube assert img.ndim in [2, 3] if len(img.shape) == 2: ny, nx = img.shape img_ = img.reshape(1, ny, nx) elif len(img.shape) == 3: img_ = img num_channels = len(img_) # filterWeights: channel x band if channel_map is None: channel_map = channels_to_rgb(num_channels) # channel x band else: assert channel_map.shape == (3, len(img)) # filter out masked and bad values img_ = jnp.where(jnp.isfinite(img_), img_, 0) # map channels onto RGB channels _, ny, nx = img_.shape rgb = jnp.dot(channel_map, img_.reshape(num_channels, -1)).reshape(3, ny, nx) return rgb
[docs] def img_to_rgb(img, channel_map=None, norm=None, mask=None): """Convert images to normalized RGB. If normalized values are outside of the range [0..255], they will be truncated such as to preserve the corresponding color. Parameters ---------- img: array This should be an array with dimensions (channels, height, width). channel_map: array Linear mapping from channels to RGB, dimensions (3, channels) norm: Norm, optional Norm to use for mapping in the allowed range [0..255]. If `norm=None`, `scarlet.display.LinearPercentileNorm` will be used. mask: array_like, optional A [0,1] binary mask set define where pixels will have 0 opacity. Returns ------- array Dimensions (3, height, width), type float """ im3 = img_to_3channel(img, channel_map=channel_map) if norm is None: norm = LinearPercentileNorm(im3) rgb = norm.make_rgb_image(*im3) if mask is not None: rgb = jnp.dstack([rgb, ~mask * 255]) return rgb
panel_size = 4.0
[docs] def observation( observation, norm=None, channel_map=None, sky_coords=None, show_psf=False, add_labels=True, split_channels=False, fig_kwargs=None, title_kwargs=None, label_kwargs=None, ): """Plot observation Show entire content of `observation`, optionally with list of sources given by `sky_coords` or a PSF image. Parameters ---------- observation: :py:class:`~scarlet2.Observation` The observation object to plot norm: Norm, optional Norm to scale the intensity of `observation` into RGB 0..256 channel_map: array, optional Linear mapping from channels to RGB, dimensions (3, channels) sky_coords: list, optional 2D coordinates (in pixel coordinates or sky coordinates). If in sky coordinates, the Frame of `observation` needs to have a valid WCS. show_psf: bool, optional Whether to plot a panel with the PSF model of `observation` centered in the middle add_labels: bool, optional Whether to plot a text label with the running number for each of the sources in `sky_coords` split_channels: bool, optional Whether to split the observation into separate channels fig_kwargs: dict, optional Additional arguments for `mpl.subplots` title_kwargs: dict, optional Additional arguments for `mpl.set_title` label_kwargs: dict, optional Additional arguments for `mpl.text`. Default is None and will be set to `{"color": "w", "ha": "center", "va": "center"}` Returns ------- mpl.Figure """ if fig_kwargs is None: fig_kwargs = {} if title_kwargs is None: title_kwargs = {} if label_kwargs is None: label_kwargs = {"color": "w", "ha": "center", "va": "center"} if show_psf: assert observation.frame.psf is not None, "show_psf requires observation.frame.psf to be set" psf_model = observation.frame.psf() rows = len(observation.frame.channels) if split_channels else 1 panels = 1 if show_psf is False else 2 figsize = fig_kwargs.pop("figsize", None) if figsize is None: figsize = (panel_size * panels, panel_size * rows) fig, ax = plt.subplots(rows, panels, figsize=figsize, squeeze=False, **fig_kwargs) if not hasattr(ax, "__iter__"): ax = (ax,) extent = observation.frame.bbox.get_extent() for row in range(rows): if split_channels: data = observation.data[row] mask = observation.weights[row] == 0 name = observation.frame.channels[row] if show_psf: psf = psf_model[row] # make PSF as bright as the brightest pixel of the observation psf *= data.max() / psf.max() else: data = observation.data # Mask any pixels with zero weight in all channels mask = np.sum(observation.weights, axis=0) == 0 name = observation.name if hasattr(observation, "name") else "" if show_psf: psf = psf_model # make PSF as bright as the brightest pixel of the observation psf *= observation.data.mean(axis=0).max() / psf_model.mean(axis=0).max() # if there are no masked pixels, do not use a mask if np.all(mask == 0): mask = None panel = 0 ax[row, panel].imshow( img_to_rgb(data, norm=norm, channel_map=channel_map, mask=mask), extent=extent, origin="lower", ) ax[row, panel].set_title(f"Observation {name}", **title_kwargs) if add_labels and sky_coords is not None: for k, center in enumerate(sky_coords): center_ = observation.frame.get_pixel(center) ax[row, panel].text(*center_[::-1], k, **label_kwargs) if show_psf: panel = 1 psf_image = np.zeros(data.shape) # insert into middle of "blank" observation full_box = Box(psf_image.shape) shift = tuple(psf_image.shape[d] // 2 - psf.shape[d] // 2 for d in range(full_box.D)) model_box = Box(psf.shape) + shift psf_image = insert_into(psf_image, psf, model_box) # slices = scarlet.box.overlapped_slices ax[row, panel].imshow(img_to_rgb(psf_image, norm=norm), origin="lower") ax[row, panel].set_title("PSF", **title_kwargs) fig.tight_layout() return fig
# ------------------------------------------------------ # # include a routine to calculate the hallucination score # # ----------------------------------------------------- #
[docs] def cut_square_box(arr, center, size): """ Cut out a square box from a 2D array based on the center and size. Parameters: arr: numpy.ndarray The input 2D array. center: tuple The center of the box in the format (row_center, col_center). size: int The size of the square box (side length). Returns: numpy.ndarray: The square box extracted from the input array. """ # get the dimensions of the data obs_dim = arr.ndim row_center, col_center = center # col_center, row_center = center half_size = size // 2 # Calculate the indices for slicing start_row = row_center - half_size end_row = start_row + size start_col = col_center - half_size end_col = start_col + size # Ensure the indices are within the array bounds start_row = max(0, start_row) start_col = max(0, start_col) if obs_dim == 2: end_row = min(arr.shape[0], end_row) end_col = min(arr.shape[1], end_col) else: end_row = min(arr.shape[1], end_row) end_col = min(arr.shape[2], end_col) # Cut out the square box if obs_dim == 2: square_box = arr[start_row:end_row, start_col:end_col] else: square_box = arr[:, start_row:end_row, start_col:end_col] # pad array up if needed (ie box outside array bounds) pad = False if obs_dim == 2: if square_box.shape[0] < size or square_box.shape[1] < size: pad_low = size - square_box.shape[0] pad_high = size - square_box.shape[1] pad = True else: if square_box.shape[1] < size or square_box.shape[2] < size: pad_low = size - square_box.shape[1] pad_high = size - square_box.shape[2] pad = True # perform the padding if pad: # If the square box is not the correct size, pad it with zeros if pad_low < 0: pad_low = 0 if pad_high < 0: pad_high = 0 if obs_dim <= 2: square_box = np.pad(square_box, ((pad_low, 0), (pad_high, 0)), mode="constant", constant_values=0) else: # Get the original array shape original_height, original_width, num_channels = square_box.shape # Create a new zero-padded array padded_rgb_array = np.zeros( (original_height + 2 * pad_high, original_width + 2 * pad_low, num_channels), dtype=square_box.dtype, ) # Place the original RGB array in the center of the padded array padded_rgb_array[pad_high : pad_high + original_height, pad_low : pad_low + original_width, :] = ( square_box ) return square_box
[docs] @jax.grad def neural_grad(galaxy, src): """Calculate the gradient of the neural network""" parameters = src.get_parameters(return_info=True) prior = 2 * sum( info["prior"].log_prob(galaxy) for name, (p, info) in parameters.items() if info["prior"] is not None ) return prior
[docs] def log_like(morph, spectrum, data, weights): """Calculate the log-likelihood of the model given the data""" model = morph[None, :, :] * spectrum[:, None, None] d = jnp.prod(jnp.asarray(data.shape)) - jnp.sum(weights == 0) log_norm = d / 2 * jnp.log(2 * jnp.pi) log_like = -jnp.sum(weights * (model - data) ** 2) / 2 return log_like - log_norm
# --------------------- # # Hessian approximation # # --------------------- # # https://arxiv.org/pdf/2006.00719.pdf # for regular functions f
[docs] def hvp(f, primals, tangents): """Calculate the Hessian-vector product of a function f""" return jvp(grad(f), primals, tangents)[1]
# for score functions
[docs] def hvp_grad(grad_f, primals, tangents): """Calculate the Hessian-vector product of a gradient function grad_f""" return jvp(grad_f, primals, tangents)[1]
# diagonals of Hessian from HVPs
[docs] def hvp_rad(hvp, shape): """Approximate the diagonal of the Hessian""" max_iters = 100 # maximum number of iterations h = jnp.zeros(shape, dtype=jnp.float32) h_ = jnp.zeros(shape, dtype=jnp.float32) for i in range(max_iters): key = random.PRNGKey(i) z = random.rademacher(key, shape, dtype=jnp.float32) h += jnp.multiply(z, hvp(z)) if i > 0: norm = jnp.linalg.norm(h / (i + 1) - h_ / i, ord=2) if norm < 1e-6 * jnp.linalg.norm(h / (i + 1), ord=2): # gets reasonable results with 1e-2 break h_ = h return h / (i + 1)
# TODO: fix the jit compilation errors here
[docs] def hallucination_score(scene, obs, src_num): """Calculate the hallucination score of a source in `scene` based on `obs`""" src = scene.sources[src_num] center = np.array(src.morphology.bbox.center)[::-1] morph = src.morphology.data f = lambda morph: neural_grad(morph, src) jit_hvp_x2 = jit(lambda z: hvp_grad(f, (morph,), (z,))) hvp_nn = hvp_rad(jit_hvp_x2, morph.shape) hvp_nn = np.array(hvp_nn) model_scene = scene() morph = model_scene[ src_num ] # FIXME: this must be wrong because that is a channel image, not a source image spectrum = jnp.array((1,)) data = obs.data weights = obs.weights # jit the HVP for this loss and this morph model f = lambda morph: log_like(morph, spectrum, data, weights) # noqa: E731 jit_hvp_x = jit(lambda z: hvp(f, (morph,), (z,))) hvp_ll = hvp_rad(jit_hvp_x, morph.shape) box_size = hvp_nn.shape[1] # Cut out the square box hvp_ll_cut = cut_square_box(hvp_ll, center, box_size) hallucination = -hvp_nn + hvp_ll_cut return -hallucination * src.morphology(), jnp.sum(-hallucination * src.morphology())
[docs] def confidence(scene, observation): """The confidence of each source in `scene` based on the hallucination score""" sources = scene.sources n_sources = len(sources) metrics = np.zeros(n_sources) for k, _ in enumerate(sources): _, metric = hallucination_score(scene, observation, k) metrics[k] = metric return metrics
[docs] def sources( scene, observation=None, norm=None, channel_map=None, show_model=True, show_observed=False, show_rendered=False, show_spectrum=True, model_mask=None, add_labels=False, add_boxes=False, fig_kwargs=None, title_kwargs=None, label_kwargs=None, box_kwargs=None, ): """Plot all sources in `scene` Creates one figure, with each source in `scene` occupying one row. Depending on the chosen options, multiple panels per source will be created. Parameters ---------- scene: :py:class:`~scarlet2.Scene` The scene object containing the sources and their models observation: :py:class:`~scarlet2.Observation`, optional The observation to render the sources for, or to show the data of. Only needed when `show_observed` or `show_rendered` is True. norm: Norm, optional Norm to scale the intensity of `observation` into RGB 0..256 channel_map: array, optional Linear mapping from channels to RGB, dimensions (3, channels) show_model: bool, optional Whether to show the internal model of each source show_observed: bool, optional Whether to show the observations in the same region as the source show_rendered: bool, optional Whether to show the model of each source rendered into the frame of `observation` show_spectrum: bool, optional Whether to show the spectrum of each source model_mask: array, optional A mask to apply to the model. If not given, no mask is applied add_labels: bool, optional Whether each source is labeled with its numerical index in the source list add_boxes: bool, optional Whether to plot the bounding box of each source fig_kwargs: dict, optional Additional arguments for `mpl.subplots` title_kwargs: dict, optional Additional arguments for `mpl.set_title` label_kwargs: dict, optional Additional arguments for `mpl.plot` of the source centers. Defaults to {"color": "w", "marker": "x", "mew": 1, "ms": 10} box_kwargs: dict, optional Additional arguments for `mpl.Polygon`. Defaults to {"facecolor": "none", "edgecolor": "w", "lw": 0.5} Returns ------- mpl.Figure """ if fig_kwargs is None: fig_kwargs = {} if title_kwargs is None: title_kwargs = {} if label_kwargs is None: label_kwargs = {"color": "w", "ha": "center", "va": "center"} if box_kwargs is None: box_kwargs = {"facecolor": "none", "edgecolor": "w", "lw": 0.5} if show_rendered or show_observed: assert observation is not None, "show_rendered or show_observed requires observation" sources = scene.sources n_sources = len(sources) panels = sum((show_model, show_observed, show_rendered, show_spectrum)) figsize = fig_kwargs.pop("figsize", None) if figsize is None: figsize = (panel_size * panels, panel_size * n_sources) fig, ax = plt.subplots(n_sources, panels, figsize=figsize, squeeze=False, **fig_kwargs) for k, src in enumerate(sources): # model in its bbox panel = 0 model = src() if show_model: if observation is not None: c = ChannelRenderer(scene.frame, observation.frame) model = c(model) # Show the unrendered model in it's bbox extent = src.bbox.get_extent() ax[k][panel].imshow( img_to_rgb(model, norm=norm, channel_map=channel_map, mask=model_mask), extent=extent, origin="lower", ) ax[k][panel].set_title(f"Source {k}", **title_kwargs) if add_labels: center = src.center ax[k][panel].text(*(center[::-1]), k, **label_kwargs) # x,y panel += 1 if show_rendered or show_observed: observation.check_set_renderer(scene.frame) if add_labels: center_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(center)).flatten() if add_boxes: start, stop = src.bbox.spatial.start, src.bbox.spatial.stop corners = jnp.array( [start, jnp.array((start[0], stop[1])), stop, jnp.array((stop[0], start[1]))] ) corners_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(corners)) # model in observation frame if show_rendered: model = scene.evaluate_source(src) model_ = observation.render(model) ax[k][panel].imshow( img_to_rgb(model_, norm=norm, channel_map=channel_map, mask=model_mask), origin="lower", ) ax[k][panel].set_title(f"Source {k} Rendered", **title_kwargs) if add_labels: ax[k][panel].text(*(center_obs[::-1]), k, **label_kwargs) # x,y if add_boxes: poly = Polygon(corners_obs[:, ::-1], closed=True, **box_kwargs) ax[k][panel].add_artist(poly) panel += 1 if show_observed: name = observation.name if hasattr(observation, "name") else "" # Center the observation on the source and display it ax[k][panel].imshow( img_to_rgb(observation.data, norm=norm, channel_map=channel_map), origin="lower", ) ax[k][panel].set_title(f"Observation {name}", **title_kwargs) if add_labels: ax[k][panel].text(*(center_obs[::-1]), k, **label_kwargs) # x,y if add_boxes: poly = Polygon(corners_obs[:, ::-1], closed=True, **box_kwargs) ax[k][panel].add_artist(poly) panel += 1 if show_spectrum: # needs to be evaluated in the source box to prevent truncation spectra = [ measure.flux(src), ] + [measure.flux(component) for component in src.components] for spectrum in spectra: ax[k][panel].plot(spectrum) ax[k][panel].set_xticks(range(len(spectrum))) if scene.frame.channels is not None: ax[k][panel].set_xticklabels(scene.frame.channels) ax[k][panel].set_title("Spectrum", **title_kwargs) ax[k][panel].set_xlabel("Channel") ax[k][panel].set_ylabel("Flux") fig.tight_layout() return fig
[docs] def scene( scene, observation=None, norm=None, channel_map=None, show_model=True, show_observed=False, show_rendered=False, show_residual=False, add_labels=True, add_boxes=False, split_channels=False, fig_kwargs=None, title_kwargs=None, label_kwargs=None, box_kwargs=None, ): """Plot all sources to recreate the scene. The functions provide a fast way of evaluating the quality of the entire model, i.e. the combination of all sources that seek to fit the observation. Parameters ---------- scene: :py:class:`~scarlet2.Scene` The scene object containing the sources and their models observation: :py:class:`~scarlet2.Observation`, optional The observation containing the data norm: Norm Norm to scale the intensity of `observation` into RGB 0..256 channel_map: array_like Linear mapping from channels to RGB, dimensions (3, channels) show_model: bool Whether the internal model is shown in the model frame show_observed: bool Whether the observation is shown show_rendered: bool Whether the model, rendered to match the observation, is shown show_residual: bool Whether the residuals between rendered model and observation is shown add_labels: bool Whether each source is labeled with its numerical index in the source list add_boxes: bool Whether each source box is shown split_channels: bool Whether to split the observation into separate channels fig_kwargs: dict kwargs for plt.figure() title_kwargs: dict kwargs for plt.title() label_kwargs: dict kwargs for source labels, default {"color": "w", "ha": "center", "va": "center"} box_kwargs: dict kwargs for source boxes, default {"facecolor": "none", "edgecolor": "w", "lw": 0.5} Returns ------- mpl.Figure """ if fig_kwargs is None: fig_kwargs = {} if title_kwargs is None: title_kwargs = {} if label_kwargs is None: label_kwargs = {"color": "w", "ha": "center", "va": "center"} if box_kwargs is None: box_kwargs = {"facecolor": "none", "edgecolor": "w", "lw": 0.5} # for animations with multiple scenes if hasattr(scene, "__iter__"): scenes = scene scene = scenes[0] if show_observed or show_rendered or show_residual: assert observation is not None, "Provide matched observation to show observed frame" rows = len(observation.frame.channels) if split_channels else 1 panels = sum((show_model, show_observed, show_rendered, show_residual)) figsize = fig_kwargs.pop("figsize", None) if figsize is None: figsize = (panel_size * panels, panel_size * rows) fig, ax = plt.subplots(rows, panels, figsize=figsize, squeeze=False, **fig_kwargs) model = scene() if show_rendered or show_residual: observation.check_set_renderer(scene.frame) model_rendered = observation.render(model) if show_model and observation is not None: c = ChannelRenderer(scene.frame, observation.frame) model = c(model) if show_observed or show_residual: data = observation.data mask = observation.weights == 0 for row in range(rows): if split_channels: sel = row name = observation.frame.channels[row] channel_map = None else: sel = slice(None) name = observation.name if hasattr(observation, "name") else "" panel = 0 if show_model: extent = scene.frame.bbox.get_extent() model_img = ax[row, panel].imshow( img_to_rgb(model[sel], norm=norm, channel_map=channel_map), extent=extent, origin="lower", ) ax[row, panel].set_title("Model", **title_kwargs) panel += 1 if show_rendered: rendered_img = ax[row, panel].imshow( img_to_rgb(model_rendered[sel], norm=norm, channel_map=channel_map), origin="lower", ) ax[row, panel].set_title("Model Rendered", **title_kwargs) panel += 1 if show_observed or show_rendered: if split_channels: # noqa: SIM108 mask_ = mask[sel] else: # Mask any pixels with zero weight in all channels mask_ = np.sum(mask, axis=0) > 0 if np.all(mask_ == 0): mask_ = None if show_observed: _ = ax[row, panel].imshow( img_to_rgb(data[sel], norm=norm, channel_map=channel_map, mask=mask_), origin="lower", ) ax[row, panel].set_title(f"Observation {name}", **title_kwargs) panel += 1 if show_residual: residual = data[sel] - model_rendered[sel] norm_ = LinearPercentileNorm(residual) residual_img = ax[row, panel].imshow( img_to_rgb(residual, norm=norm_, channel_map=channel_map, mask=mask_), origin="lower", ) ax[row, panel].set_title("Obs - Model", **title_kwargs) panel += 1 for k, src in enumerate(scene.sources): if add_boxes: start, stop = src.bbox.spatial.start, src.bbox.spatial.stop corners = jnp.array( [start, jnp.array((start[0], stop[1])), stop, jnp.array((stop[0], start[1]))] ) if observation is not None: corners_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(corners)) for panel in range(panels): corners_ = corners if panel == 0 and show_model else corners_obs poly = Polygon(corners_[:, ::-1], closed=True, **box_kwargs) # needs x,y ax[row, panel].add_artist(poly) if add_labels: center = src.center if observation is not None: center_obs = observation.frame.get_pixel(scene.frame.get_sky_coord(center)).flatten() for panel in range(panels): center_ = center if panel == 0 and show_model else center_obs ax[row, panel].text(*(center_[::-1]), k, **label_kwargs) # x,y fig.tight_layout() try: # animate multiple scenes n_frames = len(scenes) # update only images dependent on the current state of scene def update(i): updated = [] scene = scenes[i] model = scene() if show_model: model_img.set_data(img_to_rgb(model, norm=norm, channel_map=channel_map)) updated.append(model_img) if show_rendered or show_residual: model = observation.render(model) if show_rendered: rendered_img.set_data(img_to_rgb(model, norm=norm, channel_map=channel_map, mask=mask_)) updated.append(rendered_img) if show_residual: residual = observation.data - model norm_ = LinearPercentileNorm(residual) residual_img.set_data(img_to_rgb(residual, norm=norm_, channel_map=channel_map, mask=mask_)) updated.append(residual_img) return updated ani = animation.FuncAnimation(fig=fig, func=update, frames=n_frames, interval=30) return ani except NameError: return fig