ScorePrior#
- class scarlet2.nn.ScorePrior(model, shape, *args, **kwargs)[source]#
Bases:
DistributionScore-matching neural network to represent the prior distribution
This class is used to calculate the gradient of the log-probability of the prior distribution. A custom vjp is created to return the score when calling jax.grad().
- Parameters:
model (
callable) – Returns the score value given parameter: model(x) -> scoreshape (
tuple) – Shape of the parameter the model can accept*args (
tuple) – List of unnamed parameter for model, e.g. model(x, *args) -> score**kwargs (
dict) – List of named parameter for model, e.g. model(x, **kwargs) -> score
- property batch_shape: tuple[int, ...]#
Returns the shape over which the distribution parameters are batched.
- Returns:
batch shape of the distribution.
- Return type:
tuple[int, …]
- cdf(value: Array | ndarray | bool | number | bool | int | float | complex) Array | ndarray | bool | number | bool | int | float | complex#
The cumulative distribution function of this distribution.
- Parameters:
value – samples from this distribution.
- Returns:
output of the cumulative distribution function evaluated at value.
- entropy() Array | ndarray | bool | number | bool | int | float | complex#
Returns the entropy of the distribution.
- enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex#
Returns an array with shape len(support) x batch_shape containing all values in the support.
- property event_dim: int#
return: Number of dimensions of individual events. :rtype: int
- property event_shape: tuple[int, ...]#
Returns the shape of a single sample from the distribution without batching.
- Returns:
event shape of the distribution.
- Return type:
tuple[int, …]
- expand(batch_shape: tuple[int, ...]) Distribution#
Returns a new
ExpandedDistributioninstance with batch dimensions expanded to batch_shape.- Parameters:
batch_shape (tuple) – batch shape to expand to.
- Returns:
an instance of ExpandedDistribution.
- Return type:
ExpandedDistribution
- expand_by(sample_shape: tuple[int, ...]) Distribution#
Expands a distribution by adding
sample_shapeto the left side of itsbatch_shape. To expand internal dims ofself.batch_shapefrom 1 to something larger, useexpand()instead.- Parameters:
sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.
- Returns:
An expanded version of this distribution.
- Return type:
ExpandedDistribution
- get_args() dict#
Get arguments of the distribution.
- icdf(q: Array | ndarray | bool | number | bool | int | float | complex) Array | ndarray | bool | number | bool | int | float | complex#
The inverse cumulative distribution function of this distribution.
- Parameters:
q – quantile values, should belong to [0, 1].
- Returns:
the samples whose cdf values equals to q.
- classmethod infer_shapes(*args, **kwargs)#
Infers
batch_shapeandevent_shapegiven shapes of args to__init__().Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters:
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns:
A pair
(batch_shape, event_shape)of the shapes of a distribution that would be created with input args of the given shapes.- Return type:
tuple
- log_prob(x)[source]#
Evaluates the log probability density for a batch of samples given by value.
- Parameters:
value – A batch of samples from the distribution.
- Returns:
an array with shape value.shape[:-self.event_shape]
- Return type:
ArrayLike
- mask(mask: Array) MaskedDistribution#
Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions
Distribution.batch_shape.- Parameters:
mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).
- Returns:
A masked copy of this distribution.
- Return type:
MaskedDistribution
Example:
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.distributions import constraints >>> from numpyro.infer import SVI, Trace_ELBO >>> def model(data, m): ... f = numpyro.sample("latent_fairness", dist.Beta(1, 1)) ... with numpyro.plate("N", data.shape[0]): ... # only take into account the values selected by the mask ... masked_dist = dist.Bernoulli(f).mask(m) ... numpyro.sample("obs", masked_dist, obs=data) >>> def guide(data, m): ... alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive) ... beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive) ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)]) >>> # select values equal to one >>> masked_array = jnp.where(data == 1, True, False) >>> optimizer = numpyro.optim.Adam(step_size=0.05) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array) >>> params = svi_result.params >>> # inferred_mean is closer to 1 >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
- sample(key, sample_shape=())[source]#
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters:
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns:
an array of shape sample_shape + batch_shape + event_shape
- Return type:
numpy.ndarray
- sample_with_intermediates(key: prng_key | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex#
Same as
sampleexcept that any intermediate computations are returned (useful for TransformedDistribution).- Parameters:
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns:
an array of shape sample_shape + batch_shape + event_shape
- Return type:
numpy.ndarray
- shape(sample_shape: tuple[int, ...] = ()) tuple[int, ...]#
The tensor shape of samples from this distribution.
Samples are of shape:
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
- Parameters:
sample_shape (tuple) – the size of the iid batch to be drawn from the distribution.
- Returns:
shape of samples.
- Return type:
tuple
- to_event(reinterpreted_batch_ndims: int | None = None) Distribution#
Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
- Parameters:
reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims.
- Returns:
An instance of Independent distribution.
- Return type:
- validate_args(strict: bool = True) None#
Validate the arguments of the distribution.
- Parameters:
strict – Require strict validation, raising an error if the function is called inside jitted code.
- property variance: Array | ndarray | bool | number | bool | int | float | complex#
Variance of the distribution.