Use Neural Priors#
In the sampling tutorial, we have demonstrated how to define parameters with priors.
This guide shows you how to set up and use neural network priors for the optimization of the fitted scene.
We make use of the related package galaxygrad, which can be pip-installed.
More details about the use of a score-based prior model for diffusion can be found in the paper “Score-matching neural networks for improved multi-band source separation”, Sampson et al., 2024, A&C, 49, 100875.
This guide will follow the Real-World Example, with changes in the initialization and parameter specification.
# Import Packages and setup
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scarlet2 as sc2
sc2.set_validation(False)
Create Observation#
Again we import the test data and create the observation:
from huggingface_hub import hf_hub_download
filename = hf_hub_download(
repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset"
)
file = jnp.load(filename)
data = file["images"]
channels = [str(f) for f in file["filters"]]
centers = jnp.array([(src["y"], src["x"]) for src in file["catalog"]]) # Note: y/x convention!
weights = 1 / file["variance"]
psf = file["psfs"]
# create the observation
obs = sc2.Observation(
data,
weights,
psf=psf,
channels=channels,
)
model_frame = sc2.Frame.from_observations(obs)
obs.match(model_frame)
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Initialize Sources#
with sc2.Scene(model_frame) as scene:
for i, center in enumerate(centers):
if i == 0: # we know source 0 is a star
spectrum = sc2.init.pixel_spectrum(obs, center, correct_psf=True)
sc2.PointSource(center, spectrum)
else:
try:
spectrum, morph = sc2.init.from_gaussian_moments(obs, center, min_corr=0.99)
except ValueError:
spectrum = sc2.init.pixel_spectrum(obs, center)
morph = sc2.init.compact_morphology()
sc2.Source(center, spectrum, morph)
Load Neural Prior#
# load in the model you wish to use
from galaxygrad import get_prior
from scarlet2.nn import ScorePrior
# instantiate the prior class
temp = 1e-2 # values in the range of [1e-3, 1e-1] produce good results
prior32 = get_prior("hsc32")
prior64 = get_prior("hsc64")
prior32 = ScorePrior(prior32, prior32.shape(), t=temp)
prior64 = ScorePrior(prior64, prior64.shape(), t=temp)
The prior model itself is in the form of a score-based diffusion model, which matches the score function, i.e. the gradient of the log-likelihood of the training data with respect to the parameters. For an image-based parameterization, the free parameters are the pixels, which means the gradient has the same shape as the image. galaxygrad provides several pre-trained models, here we use a prior that was trained on deblended isolate source in HSC data, with the shapes of 32x32 or 64x64, respectively. These sizes denote the maximum image size for which the prior is trained.
We import ScorePrior to use with our prior. It automatically zero-pads any smaller image array up to the specified size and provides a custom gradient path that calls the underlying score model during optimization or HMC sampling. The temp argument refers to a fixed temperature for the diffusion process. For speed, we run a single diffusion step with the given temperature.
Define Parameters with Prior#
We use the same fitting routine as in the Quickstart guide, but replace contraints.unit_interval with prior=prior in the Parameter containing the source morphologies. As we will be running an optimizer, not a sampler, only the gradients of the prior matter (which is what galaxgrad provides, but we still need to define stepsize for this parameter. It is useful to maintain small step sizes for the morphology updates because large jumps can lead to unstable prior gradients.
from functools import partial
from numpyro.distributions import constraints
spec_step = partial(sc2.relative_step, factor=0.05)
morph_step = partial(sc2.relative_step, factor=1e-3)
with sc2.Parameters(scene):
for i in range(len(scene.sources)):
sc2.Parameter(
scene.sources[i].spectrum,
name=f"spectrum:{i}",
constraint=constraints.positive,
stepsize=spec_step,
)
if i == 0:
sc2.Parameter(scene.sources[i].center, name=f"center:{i}", stepsize=0.1)
else:
# chose a prior of suitable size
prior = prior32 if max(scene.sources[i].morphology.shape) <= 32 else prior64
sc2.Parameter(
scene.sources[i].morphology,
name=f"morph:{i}",
prior=prior, # attach the prior here
stepsize=morph_step,
)
Now we again run the fitter:
maxiter = 1000
print("Initial chi^2:", obs.goodness_of_fit(scene()))
scene_ = scene.fit(obs, max_iter=maxiter, e_rel=1e-4, progress_bar=True)
print("Optimized chi^2:", obs.goodness_of_fit(scene_()))
Initial chi^2: 433.39108
The fit reaches values quite comparable to the run in the quickstart guide. But let’s look at the sources…
Check Results#
norm = sc2.plot.AsinhAutomaticNorm(obs)
sc2.plot.scene(
scene_,
obs,
norm=norm,
show_model=True,
show_rendered=True,
show_observed=True,
show_residual=True,
add_boxes=True,
)
plt.show()
sc2.plot.sources(
scene_,
norm=norm,
observation=obs,
show_model=True,
show_rendered=True,
show_observed=True,
show_spectrum=False,
add_labels=False,
add_boxes=True,
)
plt.show()
The results for most of the galaxies look very reasonable now, in particular for the fainter ones because the prior is more important for them. So, they remain compact and not overly affected by noise. Source #1 has minor artifacts and picks up neighboring objects, indicating that this prior has not been trained (yet) on as many larger galaxies and is therefore still somewhat weak. An update will fix this soon.