import equinox as eqx
import jax.numpy as jnp
from . import Scenery
from .module import Module
[docs]
class Spectrum(Module):
"""Spectrum base class"""
@property
def shape(self):
"""Shape (1D) of the spectrum model"""
raise NotImplementedError
[docs]
class StaticArraySpectrum(Spectrum):
"""Static (non-variable) source in a transient scene
In the frames of transient scenes, the attribute :py:attr:`~scarlet2.Frame.channels`
are overloaded and defined with a spectral and a temporal component, e.g.
`channel = (band, epoch)`.
This class is for models that do not vary in time, i.e. only have a spectral dependency.
The length of :py:attr:`data` is thus given by the number of distinct spectral bands.
"""
data: jnp.array
"""Data to describe the static spectrum
The order in this array should be given by :py:attr:`bands`.
"""
bands: list
"""Identifier for the list of unique bands in the model frame channels"""
_channelindex: jnp.array = eqx.field(repr=False)
def __init__(self, data, bands, band_selector=lambda channel: channel[0]):
"""
Parameters
----------
data: array
Spectrum without temporal variation. Contains as many elements as there
are spectral channels in the model.
bands: list, array
Identifier for the list of unique bands in the model frame channels
band_selector: callable, optional
Identify the spectral "band" component from the name/ID used in the
channels of the model frame
Examples
--------
>>> # model channels: [('G',0),('G',1),('R',0),('R',1),('R',2)]
>>> spectrum = jnp.ones(2)
>>> bands = ['G','R']
>>> band_selector = lambda channel: channel[0]
>>> StaticArraySpectrum(spectrum, bands, band_selector=band_selector)
This constructs a 2-element spectrum to describe the spectral properties
in all epochs 0,1,2.
See Also
--------
TransientArraySpectrum
"""
try:
frame = Scenery.scene.frame
except AttributeError:
print("Source can only be created within the context of a Scene")
print("Use 'with Scene(frame) as scene: Source(...)'")
raise
self.data = data
self.bands = bands
self._channelindex = jnp.array([self.bands.index(band_selector(c)) for c in frame.channels])
[docs]
def __call__(self):
"""What to run when the StaticArraySpectrum is called"""
return self.data[self._channelindex]
@property
def shape(self):
"""The shape of the spectrum data"""
return (len(self._channelindex),)
[docs]
class TransientArraySpectrum(Spectrum):
"""Variable source in a transient scene with possible quiescent periods
In the frames of transient scenes, the attribute :py:attr:`~scarlet2.Frame.channels`
are overloaded and defined with a spectral and a temporal component, e.g.
`channel = (band, epoch)`. This class is for models that vary in time, especially
if they have periods of inactivity. The length of :py:attr:`data` is given by
the number channels in the model frame, but during inactive epochs, the emission
is set to zero.
"""
data: jnp.array
"""Data to describe the variable spectrum.
The length of this vector is identical to the number of channels in the model frame.
"""
epochs: list
"""Identifier for the list of active epochs. If set to `None`, all epochs are
considered active"""
_epochmultiplier: jnp.array = eqx.field(repr=False)
def __init__(self, data, epochs=None, epoch_selector=lambda channel: channel[1]):
"""
Parameters
----------
data: array
Spectrum array. Contains as many elements as there are spectro-temporal
channels in the model.
epochs: list, array, optional
List of temporal "epoch" identifiers for the active phases of the source.
epoch_selector: callable, optional
Identify the temporal "epoch" component from the name/ID used in the
channels of the model frame
Examples
--------
>>> # model channels: [('G',0),('G',1),('R',0),('R',1),('R',2)]
>>> spectrum = jnp.ones(5)
>>> epochs = [0, 1]
>>> epoch_selector = lambda channel: channel[1]
>>> TransientArraySpectrum(spectrum, epochs, epoch_selector=epoch_selector)
This sets the spectrum to active during epochs 0 and 1, and mask the
spectrum element for `('R',2)` with zero.
See Also
--------
StaticArraySpectrum
"""
try:
frame = Scenery.scene.frame
except AttributeError:
print("Source can only be created within the context of a Scene")
print("Use 'with Scene(frame) as scene: Source(...)'")
raise
self.data = data
self.epochs = epochs
self._epochmultiplier = jnp.array(
[1.0 if epoch_selector(c) in epochs else 0.0 for c in frame.channels]
)
[docs]
def __call__(self):
"""What to run when the TransientArraySpectrum is called"""
return jnp.multiply(self.data, self._epochmultiplier)
@property
def shape(self):
"""The shape of the spectrum data"""
return self.data.shape