import jax
import jax.numpy as jnp
from . import Scenery
from .bbox import overlap_slices
from .frame import Frame
from .module import Module
from .validation_utils import print_validation_results
[docs]
class Scene(Module):
"""Model of the celestial scene
This class connects the main functionality of `scarlet2`: the fitting of an
:py:class:`~scarlet2.Observation` (or several) by a :py:class:`~scarlet2.Source`
model (or several). Model parameters can be optimized or samples with any method
implemented in jax, but this class provides the :py:func:`fit` and
:py:func:`sample` methods as built-in solutions.
"""
frame: Frame
"""Portion of the sky represented by this model"""
sources: list
"""List of :py:class:`~scarlet2.Source` comprised in this model"""
def __init__(self, frame):
"""
Parameters
----------
frame: `Frame`
Portion of the sky represented by this model
Examples
--------
The class provides a context so that sources can be added to the same model frame:
>>> with Scene(model_frame) as scene:
>>> Source(center, spectrum, morphology)
This adds a single source to the list :py:attr:`~scarlet2.Scene.sources`
of `scene`. The context provides a common definition of the model frame,
so that, e.g., `center` can be given as :py:class:`astropy.coordinates.SkyCoord`
and will automatically be converted to the pixel coordinate in the model frame.
The constructed source does not go out of scope after the `with` context
is closed, it is stored in the scene.
See Also
--------
:py:class:`~scarlet2.Scenery`, :py:class:`~scarlet2.Source`
"""
self.frame = frame
self.sources = list()
[docs]
def __call__(self):
"""What to run when the scene is called"""
model = jnp.zeros(self.frame.bbox.shape)
for source in self.sources:
model += self.evaluate_source(source)
return model
[docs]
def evaluate_source(self, source):
"""Evaluate a single source in the frame of this scene.
This method inserts the model of `source` into the proper location in `scene`.
Parameters
----------
source: :py:class:`~scarlet2.Source`
The source to evaluate.
Returns
-------
array
Array of the dimension indicated by :py:attr:`shape`.
"""
model_ = source()
# cut out region from model, add single source model
bbox, bbox_ = overlap_slices(self.frame.bbox, source.bbox, return_boxes=True)
sub_model_ = jax.lax.dynamic_slice(model_, bbox_.start, bbox_.shape)
# add model_ back in full model
model = jnp.zeros(self.frame.bbox.shape)
model = jax.lax.dynamic_update_slice(model, sub_model_, bbox.start)
return model
def __enter__(self):
# this scene might have parameters defined, we need to reset its registry key
object.__setattr__(self, "registry_key", "")
# context manager to register sources
# purpose is to provide scene.frame to source inits that will need some of its information
# also allows us to append the sources automatically to the scene
Scenery.scene = self
return self
def __exit__(self, exc_type, exc_value, traceback):
Scenery.scene = None
# (re)-import `VALIDATION_SWITCH` at runtime to avoid using a static/old value
from .validation_utils import VALIDATION_SWITCH
if VALIDATION_SWITCH:
from .validation import check_scene
validation_results = check_scene(self)
print_validation_results("Source validation results", validation_results)
[docs]
def fit(
self,
observations,
*args,
schedule=None,
max_iter=100,
e_rel=1e-4,
progress_bar=True,
callback=None,
**kwargs,
):
"""Fit model `parameters` of every source in the scene to match `observations`.
Parameters
----------
observations: :py:class:`~scarlet2.Observation` or list
The observations to fit the model to.
*args: list, optional
Additional arguments passed. Only used for backwards (v0.3) compatibility.
schedule: callable, optional
A function that maps optimizer step count to value. See :py:class:`optax.Schedule` for details.
max_iter: int, optional
Maximum number of optimizer iterations
e_rel: float, optional
Upper limit for the relative change in the norm of any parameter to
terminate the optimization early.
progress_bar: bool, optional
Whether to show a progress bar
callback: callable, optional
Function to be called on the current state of the optimized scene.
Signature `callback(scene, convergence, loss) -> None`, where
`convergence` is a tree of the same structure as `scene`, and `loss`
is the current value of the log_posterior.
**kwargs: dict, optional
Additional keyword arguments passed to the `optax.scale_by_adam` optimizer.
Returns
-------
Scene
The scene model with updated parameters
See Also
--------
:py:func:`~scarlet2.fit`
"""
# making sure we can iterate
if not isinstance(observations, (list, tuple)):
observations = (observations,)
# don't use this function with observation parameters
if any(len(obs.parameters) for obs in observations):
msg = "For Scene.fit(), observations must not have parameters. Use scarlet2.fit() instead."
raise RuntimeError(msg)
from .infer import fit
scene_, _ = fit(
self,
observations,
schedule=schedule,
max_iter=max_iter,
e_rel=e_rel,
progress_bar=progress_bar,
callback=callback,
**kwargs,
)
return scene_
[docs]
def sample(
self, observations, *args, seed=0, num_warmup=100, num_samples=200, progress_bar=True, **kwargs
):
"""Sample `parameters` of every source in the scene to get posteriors given `observations`.
This method runs the HMC NUTS sampler from `numpyro` to get parameter
posteriors. It uses the likelihood of `observations` as well as the `prior`
attribute set for every :py:class:`~scarlet2.Parameter` in `parameters`.
Parameters
----------
observations: :py:class:`~scarlet2.Observation` or list
The observations to fit the models to.
*args: list, optional
Additional arguments passed. Only used for backwards (v0.3) compatibility.
seed: int, optional
RNG seed for the sampler
num_warmup: int, optional
Number of samples during HMC warm-up
num_samples: int, optional
Number of samples to create from tuned HMC
progress_bar: bool, optional
Whether to show a progress bar
**kwargs: dict, optional
Additional keyword arguments passed to the `numpyro.infer.NUTS` sampler.
Returns
-------
numpyro.infer.mcmc.MCMC
See Also
--------
:py:func:`~scarlet2.sample`
"""
from .infer import sample
return sample(
self,
observations,
seed=seed,
num_warmup=num_warmup,
num_samples=num_samples,
progress_bar=progress_bar,
**kwargs,
)