Source code for scarlet2.infer

import functools
import operator
from pprint import pformat

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree as jt
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS

from .scene import Scene
from .validation_utils import (
    ValidationError,
    ValidationInfo,
    ValidationMethodCollector,
    ValidationResult,
    ValidationWarning,
    print_validation_results,
)


# helper class to turn observation likelihood(s) into numpyro distribution
class _ObsDistribution(dist.Distribution):
    support = constraints.real_vector

    def __init__(self, obs, model, validate_args=None):
        self.obs = obs
        self.model = model
        event_shape = jnp.shape(model)
        super().__init__(
            event_shape=event_shape,
            validate_args=validate_args,
        )

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def mean(self):
        return self.obs.render(self.model)

    @dist.util.validate_sample
    def log_prob(self, value):
        # numpyro needs sampling distribution of data (=value), not likelihood function of parameters
        return self.obs._log_likelihood(self.model, value)


@functools.lru_cache(maxsize=16)
def _eqx_module_class(name: str, fields: tuple):
    # Cached so repeated fit() calls with the same parameter set reuse the same class.
    # Class identity drives JAX's pytree treedef, so a fresh class would invalidate the
    # JIT cache for any function taking this module as an arg.
    annotations = {k: jax.Array for k in fields}
    return type(name, (eqx.Module,), {"__annotations__": annotations})


def _dict_to_eqx_module(name: str, data: dict):
    cls = _eqx_module_class(name, tuple(data.keys()))
    return cls(**data)


# -----------------------------------------------------------------------------
# Pair-similarity regularizer
# -----------------------------------------------------------------------------
#
# Penalizes the source-pair similarity that is the geometric signature of
# parasitic flux: when source A acquires a B-shaped bump, the cosine
# similarity between A's and B's morphologies rises. The penalty is
#
#     R = sum_{A != B} sigma_AB * rho_AB
#
# where rho_AB is the cosine similarity of the band-summed morphologies and
# sigma_AB is the cosine similarity of the SEDs. The SED factor scales the
# penalty by the strength of the spectral degeneracy: maximal where the
# data likelihood cannot tell A and B apart, fading where SEDs differ.
# -----------------------------------------------------------------------------


[docs] class PairSimilarity(eqx.Module): """Configuration for the pair-similarity regularizer used by :func:`fit`. Penalizes the morphology-cosine x SED-cosine similarity summed over source pairs. Targets the parasitic-flux failure mode: when source A absorbs flux in the shape of a neighbour B, the morphology cosine rises, and the penalty pushes that B-shaped component back down. The SED-cosine factor scales the penalty by the degree of spectral degeneracy between the pair. Parameters ---------- weight : float Relative coefficient of the regularizer: the target ratio of the penalty to the initial NLL. ``fit()`` evaluates the NLL and the unweighted penalty once at initialization and rescales internally so that ``R = weight * |NLL_init|`` at the start of optimization. ``0.0`` is a no-op (bit-identical to the unregularized loss). Empirically a value around ``0.01`` (penalty ~1% of the NLL) works well. eps : float, optional Small constant added to cosine denominators for numerical stability. Notes ----- Implemented as an ``eqx.Module`` with all-static fields so it threads through ``eqx.filter_jit`` without being treated as a differentiable leaf and without forcing recompilation across identical configs. """ weight: float = eqx.field(static=True) eps: float = eqx.field(static=True, default=1e-12)
def _cosine_matrix(X, eps): """Cosine similarity between rows of X. Shape (N, D) -> (N, N).""" norms = jnp.sqrt(jnp.sum(X * X, axis=1, keepdims=True)) Xn = X / (norms + eps) return Xn @ Xn.T def _pair_similarity_from_stack(per_source, cfg: PairSimilarity): """Compute R from a stack of per-source models in a common frame. ``per_source`` has shape ``(K, C, H, W)`` for multi-band scenes or ``(K, H, W)`` for single-band. Pure tensor math so it can be exercised in tests without a Scene. Does not apply ``cfg.weight``; the caller multiplies. """ per_source = jnp.maximum(per_source, 0.0) if per_source.ndim == 4: M = per_source.sum(axis=1) # (K, H, W) f = per_source.sum(axis=(2, 3)) # (K, C) else: M = per_source # Degenerate single-band case: SED cosine is identically 1, so # only morphology contributes. f = jnp.ones((per_source.shape[0], 1)) M_flat = M.reshape(M.shape[0], -1) rho = _cosine_matrix(M_flat, eps=cfg.eps) sigma = _cosine_matrix(f, eps=cfg.eps) # rho, sigma are matmuls of non-negative matrices and thus non-negative # element-wise; the maxima are belt-and-braces against any roundoff dust. pair_term = jnp.maximum(sigma, 0.0) * jnp.maximum(rho, 0.0) # Sum the strict upper triangle directly. We avoid `(sum - trace) / 2`: # in float32 those are two reductions of comparable magnitude (~K) summed # in different orders, so when off-diagonals are tiny the diagonals do # not cancel exactly and the result can flip slightly negative. return jnp.sum(jnp.triu(pair_term, k=1)) def _evaluate_per_source(scene_obj): """Stack of per-source models in the scene frame: shape ``(K, C, H, W)``.""" return jnp.stack([scene_obj.evaluate_source(s) for s in scene_obj.sources], axis=0) def _pair_similarity_penalty(scene_obj, cfg: PairSimilarity, per_source=None): """Compute the pair-similarity regularizer for a rendered Scene. If ``per_source`` (the ``(K, C, H, W)`` stack) is already available from the caller, pass it in to avoid re-evaluating each source. """ if cfg.weight == 0.0: return jnp.array(0.0) if per_source is None: per_source = _evaluate_per_source(scene_obj) return cfg.weight * _pair_similarity_from_stack(per_source, cfg)
[docs] def sample(scene, observations, *args, seed=0, num_warmup=100, num_samples=200, progress_bar=True, **kwargs): """Sample `parameters` of every source in `scene` to get posteriors given `observations`. This method runs the HMC NUTS sampler from `numpyro` to get parameter posteriors. It uses the likelihood of `observations` as well as the `prior` attribute set for every :py:class:`~scarlet2.Parameter` in `parameters`. Parameters ---------- scene : :py:class:`~scarlet2.Scene` The model of the scene. observations: :py:class:`~scarlet2.Observation` or list The observations to fit the models to. *args: list, optional Additional arguments passed. Only used for backwards (v0.3) compatibility. seed: int, optional RNG seed for the sampler num_warmup: int, optional Number of samples during HMC warm-up num_samples: int, optional Number of samples to create from tuned HMC progress_bar: bool, optional Whether to show a progress bar **kwargs: dict, optional Additional keyword arguments passed to the `numpyro.infer.NUTS` sampler. Notes ----- Requires `numpyro` Returns ------- numpyro.infer.mcmc.MCMC """ # making sure we can iterate if not isinstance(observations, (list, tuple)): observations = (observations,) obs_params = {} for obs in observations: obs.check_set_renderer(scene.frame) obs_params.update(obs.parameters) # scene and observations can have parameters: combine them into one model parameters = scene.parameters | obs_params if len(parameters) == 0: msg = "Scene and Observation(s) must have at least one parameter. Found none." raise AttributeError(msg) # find all non-fixed parameters and their priors priors = {name: p.prior for name, (idx, p) in parameters.items()} has_none = any(prior is None for prior in priors.values()) if has_none: msg = f"All parameters need to have priors set. Got:\n{pformat(priors)}" raise AttributeError(msg) values = scene.get() for obs in observations: values |= obs.get() init_values = values.copy() # construct eqx.Module containing all parameter arrays as attributes values = _dict_to_eqx_module("ParamModel", values) # define the pyro model, where every parameter value becomes a sample, # and the observations sample from their likelihood given the rendered model def pyro_model(values): samples = {name: numpyro.sample(name, param.prior) for name, (node, param) in parameters.items()} scene_ = scene.set(samples) pred = scene_() # create scene once for all observations # evaluate likelihood of multiple observations for i, obs_ in enumerate(observations): numpyro.sample(f"obs.{i}", _ObsDistribution(obs_.set(values), pred), obs=obs_.data) # if not told otherwise: use init from current value of model init_strategy = kwargs.pop("init_strategy", None) if init_strategy is None: from functools import partial from numpyro.infer.initialization import init_to_value init_strategy = partial(init_to_value, values=init_values) nuts_kernel = NUTS(pyro_model, init_strategy=init_strategy, **kwargs) mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=progress_bar) rng_key = jax.random.PRNGKey(seed) mcmc.run(rng_key, values) return mcmc
[docs] def fit( scene, observations, *args, schedule=None, max_iter=100, e_rel=1e-4, progress_bar=True, callback=None, pair_similarity=None, **kwargs, ): """Fit model `parameters` of every source in `scene` to match `observations`. Computes the best-fit parameters of all components in every source by first-order gradient descent with the Yogi optimizer from `optax`. Parameters ---------- scene : :py:class:`~scarlet2.Scene` The model of the scene. observations: :py:class:`~scarlet2.Observation` or list The observations to fit the model to. *args: list, optional Additional arguments passed. Only used for backwards (v0.3) compatibility. schedule: callable, optional A function that maps optimizer step count to value. See :py:class:`optax.Schedule` for details. max_iter: int, optional Maximum number of optimizer iterations e_rel: float, optional Upper limit for the relative change in the norm of any parameter to terminate the optimization early. progress_bar: bool, optional Whether to show a progress bar callback: callable, optional Function to be called on the current state of the optimized scene. Signature `callback(scene, convergence, loss) -> None`, where `convergence` is a tree of the same structure as `scene`, and `loss` is the current value of the log_posterior. pair_similarity : :py:class:`~scarlet2.PairSimilarity`, optional If given, adds a pair-similarity regularizer to the loss. ``None`` (default) disables it (bit-identical to the un-regularized version). Penalizes the cosine similarity of source-pair morphologies, scaled by the cosine similarity of their SEDs. Targets parasitic flux directly via the morphology cosine. See :py:class:`~scarlet2.PairSimilarity` for the available options. **kwargs: dict, optional Additional keyword arguments passed to the `optax.scale_by_yogi` optimizer. Notes ----- Requires `optax`. The returned scene carries the *best* parameters seen during optimization (lowest loss), not the last iteration's, since Yogi/Adam can be non-monotonic. Diagnostic info (loss history, best loss, iteration count) is attached as :py:attr:`Scene.fit_info`. Returns ------- Scene, list(Observation) The scene and observation(s) with best-fit parameters. ``scene.fit_info`` contains ``{"loss", "pair_similarity", "best_loss", "n_iter"}``. The ``pair_similarity`` array report zero when the regularizer is disabled. """ try: from tqdm.auto import trange except ImportError as err: raise ImportError("scarlet2.Scene.fit() requires optax and numpyro.") from err # making sure we can iterate if not isinstance(observations, (list, tuple)): observations = (observations,) obs_params = {} for obs in observations: obs.check_set_renderer(scene.frame) obs_params.update(obs.parameters) # scene and observations can have parameters: combine them into one model parameters = scene.parameters | obs_params if len(parameters) == 0: msg = "Scene and Observation(s) must have at least one parameter. Found none." raise AttributeError(msg) # canonicalize: strip any fit_info from a previously-fit scene so the JIT cache key # is identical for fresh and re-fitted inputs (fit_info is a static eqx field) scene = scene.replace("fit_info", None) values = scene.get() for obs in observations: values |= obs.get() # construct eqx.Module containing all parameter arrays as attributes values = _dict_to_eqx_module("ParamModel", values) # same tree structure but for param specs so that we can use them below treedef = jt.structure(values) params = tuple(param for name, (node, param) in parameters.items()) params = jt.unflatten(treedef, params) steps = jt.map(lambda param: param.stepsize, params) # build (or fetch from cache) the optax optimizer; identity-stable across fit() calls # with matching kwargs and schedule, so JAX's JIT cache stays warm schedule_fn = schedule if callable(schedule) else _unit_schedule optim = _build_optim(tuple(sorted(kwargs.items())), schedule_fn) # transform to unconstrained parameters values = _constraint_replace(values, params, inv=True) opt_state = optim.init(values) # default: regularizer off (weight=0.0 short-circuits to a no-op) pair_cfg = pair_similarity if pair_similarity is not None else PairSimilarity(weight=0.0) # Calibrate pair-similarity: user-facing `weight` is relative to the initial NLL. # We compute |NLL_init| and R_init (unweighted) once, then replace the cfg with # one carrying the absolute multiplier. Skip when weight=0 (no-op) or R_init==0. if pair_cfg.weight != 0.0: values_init = _constraint_replace(values, params) scene_init = scene.set(values_init) nll_init = -sum(obs.set(values_init).log_likelihood(scene_init()) for obs in observations) r_init = _pair_similarity_from_stack(_evaluate_per_source(scene_init), pair_cfg) scale = jnp.where(r_init > 0, jnp.abs(nll_init) / (r_init + pair_cfg.eps), 0.0) pair_cfg = PairSimilarity(weight=float(pair_cfg.weight * scale), eps=pair_cfg.eps) # initialize best-fit tracker (loss is minimized, so start at +inf). # Use an explicit dtype so the array is strongly-typed: jnp.where()'s output # in iter 1 is strongly-typed, and JAX caches weak vs. strong separately — # passing a weak inf would force a second compile on iter 2. best_loss = jnp.array(jnp.inf, dtype=jnp.float32) best_values = values losses = [] pair_history = [] with trange(max_iter, disable=not progress_bar) as t: for step in t: # noqa: B007 # optimizer step values, loss, pair, opt_state, convergence, best_loss, best_values = _make_step( values, params, scene, observations, optim, opt_state, steps, best_loss, best_values, pair_cfg, ) losses.append(loss) pair_history.append(pair) # compute max change across all non-fixed parameters for convergence test max_change = jax.tree_util.tree_reduce(lambda a, b: max(a, b), convergence) # report current iteration results to callback if callback is not None: values_ = _constraint_replace(values, params) scene_ = scene.set(values_) callback(scene_, convergence, loss) # Log the loss and max_change in the tqdm progress bar t.set_postfix(loss=f"{loss:08.2f}", max_change=f"{max_change:1.6f}") # test convergence if max_change < e_rel: break # transform best-fit values back to constrained variables and replace in scene best_values = _constraint_replace(best_values, params) # scene_ is a copy, but its registry_key still points to scene.parameters, can thus be reused scene_ = scene.set(best_values) obs_ = tuple(obs.set(best_values) for obs in observations) # attach diagnostic info to the returned scene fit_info = { "loss": jnp.stack(losses), "pair_similarity": jnp.stack(pair_history), "best_loss": best_loss, "n_iter": step + 1, } scene_ = scene_.replace("fit_info", fit_info) # (re)-import `VALIDATION_SWITCH` at runtime to avoid using a static/old value from .validation_utils import VALIDATION_SWITCH if VALIDATION_SWITCH: from .validation import check_fit for obs in observations: validation_results = check_fit(scene_, obs) print_validation_results(f"Fit validation results for observation {obs.name}", validation_results) return scene_, obs_
def _constraint_replace(values, params, inv=False): # replace any parameter with constraint into unconstrained ones by calling its constraint bijector def transform(value, param): if param.constraint is not None: func = param.constraint_transform if inv is False else param.constraint_transform.inv return func(value) else: return value return jt.map(transform, values, params) def _unit_schedule(_): # module-level so optax.scale_by_schedule(_unit_schedule) has stable identity across calls return 1.0 @functools.lru_cache(maxsize=8) def _build_optim(kwargs_items, schedule): """Construct (and cache) the optax optimizer. Cached on hashable args so repeated `fit()` calls with the same configuration reuse the same optimizer instance. JAX's JIT cache keys static args by Python identity, so a fresh `optax.chain(...)` from each fit() call would otherwise miss the cache and recompile `_make_step` every time. See https://github.com/pmelchior/scarlet2/issues/120. """ import optax return optax.chain( optax.scale_by_yogi(**dict(kwargs_items)), optax.scale_by_schedule(schedule), ) # update step for optax optimizer @eqx.filter_jit def _make_step( values, params, scene, observations, optim, opt_state, steps, best_loss, best_values, pair_cfg ): def loss_fn(values): # parameters now obey constraints # transformation happens in the grad path, so gradients are wrt to unconstrained variables # likelihood and prior grads transparently apply the Jacobians of these transformations values_ = _constraint_replace(values, params) scene_ = scene.set(values_) # If any regularizer is active, we need the per-source stack anyway — # build it once and reuse its sum as the model so we don't loop over # sources twice. needs_per_source = pair_cfg.weight != 0.0 if needs_per_source: per_source = _evaluate_per_source(scene_) model = per_source.sum(axis=0) else: per_source = None model = scene_() log_like = sum(obs.set(values_).log_likelihood(model) for obs in observations) # add log prior for all parameters which define priors # Note: This calls priors separately even if they support batched execution # see https://github.com/pmelchior/scarlet2/issues/103 for a possible solution # however, testing after #103 got merged suggests that tree_reduce is faster than grouping log_prior = jt.reduce( operator.add, jt.map( lambda value, param: param.prior.log_prob(value) if param.prior is not None else 0, values_, params, ), ) # pair-similarity regularizer (no-op when pair_cfg.weight == 0.0) pair = _pair_similarity_penalty(scene_, pair_cfg, per_source=per_source) # has_aux=True: return penalty values alongside the loss so they can be # tracked in fit_info without a second forward pass return -(log_like + log_prior) + pair, (pair,) (loss, (pair,)), grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(values) updates, opt_state = optim.update(grads, opt_state, values) # apply per-parameter stepsizes; minus sign because we want gradient descent updates = jt.map( lambda u, s, p: None if u is None else (-s(p) * u if callable(s) else -s * u), updates, steps, values, is_leaf=lambda x: x is None, ) values_ = eqx.apply_updates(values, updates) # for convergence criterion: compute norms of parameters and updates norm = lambda x, dx: 0 if dx is None else jnp.linalg.norm(dx) / jnp.linalg.norm(x) convergence = jt.map(lambda x, dx: norm(x, dx), values, updates) # track best-fit: loss is at the *input* `values` (pre-update), so on improvement # we keep the input values, not the post-update ones is_better = loss < best_loss best_loss = jnp.where(is_better, loss, best_loss) best_values = jt.map(lambda b, v: jnp.where(is_better, v, b), best_values, values) return values_, loss, pair, opt_state, convergence, best_loss, best_values class FitValidator(metaclass=ValidationMethodCollector): """A class containing all of the validation checks for a Scene objects after calling `.fit()`. Note that the metaclass is defined as `MethodCollector`, which collects all validation methods in this class into a single class attribute list called `validation_checks`. This allows for easy iteration over all checks.""" def __init__(self, scene: Scene, observation): """Initialize the FitValidator. Parameters ---------- scene : Scene The scene object to validate. observation : Observation The observation object containing the data to validate against. """ self.scene = scene self.observation = observation # These are placeholders, waiting for the actual width of distribution to be # implemented - see issue https://github.com/pmelchior/scarlet2/issues/192 self.chi2_tolerable_threshold = 1.5 self.chi2_critical_threshold = 5.0 def check_goodness_of_fit(self) -> ValidationResult: """Evaluate the goodness of the model fit to the data by calling the Observation class's `goodness_of_fit` method. Please see the docstring for that method for details. Returns ------- ValidationResult A subclass of ValidationResult indicating the result of the check. """ obs = self.observation chi2 = obs.goodness_of_fit(self.scene()) context = {"chi2": chi2} ret_val: ValidationResult = ValidationInfo( "The model fit is good.", check=self.__class__.__name__, context=context ) if self.chi2_tolerable_threshold <= chi2 < self.chi2_critical_threshold: ret_val = ValidationWarning( "The model fit is acceptable, but the goodness of fit is not optimal.", check=self.__class__.__name__, context=context, ) elif chi2 >= self.chi2_critical_threshold or jnp.isnan(chi2): ret_val = ValidationError( "The model fit is poor.", check=self.__class__.__name__, context=context ) return ret_val def check_chi_square_in_box_and_border(self) -> list[ValidationResult]: """Evaluate the weighted mean (weighted by the inverse variance weights) of the squared residuals for each source. Chi square is also computed for the perimeter outside the box of with `border_width`. Returns ------- list[ValidationResult] A list of ValidationResult subclasses for each source. For each source there will be two results. One for inside the bounding box and one for the border. The ValidationResults will each be one of the following: - If the chi-square is above the critical threshold, a ValidationError. - If the chi-square is below the tolerable threshold, a ValidationInfo. - If the chi-square is between the two thresholds, a ValidationWarning. """ obs = self.observation chi2_per_source = obs.eval_chi_square_in_box_and_border(self.scene) validation_results: list[ValidationResult] = [] for i, chi2 in chi2_per_source.items(): chi2_inside = chi2["in"] chi2_outside = chi2["out"] if chi2_inside < self.chi2_tolerable_threshold: validation_results.append( ValidationInfo( f"The chi-square in the box for source {i} is good.", check=self.__class__.__name__, context={"chi2_in": chi2_inside, "source": i}, ) ) elif self.chi2_tolerable_threshold <= chi2_inside < self.chi2_critical_threshold: validation_results.append( ValidationWarning( f"The chi-square in the box for source {i} is acceptable, but not optimal.", check=self.__class__.__name__, context={"chi2_in": chi2_inside, "source": i}, ) ) elif chi2_inside >= self.chi2_critical_threshold: validation_results.append( ValidationError( f"The chi-square in the box for source {i} is poor.", check=self.__class__.__name__, context={"chi2_in": chi2_inside, "source": i}, ) ) if chi2_outside < self.chi2_tolerable_threshold: validation_results.append( ValidationInfo( f"The chi-square in the border for source {i} is good.", check=self.__class__.__name__, context={"chi2_border": chi2_outside, "source": i}, ) ) elif self.chi2_tolerable_threshold <= chi2_outside < self.chi2_critical_threshold: validation_results.append( ValidationWarning( f"The chi-square in the border for source {i} is acceptable, but not optimal.", check=self.__class__.__name__, context={"chi2_border": chi2_outside, "source": i}, ) ) elif chi2_outside >= self.chi2_critical_threshold: validation_results.append( ValidationError( f"The chi-square in the border for source {i} is poor.", check=self.__class__.__name__, context={"chi2_border": chi2_outside, "source": i}, ) ) return validation_results