Source code for scarlet2.infer

import operator
from pprint import pformat

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree as jt
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS

from .scene import Scene
from .validation_utils import (
    ValidationError,
    ValidationInfo,
    ValidationMethodCollector,
    ValidationResult,
    ValidationWarning,
    print_validation_results,
)


# helper class to turn observation likelihood(s) into numpyro distribution
class _ObsDistribution(dist.Distribution):
    support = constraints.real_vector

    def __init__(self, obs, model, validate_args=None):
        self.obs = obs
        self.model = model
        event_shape = jnp.shape(model)
        super().__init__(
            event_shape=event_shape,
            validate_args=validate_args,
        )

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def mean(self):
        return self.obs.render(self.model)

    @dist.util.validate_sample
    def log_prob(self, value):
        # numpyro needs sampling distribution of data (=value), not likelihood function of parameters
        return self.obs._log_likelihood(self.model, value)


def _dict_to_eqx_module(name: str, data: dict):
    # Build field annotations dynamically
    annotations = {k: type(v) for k, v in data.items()}

    # Create a class that inherits from eqx.Module
    cls = type(name, (eqx.Module,), {"__annotations__": annotations})

    # Instantiate with the dict values
    return cls(**data)


[docs] def sample(scene, observations, *args, seed=0, num_warmup=100, num_samples=200, progress_bar=True, **kwargs): """Sample `parameters` of every source in `scene` to get posteriors given `observations`. This method runs the HMC NUTS sampler from `numpyro` to get parameter posteriors. It uses the likelihood of `observations` as well as the `prior` attribute set for every :py:class:`~scarlet2.Parameter` in `parameters`. Parameters ---------- scene : :py:class:`~scarlet2.Scene` The model of the scene. observations: :py:class:`~scarlet2.Observation` or list The observations to fit the models to. *args: list, optional Additional arguments passed. Only used for backwards (v0.3) compatibility. seed: int, optional RNG seed for the sampler num_warmup: int, optional Number of samples during HMC warm-up num_samples: int, optional Number of samples to create from tuned HMC progress_bar: bool, optional Whether to show a progress bar **kwargs: dict, optional Additional keyword arguments passed to the `numpyro.infer.NUTS` sampler. Notes ----- Requires `numpyro` Returns ------- numpyro.infer.mcmc.MCMC """ # making sure we can iterate if not isinstance(observations, (list, tuple)): observations = (observations,) obs_params = {} for obs in observations: obs.check_set_renderer(scene.frame) obs_params.update(obs.parameters) # scene and observations can have parameters: combine them into one model parameters = scene.parameters | obs_params if len(parameters) == 0: msg = "Scene and Observation(s) must have at least one parameter. Found none." raise AttributeError(msg) # find all non-fixed parameters and their priors priors = {name: p.prior for name, (idx, p) in parameters.items()} has_none = any(prior is None for prior in priors.values()) if has_none: msg = f"All parameters need to have priors set. Got:\n{pformat(priors)}" raise AttributeError(msg) values = scene.get() for obs in observations: values |= obs.get() init_values = values.copy() # construct eqx.Module containing all parameter arrays as attributes values = _dict_to_eqx_module("ParamModel", values) # define the pyro model, where every parameter value becomes a sample, # and the observations sample from their likelihood given the rendered model def pyro_model(values): samples = {name: numpyro.sample(name, param.prior) for name, (node, param) in parameters.items()} scene_ = scene.set(samples) pred = scene_() # create scene once for all observations # evaluate likelihood of multiple observations for i, obs_ in enumerate(observations): numpyro.sample(f"obs.{i}", _ObsDistribution(obs_.set(values), pred), obs=obs_.data) # if not told otherwise: use init from current value of model init_strategy = kwargs.pop("init_strategy", None) if init_strategy is None: from functools import partial from numpyro.infer.initialization import init_to_value init_strategy = partial(init_to_value, values=init_values) nuts_kernel = NUTS(pyro_model, init_strategy=init_strategy, **kwargs) mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=progress_bar) rng_key = jax.random.PRNGKey(seed) mcmc.run(rng_key, values) return mcmc
[docs] def fit( scene, observations, *args, schedule=None, max_iter=100, e_rel=1e-4, progress_bar=True, callback=None, **kwargs, ): """Fit model `parameters` of every source in `scene` to match `observations`. Computes the best-fit parameters of all components in every source by first-order gradient descent with the Adam optimizer from `optax`. Parameters ---------- scene : :py:class:`~scarlet2.Scene` The model of the scene. observations: :py:class:`~scarlet2.Observation` or list The observations to fit the model to. *args: list, optional Additional arguments passed. Only used for backwards (v0.3) compatibility. schedule: callable, optional A function that maps optimizer step count to value. See :py:class:`optax.Schedule` for details. max_iter: int, optional Maximum number of optimizer iterations e_rel: float, optional Upper limit for the relative change in the norm of any parameter to terminate the optimization early. progress_bar: bool, optional Whether to show a progress bar callback: callable, optional Function to be called on the current state of the optimized scene. Signature `callback(scene, convergence, loss) -> None`, where `convergence` is a tree of the same structure as `scene`, and `loss` is the current value of the log_posterior. **kwargs: dict, optional Additional keyword arguments passed to the `optax.scale_by_adam` optimizer. Notes ----- Requires `optax` Returns ------- Scene, list(Observation) The scene and observation(s) with updated parameters """ try: import optax import optax._src.base as base from tqdm.auto import trange except ImportError as err: raise ImportError("scarlet2.Scene.fit() requires optax and numpyro.") from err # making sure we can iterate if not isinstance(observations, (list, tuple)): observations = (observations,) obs_params = {} for obs in observations: obs.check_set_renderer(scene.frame) obs_params.update(obs.parameters) # scene and observations can have parameters: combine them into one model parameters = scene.parameters | obs_params if len(parameters) == 0: msg = "Scene and Observation(s) must have at least one parameter. Found none." raise AttributeError(msg) values = scene.get() for obs in observations: values |= obs.get() # construct eqx.Module containing all parameter arrays as attributes values = _dict_to_eqx_module("ParamModel", values) # same tree structure but for param specs so that we can use them below treedef = jt.structure(values) params = tuple(param for name, (node, param) in parameters.items()) params = jt.unflatten(treedef, params) steps = jt.map(lambda param: param.stepsize, params) def scale_by_stepsize() -> base.GradientTransformation: # adapted from optax.scale_by_param_block_norm() def init_fn(params): del params return base.EmptyState() def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) updates = jt.map( # minus because we want gradient descent lambda u, s, p: None if u is None else -s * u if not callable(s) else -s(p) * u, updates, steps, params, is_leaf=lambda x: x is None, ) return updates, state return base.GradientTransformation(init_fn, update_fn) # run adam, followed by stepsize adjustments optim = optax.chain( optax.scale_by_adam(**kwargs), optax.scale_by_schedule(schedule if callable(schedule) else lambda x: 1), scale_by_stepsize(), ) # transform to unconstrained parameters values = _constraint_replace(values, params, inv=True) opt_state = optim.init(values) with trange(max_iter, disable=not progress_bar) as t: for step in t: # noqa: B007 # optimizer step values, loss, opt_state, convergence = _make_step( values, params, scene, observations, optim, opt_state ) # compute max change across all non-fixed parameters for convergence test max_change = jax.tree_util.tree_reduce(lambda a, b: max(a, b), convergence) # report current iteration results to callback if callback is not None: values_ = _constraint_replace(values, params) scene_ = scene.set(values_) callback(scene_, convergence, loss) # Log the loss and max_change in the tqdm progress bar t.set_postfix(loss=f"{loss:08.2f}", max_change=f"{max_change:1.6f}") # test convergence if max_change < e_rel: break # transform back to constrained variables and replace in scene values = _constraint_replace(values, params) # scene_ is a copy, but its registry_key still points to scene.parameters, can thus be reused scene_ = scene.set(values) obs_ = tuple(obs.set(values) for obs in observations) # (re)-import `VALIDATION_SWITCH` at runtime to avoid using a static/old value from .validation_utils import VALIDATION_SWITCH if VALIDATION_SWITCH: from .validation import check_fit for obs in observations: print(f"Running validation checks on the fit of the scene for observation {obs.name}.") validation_results = check_fit(scene_, obs) print_validation_results("Fit validation results", validation_results) return scene_, obs_
def _constraint_replace(values, params, inv=False): # replace any parameter with constraint into unconstrained ones by calling its constraint bijector def transform(value, param): if param.constraint is not None: func = param.constraint_transform if inv is False else param.constraint_transform.inv return func(value) else: return value return jt.map(transform, values, params) # update step for optax optimizer @eqx.filter_jit def _make_step(values, params, scene, observations, optim, opt_state): def loss_fn(values): # parameters now obey constraints # transformation happens in the grad path, so gradients are wrt to unconstrained variables # likelihood and prior grads transparently apply the Jacobians of these transformations values_ = _constraint_replace(values, params) scene_ = scene.set(values_)() log_like = sum(obs.set(values_).log_likelihood(scene_) for obs in observations) # add log prior for all parameters which define priors # Note: This calls priors separately even if they support batched execution # see https://github.com/pmelchior/scarlet2/issues/103 for a possible solution # however, testing after #103 got merged suggests that tree_reduce is faster than grouping log_prior = jt.reduce( operator.add, jt.map( lambda value, param: param.prior.log_prob(value) if param.prior is not None else 0, values_, params, ), ) return -(log_like + log_prior) loss, grads = eqx.filter_value_and_grad(loss_fn)(values) updates, opt_state = optim.update(grads, opt_state, values) values_ = eqx.apply_updates(values, updates) # for convergence criterion: compute norms of parameters and updates norm = lambda x, dx: 0 if dx is None else jnp.linalg.norm(dx) / jnp.linalg.norm(x) convergence = jt.map(lambda x, dx: norm(x, dx), values, updates) return values_, loss, opt_state, convergence class FitValidator(metaclass=ValidationMethodCollector): """A class containing all of the validation checks for a Scene objects after calling `.fit()`. Note that the metaclass is defined as `MethodCollector`, which collects all validation methods in this class into a single class attribute list called `validation_checks`. This allows for easy iteration over all checks.""" def __init__(self, scene: Scene, observation): """Initialize the FitValidator. Parameters ---------- scene : Scene The scene object to validate. observation : Observation The observation object containing the data to validate against. """ self.scene = scene self.observation = observation # These are placeholders, waiting for the actual width of distribution to be # implemented - see issue https://github.com/pmelchior/scarlet2/issues/192 self.chi2_tolerable_threshold = 1.5 self.chi2_critical_threshold = 5.0 def check_goodness_of_fit(self) -> ValidationResult: """Evaluate the goodness of the model fit to the data by calling the Observation class's `goodness_of_fit` method. Please see the docstring for that method for details. Returns ------- ValidationResult A subclass of ValidationResult indicating the result of the check. """ obs = self.observation chi2 = obs.goodness_of_fit(self.scene()) context = {"chi2": chi2} ret_val: ValidationResult = ValidationInfo( "The model fit is good.", check=self.__class__.__name__, context=context ) if self.chi2_tolerable_threshold <= chi2 < self.chi2_critical_threshold: ret_val = ValidationWarning( "The model fit is acceptable, but the goodness of fit is not optimal.", check=self.__class__.__name__, context=context, ) elif chi2 >= self.chi2_critical_threshold or jnp.isnan(chi2): ret_val = ValidationError( "The model fit is poor.", check=self.__class__.__name__, context=context ) return ret_val def check_chi_square_in_box_and_border(self) -> list[ValidationResult]: """Evaluate the weighted mean (weighted by the inverse variance weights) of the squared residuals for each source. Chi square is also computed for the perimeter outside the box of with `border_width`. Returns ------- list[ValidationResult] A list of ValidationResult subclasses for each source. For each source there will be two results. One for inside the bounding box and one for the border. The ValidationResults will each be one of the following: - If the chi-square is above the critical threshold, a ValidationError. - If the chi-square is below the tolerable threshold, a ValidationInfo. - If the chi-square is between the two thresholds, a ValidationWarning. """ obs = self.observation chi2_per_source = obs.eval_chi_square_in_box_and_border(self.scene) validation_results: list[ValidationResult] = [] for i, chi2 in chi2_per_source.items(): chi2_inside = chi2["in"] chi2_outside = chi2["out"] if chi2_inside < self.chi2_tolerable_threshold: validation_results.append( ValidationInfo( f"The chi-square in the box for source {i} is good.", check=self.__class__.__name__, context={"chi2_in": chi2_inside, "source": i}, ) ) elif self.chi2_tolerable_threshold <= chi2_inside < self.chi2_critical_threshold: validation_results.append( ValidationWarning( f"The chi-square in the box for source {i} is acceptable, but not optimal.", check=self.__class__.__name__, context={"chi2_in": chi2_inside, "source": i}, ) ) elif chi2_inside >= self.chi2_critical_threshold: validation_results.append( ValidationError( f"The chi-square in the box for source {i} is poor.", check=self.__class__.__name__, context={"chi2_in": chi2_inside, "source": i}, ) ) if chi2_outside < self.chi2_tolerable_threshold: validation_results.append( ValidationInfo( f"The chi-square in the border for source {i} is good.", check=self.__class__.__name__, context={"chi2_border": chi2_outside, "source": i}, ) ) elif self.chi2_tolerable_threshold <= chi2_outside < self.chi2_critical_threshold: validation_results.append( ValidationWarning( f"The chi-square in the border for source {i} is acceptable, but not optimal.", check=self.__class__.__name__, context={"chi2_border": chi2_outside, "source": i}, ) ) elif chi2_outside >= self.chi2_critical_threshold: validation_results.append( ValidationError( f"The chi-square in the border for source {i} is poor.", check=self.__class__.__name__, context={"chi2_border": chi2_outside, "source": i}, ) ) return validation_results