Source code for scarlet2.nn

# To be removed as part of issue #168
# ruff: noqa: D101
# ruff: noqa: D102
# ruff: noqa: D103
# ruff: noqa: D106

"""Neural network priors"""

from functools import partial

import equinox as eqx
import jax.numpy as jnp
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from jax import custom_vjp, vjp


[docs] def pad_fwd(x, model_shape): """Zero-pads the input image to the model size Parameters ---------- x : jnp.array data to be padded model_shape : tuple shape of the prior model to be used Returns ------- x : jnp.array data padded to same size as model_shape pad: tuple padding amount in every dimension """ assert all(model_shape[d] >= x.shape[d] for d in range(x.ndim)), ( "Model size must be larger than data size" ) if model_shape == x.shape: pad = 0 return x, pad pad = tuple( # even padding (int(gap / 2), int(gap / 2)) if (gap := model_shape[d] - x.shape[d]) % 2 == 0 # uneven padding else (int(gap // 2), int(gap // 2) + 1) # over all dimensions for d in range(x.ndim) ) # perform the zero-padding x = jnp.pad(x, pad, "constant", constant_values=0) return x, pad
# reverse pad back to original size
[docs] def pad_back(x, pad): """Removes the zero-padding from the input image Parameters --------- x : jnp.array padded data to same size as model_shape pad: tuple padding amount in every dimension Returns ------- x : jnp.array data returned to it pre-pad shape """ slices = tuple(slice(low, -hi) if hi > 0 else slice(low, None) for (low, hi) in pad) return x[slices]
# calculate score function (jacobian of log-probability)
[docs] def calc_grad(x, model): """Calculates the gradient of the log-prior using the ScoreNet model chosen Parameters ---------- x: array of the data model: the model to calculate the score function Returns ------- score_func : array of the score function """ # cast to float32, expand to (batch, shape), and pad to match the shape of the score model x_, pad = pad_fwd(jnp.float32(x), model.shape) # run score model, expects (batch, shape) if jnp.ndim(x) == len(model.shape): x_ = jnp.expand_dims(x_, axis=0) score_func = model.func(x_) if jnp.ndim(x) == len(model.shape): score_func = jnp.squeeze(score_func, axis=0) # remove padding if pad != 0: score_func = pad_back(score_func, pad) return score_func
# jax gradient function to calculate jacobian
[docs] def vgrad(f, x): y, vjp_fn = vjp(f, x) return vjp_fn(jnp.ones(y.shape))[0]
# Here we define a custom vjp for the log_prob function # such that for gradient calls in jax, the score prior # is returned @partial(custom_vjp, nondiff_argnums=(0,)) def _log_prob(model, x): return 0.0 def _log_prob_fwd(model, x): score_func = calc_grad(x, model) return 0.0, score_func # cannot directly call log_prob in Class object def _log_prob_bwd(model, res, g): score_func = res # Get residuals computed in f_fwd return (g * score_func,) # create the vector (g) jacobian (score_func) product # register the custom vjp _log_prob.defvjp(_log_prob_fwd, _log_prob_bwd)
[docs] class ScorePrior(dist.Distribution):
[docs] class ScoreWrapper(eqx.Module): func: callable shape: tuple
support = constraints.real_vector _model = ScoreWrapper def __init__(self, model, shape, *args, **kwargs): """Score-matching neural network to represent the prior distribution This class is used to calculate the gradient of the log-probability of the prior distribution. A custom vjp is created to return the score when calling `jax.grad()`. Parameters ---------- model: callable Returns the score value given parameter: `model(x) -> score` shape: tuple Shape of the parameter the model can accept *args: tuple List of unnamed parameter for model, e.g. `model(x, *args) -> score` **kwargs: dict List of named parameter for model, e.g. `model(x, **kwargs) -> score` """ # helper class that ensures the model function binds the args/kwargs and has a shape wrapper = ScorePrior.ScoreWrapper(partial(model.__call__, *args, **kwargs), shape) self._model = wrapper super().__init__( validate_args=None, )
[docs] def __call__(self, x): return self._model.func(x)
[docs] def sample(self, key, sample_shape=()): # TODO: add ability to draw samples from the prior, if desired raise NotImplementedError
[docs] def mean(self): raise NotImplementedError
[docs] def log_prob(self, x): return _log_prob(self._model, x)