Scene#

class scarlet2.Scene(frame)[source]#

Bases: Module

Model of the celestial scene

This class connects the main functionality of scarlet2: the fitting of an Observation (or several) by a Source model (or several). Model parameters can be optimized or samples with any method implemented in jax, but this class provides the fit() and sample() methods as built-in solutions.

Parameters:

frame (Frame) – Portion of the sky represented by this model

Examples

The class provides a context so that sources can be added to the same model frame:

>>> with Scene(model_frame) as scene:
>>>    Source(center, spectrum, morphology)

This adds a single source to the list sources of scene. The context provides a common definition of the model frame, so that, e.g., center can be given as astropy.coordinates.SkyCoord and will automatically be converted to the pixel coordinate in the model frame.

The constructed source does not go out of scope after the with context is closed, it is stored in the scene.

See also

Scenery, Source

__call__()[source]#

What to run when the scene is called

evaluate_source(source)[source]#

Evaluate a single source in the frame of this scene.

This method inserts the model of source into the proper location in scene.

Parameters:

source (Source) – The source to evaluate.

Returns:

Array of the dimension indicated by shape.

Return type:

array

fit(observations, *args, schedule=None, max_iter=100, e_rel=0.0001, progress_bar=True, callback=None, **kwargs)[source]#

Fit model parameters of every source in the scene to match observations.

Parameters:
  • observations (Observation or list) – The observations to fit the model to.

  • *args (list, optional) – Additional arguments passed. Only used for backwards (v0.3) compatibility.

  • schedule (callable, optional) – A function that maps optimizer step count to value. See optax.Schedule for details.

  • max_iter (int, optional) – Maximum number of optimizer iterations

  • e_rel (float, optional) – Upper limit for the relative change in the norm of any parameter to terminate the optimization early.

  • progress_bar (bool, optional) – Whether to show a progress bar

  • callback (callable, optional) – Function to be called on the current state of the optimized scene. Signature callback(scene, convergence, loss) -> None, where convergence is a tree of the same structure as scene, and loss is the current value of the log_posterior.

  • **kwargs (dict, optional) – Additional keyword arguments passed to the optax.scale_by_adam optimizer.

Returns:

The scene model with updated parameters

Return type:

Scene

See also

fit()

fit_info: dict = None#

Diagnostic info populated by fit() (loss history, best loss, n_iter)

frame: Frame#

Portion of the sky represented by this model

get(name=None)#

Get parameter(s) from this module

Parameters:

name (str, optional) – Name of parameter. If not set, returns all parameters.

Returns:

requested data arrays for parameters

Return type:

dict

property parameters#

Parameters defined for this module

Returns:

name: (node, param) mapping for all parameters

Return type:

dict

replace(name, value)#

Replace member attribuge name with value

Parameters:
  • name (str) – Name of member to replace

  • value (any) – Value to replace member with

Returns:

The modified module.

Return type:

Module

sample(observations, *args, seed=0, num_warmup=100, num_samples=200, progress_bar=True, **kwargs)[source]#

Sample parameters of every source in the scene to get posteriors given observations.

This method runs the HMC NUTS sampler from numpyro to get parameter posteriors. It uses the likelihood of observations as well as the prior attribute set for every Parameter in parameters.

Parameters:
  • observations (Observation or list) – The observations to fit the models to.

  • *args (list, optional) – Additional arguments passed. Only used for backwards (v0.3) compatibility.

  • seed (int, optional) – RNG seed for the sampler

  • num_warmup (int, optional) – Number of samples during HMC warm-up

  • num_samples (int, optional) – Number of samples to create from tuned HMC

  • progress_bar (bool, optional) – Whether to show a progress bar

  • **kwargs (dict, optional) – Additional keyword arguments passed to the numpyro.infer.NUTS sampler.

Return type:

numpyro.infer.mcmc.MCMC

See also

sample()

set(values)#

Set parameter(s) from this module with values

Parameters:

values (dict[str,jnp.array]) – values to replace parameters with, identified by their name

Returns:

new module with parameter(s) replaced by values

Return type:

Module

sources: list#

List of Source comprised in this model