Sample from Posterior#
scarlet2 can provide samples from the posterior distribution to pass to downstream operations and as the most precise option for uncertainty quantification. In principle, we can get posterior samples for every parameter, and this can be done with any sampler by evaluating the log-posterior distribution. For this guide we will use the Hamiltonian Monte Carlo sampler from numpyro, for which we created a convenient front-end in scarlet2.
We start from the Real-World Example, loading the same data and the best-fitting model.
# Import Packages and setup
import astropy.units as u
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scarlet2 as sc2
sc2.set_validation(False)
Create Observation#
We need to create the Observation because it contains the log_likelihood() method we need for the posterior:
from huggingface_hub import hf_hub_download
filename = hf_hub_download(
repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset"
)
file = jnp.load(filename)
data = file["images"]
channels = [str(f) for f in file["filters"]]
centers = jnp.array([(src["y"], src["x"]) for src in file["catalog"]]) # Note: y/x convention!
weights = 1 / file["variance"]
psf = file["psfs"]
# create the observation
obs = sc2.Observation(data, weights, psf=psf, channels=channels)
Load Model#
We can make use of the best-fit model from the Quickstart guide as the starting point of the sampler.
import scarlet2.io
id = 35
filename = "hsc_cosmos.h5"
scene = scarlet2.io.model_from_h5(filename, path="..", id=id)
Let’s have a look:
norm = sc2.plot.AsinhAutomaticNorm(obs)
sc2.plot.scene(scene, observation=obs, norm=norm, add_boxes=True)
plt.show()
Define Parameters with Prior#
We will demonstrate sampling from the spectrum and the center position of the point source #0. We could do more parameters, it would just take longer, so it makes sense to sample simultaneously from those parameters that might be correlated.
To run the sampler, we need to set the prior attribute for each of these parameters. The attribute stepsize is only used by the optimizer and will be ignored by the sampler.
import numpyro.distributions as dist
C = len(channels)
with sc2.Parameters(scene):
# rough guess of source brightness across bands: 0 .. 500 in each channel
# Note:
# .to_event(1) declares the last 1 dimension of low/high to be the
# dimension of the parameter, so we get one probability evaluation,
# not C independent ones
# More details here: https://pyro.ai/examples/tensor_shapes.html
prior1 = dist.Uniform(low=jnp.zeros(C), high=500 * jnp.ones(C)).to_event(1)
sc2.Parameter(scene.sources[0].spectrum, name="spectrum", prior=prior1)
# initial position was integer pixel coordinate
# assume 0.1 arcsec uncertainty around current value
p2 = scene.sources[0].center
prior2 = dist.Normal(centers[0], scale=0.1 * u.arcsec).to_event(1)
sc2.Parameter(p2, name="center", prior=prior2)
Warning
You are responsible to set reasonable priors, which describe what you know about the parameter before having looked at the data. In the example above, the spectrum gets a wide flat prior, and the center prior uses the position centers[0], which is given by the original detection catalog. Neither use information from the optimized scene.
Also: If in doubt how much prior choices matter, vary them within reason.
Run Sampler#
Then we can run numpyro’s NUTS sampler with a call to sample():
mcmc = sc2.sample(
scene,
obs,
num_warmup=100,
num_samples=1000,
progress_bar=False,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
center[0] 32.91 0.00 32.91 32.91 32.91 1362.50 1.00
center[1] 13.58 0.00 13.58 13.58 13.58 1222.38 1.00
spectrum[0] 166.85 0.16 166.85 166.61 167.10 1022.41 1.00
spectrum[1] 264.45 0.14 264.45 264.20 264.68 1236.12 1.00
spectrum[2] 320.42 0.18 320.42 320.15 320.69 1150.34 1.00
spectrum[3] 346.79 0.29 346.79 346.34 347.27 499.64 1.01
spectrum[4] 369.33 0.62 369.29 368.24 370.33 229.92 1.00
Number of divergences: 0
Access Samples#
The samples can be accessed from the MCMC chain and are listed as arrays under the names chosen above for the respective Parameter.
import pprint
samples = mcmc.get_samples()
pprint.pprint(samples)
{'center': Array([[32.9129 , 13.581251],
[32.91349 , 13.580794],
[32.913067, 13.581468],
...,
[32.912937, 13.583058],
[32.915264, 13.581376],
[32.91406 , 13.581509]], dtype=float32),
'spectrum': Array([[166.65988, 264.08054, 320.49374, 346.48755, 370.5157 ],
[166.80688, 264.69696, 320.6934 , 346.64993, 369.98914],
[166.91756, 264.2449 , 320.21616, 346.9572 , 369.8897 ],
...,
[166.85081, 264.3168 , 320.45975, 346.94653, 368.85568],
[166.9205 , 264.4537 , 320.33667, 346.8351 , 368.71603],
[166.93126, 264.43652, 320.38535, 346.86258, 368.7247 ]], dtype=float32)}
To create versions of the scene for any of the samples, we first select a few at random and then use the method scarlet2.Module.set() to set their values:
# get values for three random samples
S = 3
seed = 42
import jax.random
key = jax.random.key(seed)
idxs = jax.random.randint(key, shape=(S,), minval=0, maxval=mcmc.num_samples)
# dictionary: parameter name -> value
values = [{name: samples[name][idx] for name in samples.keys()} for idx in idxs]
# create versions of the scene with these posterior samples
scenes = [scene.set(v) for v in values]
# display the source models
fig, axes = plt.subplots(1, S, figsize=(10, 4))
for s in range(S):
source_array = scenes[s].sources[0]()
axes[s].imshow(sc2.plot.img_to_rgb(source_array, norm=norm))
The difference are imperceptible for this source which tells us that the data were highly informative. But we can measure e.g. the total fluxes for each sample
print(f"-------------- {channels}")
for i, scene in enumerate(scenes):
print(f"Flux Sample {i}: {sc2.measure.flux(scene.sources[0])}")
-------------- ['g', 'r', 'i', 'z', 'y']
Flux Sample 0: [166.84314 264.14975 320.51938 346.83893 369.36435]
Flux Sample 1: [166.73085 264.54507 320.12756 346.85266 370.91626]
Flux Sample 2: [166.81613 264.3625 320.36237 346.77423 367.91544]
Visualize Posterior#
We can also visualize the posterior distributions, e.g. with the corner package:
import corner
corner.corner(mcmc);
/home/docs/checkouts/readthedocs.org/user_builds/scarlet2/envs/stable/lib/python3.10/site-packages/arviz/__init__.py:50: FutureWarning:
ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
warn(