import astropy.units as u
import astropy.wcs
import jax.numpy as jnp
import numpy as np
from astropy.coordinates import SkyCoord
from .bbox import Box
from .module import Module
from .psf import PSF, ArrayPSF, GaussianPSF
[docs]
class Frame(Module):
"""Definition of a view of the sky
This class combines all elements to determine how a piece of the sky will appear.
It includes metadata about the spatial and spectral coverage and resolution.
"""
bbox: Box
"""Bounding box of the frame"""
psf: (PSF, None)
"""PSF of the frame"""
wcs: astropy.wcs.WCS
"""WCS information of the frame"""
channels: list
"""Identifiers for the spectral elements"""
def __init__(self, bbox, psf=None, wcs=None, channels=None):
self.bbox = bbox
if isinstance(psf, (list, tuple, np.ndarray, jnp.ndarray)):
psf = jnp.asarray(psf).astype(float)
psf = ArrayPSF(psf)
self.psf = psf
if wcs is None:
wcs = _wcs_default(bbox.spatial.shape)
self.wcs = wcs
if channels is None:
channels = list(range(bbox.shape[0]))
self.channels = channels
def __hash__(self):
return hash(self.bbox)
@property
def C(self) -> int: # noqa: N802
"""Number of channels"""
return len(self.channels)
@property
def H(self) -> int: # noqa: N802
"""Height: number of pixels in the vertical direction"""
return self.bbox.spatial.shape[0]
@property
def W(self) -> int: # noqa: N802
"""Width: number of pixels in the horizontal direction"""
return self.bbox.spatial.shape[1]
[docs]
def get_pixel(self, pos):
"""Get the sky coordinates from a world coordinate
Parameters
----------
pos: jnp.ndarray or SkyCoord
Coordinates on the sky
Returns
---------
pixel coordinates in the model frame
"""
if isinstance(pos, SkyCoord):
# WCS uses x/y convention
wcs = self.wcs.celestial # only use celestial portion
x, y = pos.to_pixel(wcs)
pos = jnp.stack([y, x], axis=-1) # convert to y/x
return pos
[docs]
def get_sky_coord(self, pos):
"""Get the sky coordinate from a pixel coordinate
Parameters
----------
pos: jnp.ndarray
Coordinates in the pixel space
Returns
----------
astropy.coordinates.SkyCoord
"""
wcs = self.wcs.celestial # only use celestial portion
if jnp.ndim(pos) > 1:
pos = jnp.asarray(pos).reshape(-1, 2).T.reshape(2, *jnp.shape(pos)[:-1])
sky_coord = SkyCoord.from_pixel(pos[1], pos[0], wcs)
return sky_coord
[docs]
def convert_pixel_to(self, target, pixel=None):
"""Converts pixel coordinates from this frame to `target` frame
Parameters
----------
target: :py:class:`~scarlet2.Frame`
target frame
pixel: array
Pixel coordinates in this frame. If not set, convert all pixels in this frame
Returns
-------
array
coordinates at the location of `pixel` in the frame `target`
"""
if pixel is None:
y, x = jnp.indices(self.bbox.spatial.shape, dtype="float32")
pixel = jnp.stack((y.flatten(), x.flatten()), axis=1)
ra_dec = self.get_sky_coord(pixel)
return target.get_pixel(ra_dec)
[docs]
def u_to_pixel(self, distance):
"""Converts celestial distance to pixel size according to this frame WCS
Parameters
----------
distance: :py:class:`astropy.units.Quantity`
Physical size, must be `PhysicalType("angle")`
Returns
-------
float
size in pixels
"""
if isinstance(distance, u.Quantity):
pixel_size = get_pixel_size(self.wcs)
return (distance / pixel_size).to(1, equivalencies=u.dimensionless_angles()).value
else:
return distance
[docs]
def pixel_to_angle(self, size):
"""Converts pixel size to celestial distance according to this frame WCS
Parameters
----------
size: float
The size in pixels
Returns
-------
distance: :py:class:`astropy.units.Quantity`
"""
pixel_size = get_pixel_size(self.wcs)
distance = size * pixel_size
return distance
[docs]
@staticmethod
def from_observations(observations, model_psf=None, model_wcs=None, reference_id=None, coverage="union"):
"""Generates a suitable model frame for a set of observations.
This method generates a frame from a set of observations by identifying the highest resolution
and the smallest PSF and use them to construct a common frame for all observations.
Parameters
----------
observations: list
list of :py:class:`~scarlet2.Observation` to determine a common frame
model_psf: :py:class:`~scarlet2.PSF`, optional
PSF to be adopted for the model frame. This is the effective resolution
of the model, and all observations are to be deconvolved to this limit.
If None, uses the smallest PSF across all observations and channels.
model_wcs: :py:class:`astropy.wcs.WCS`
WCS for the model frame. If None, uses WCS information of the
observation with the smallest pixels.
reference_id: int, optional
index of the reference observation.
If set to None, uses the observation with the smallest pixels.
coverage: "union" or "intersection"
Sets the frame to incorporate the pixels covered by any observation
('union') or by all observations ('intersection').
"""
assert coverage in ["union", "intersection"]
if not hasattr(observations, "__iter__"):
observations = (observations,)
# Array of pixel sizes for each observation
pix_tab = []
# Array of psf size for each psf of each observation
small_psf_size = None
channels = []
# Create frame channels and find smallest and largest psf
for c, obs in enumerate(observations):
# Concatenate all channels
channels = channels + obs.frame.channels
# concatenate all pixel sizes
h_temp = get_pixel_size(obs.frame.wcs)
if isinstance(h_temp, u.Quantity):
h_temp = h_temp.to(u.arcsec).value # standardize pixel sizes, using simple scalars below
pix_tab.append(h_temp)
# Looking for the sharpest PSF
psf = obs.frame.psf.morphology
for psf_channel in psf:
psf_size = get_psf_size(psf_channel) * h_temp
if (
model_psf is None
and ((reference_id is None) or (c == reference_id))
and ((small_psf_size is None) or (psf_size < small_psf_size))
):
small_psf_size = psf_size
# Find a reference observation. Either provided by obs_id or as the
# observation with the smallest pixel
if reference_id is None:
p = jnp.array(pix_tab)
obs_ref = observations[jnp.where(p == p.min())[0][0]]
else:
# Frame defined from obs_id
obs_ref = observations[reference_id]
# Reference wcs
if model_wcs is None:
model_wcs = obs_ref.frame.wcs.deepcopy()
# Scale of the model pixel
h = get_pixel_size(model_wcs)
# If needed and psf is not provided: interpolate psf to smallest pixel
if model_psf is None:
# create Gaussian PSF with a sigma smaller than the smallest observed PSF
sigma = 0.7
assert small_psf_size / h > sigma, (
f"Default model PSF width ({sigma} pixel) too large for best-seeing observation"
)
model_psf = GaussianPSF(sigma=sigma)
# Dummy frame for WCS computations
model_shape = (len(channels), 0, 0)
model_frame = Frame(Box(model_shape), channels=channels, psf=model_psf, wcs=model_wcs)
# Determine overlap of all observations in pixel coordinates of the model frame
for c, obs in enumerate(observations):
obs_coord = obs.frame.convert_pixel_to(model_frame)
y_min, y_max = _minmax_int(obs_coord[:, 0])
x_min, x_max = _minmax_int(obs_coord[:, 1])
# +1 because Box.shape is a length, not a coordinate
this_box = Box.from_bounds((y_min, y_max + 1), (x_min, x_max + 1))
if c == 0:
model_box = this_box
else:
if coverage == "union":
model_box |= this_box
else:
model_box &= this_box
# update model_wcs to change NAXIS1/2 and CRPIX1/2, but don't change frame_origin!
model_wcs._naxis = list(model_wcs._naxis)
model_wcs._naxis[:2] = model_box.shape[::-1] # x/y needed here
model_wcs.wcs.crpix[:2] -= model_box.origin[::-1] # x/y needed here
# frame_origin = (0,) + model_box.origin
frame_shape = (len(channels),) + model_box.shape
model_frame = Frame(Box(shape=frame_shape), channels=channels, psf=model_psf, wcs=model_wcs)
return model_frame
def get_psf_size(psf):
"""Measures the size of a psf by computing the size of the area in 3 sigma around the center.
This is an approximate method to estimate the size of the psf for setting the size of the frame,
which does not require a precise measurement.
Parameters
----------
psf: `scarlet.PSF`
PSF for which to compute the size
Returns
-------
sigma3: `float`
radius of the area inside 3 sigma around the center in pixels
"""
# Normalisation by maximum
psf_frame = psf / jnp.max(psf)
# Pixels in the FWHM set to one, others to 0:
psf_frame = jnp.where(psf_frame > 0.5, 1.0, 0.0)
# Area in the FWHM:
area = jnp.sum(psf_frame)
# Diameter of this area
d = 2 * (area / jnp.pi) ** 0.5
# 3-sigma:
sigma3 = 3 * d / (2 * (2 * jnp.log(2)) ** 0.5)
return sigma3
def get_affine(wcs=None, linear=True):
"""Return the WCS transformation matrix
The transformation to intermediate world coordinates is given by the equation
$q = M\\cdot (p - r)$, where $p$ is the pixel coordinate, $r$ is `CRPIX`, and $M$ is the `CD` matrix.
This method provides the augmented matrix of the affine transformation:
$T = \begin{bmatrix} M & -M\\cdot r\\ 0 & 1\\end{bmatrix}$, for the extended vector $(p,1)$.
See Greisen & Calabretta (2002) for details.
Parameters
----------
wcs: `astropy.wcs.WCS`
WCS structure
linear: `bool`
Return only linear 2x2 matrix
Returns
-------
array (3x3 or 2x2)
"""
if wcs is None:
return jnp.diag(jnp.ones(3))
wcs_ = wcs.celestial
try:
m = wcs_.wcs.pc
except AttributeError:
try:
m = wcs_.cd
except AttributeError:
m = wcs_.wcs.cd
m = m[:2, :2] # avoid using channel information that is not declared "spectral" in the WCS
if linear:
return m
r = wcs_.wcs.crpix - 1 # CRPIX is 1-based!?!
b = -m @ r
t = jnp.zeros((3, 3))
t = t.at[:2, :2].set(m).at[:2, 2].set(b).at[2, 2].set(1)
return t
# for WCS linear matrix calculations:
# rotation matrix for counter-clockwise rotation from positive x-axis
# uses (x,y) coordinates and phi in radian!!
def _rot_matrix(phi, d=2):
sinphi, cosphi = jnp.sin(phi), jnp.cos(phi)
if d == 2:
return jnp.array([[cosphi, -sinphi], [sinphi, cosphi]])
else:
return jnp.array([[cosphi, -sinphi, 0], [sinphi, cosphi, 0], [0, 0, 1]])
# flip in y!!!
# uses (x,y) coordinates!
_flip_matrix = lambda flip, d=2: (
jnp.diag(jnp.array((1, flip), dtype=float)) if d == 2 else jnp.diag(jnp.array((1, flip, 1), dtype=float))
)
# 2x2 matrix determinant
_det2 = lambda m: m[0, 0] * m[1, 1] - m[0, 1] * m[1, 0]
# round coordinate to nearest integer (use python, not jnp)
_minmax_int = lambda x: tuple(int(f) for f in jnp.round(jnp.sort(x)[jnp.array([0, -1])])) # noqa:E731
# create trivial WCS for image with given shape
# scale = 1, pixel center in the middle of image
def _wcs_default(shape):
shape_ = shape[-2:][::-1] # x/y
wcs = astropy.wcs.WCS(naxis=2)
wcs._naxis = shape_
wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
wcs.wcs.crpix = jnp.array((shape_[0] // 2, shape_[1] // 2)) + 1 # 1-based pixel coordinates
return wcs
def get_scale_angle_flip_shift(trans):
"""Return, scale, angle, flip, translation from the WCS transformation matrix
Parameters
----------
trans: (`astropy.wcs.WCS`, array)
WCS or WCS transformation matrix
Returns
-------
scale: `float` or `astropy.units.Quantity`
angle: `float`, in radian
flip: -1 or 1
shift: `numpy.ndarray`
See Also
--------
get_affine
"""
if isinstance(trans, (np.ndarray, jnp.ndarray)): # noqa: SIM108
m = trans # noqa: N806
use_unit = False
else:
m = get_affine(trans) # noqa: N806
use_unit = True
# get shift and then reduce to 2x2 for linear part
if m.shape == (3, 3):
shift = tuple(b.item() for b in m[:2, 2])[::-1] # prevent tracing, y/x convention
m = m[:2, :2]
else:
shift = (0, 0)
det = _det2(m)
# this requires pixels to be square
# if not, use scale = jnp.linalg.svd(M, compute_uv=False)
# but be careful with rotations as anisotropic stretch and rotation do not commute
scale = jnp.sqrt(jnp.abs(det)).item(0)
if use_unit:
scale = scale * u.deg
# if rotation is improper: need to apply y-flip to M to get pure rotation matrix (and unique angle)
improper = det < 0
flip = -1 if improper else 1
f = _flip_matrix(flip) # noqa: N806, flip in y, is identity if flip = 1!!!
m_ = f @ m # noqa: N806, flip = inverse flip
angle = jnp.arctan2(m_[1, 0], m_[0, 0]).item()
return scale, angle, flip, shift
def get_relative_jacobian_shift(frame_in, frame_out):
"""Return the linear transformation matrix and shift between two frame WCSs
Parameters
----------
frame_in: `~scarlet2.Frame`
The frame that defines the origin of the transformation
frame_out: `~scarlet2.Frame`
The frame that defines the target of the transformation
Returns
-------
jacobian: jnp.ndarray
2x2 Jacobian matrix
shift: tuple
2D shift of the center of the frames
"""
# Extract rotation angle, flip, scale between WCSs
m_in = get_affine(frame_in.wcs)
m_out = get_affine(frame_out.wcs)
jacobian = jnp.linalg.inv(m_out) @ m_in # transformation from model pixel -> sky -> obs pixels
# shift can be defined by the extended 3x3 Jacobian of the affine transformation matrix,
# but it would ignore CRPIX/CRVAL difference between frmes
# so we define it from the shift of the center of the two frames
center_model = jnp.array(frame_in.bbox.spatial.center)
center_model_in_obs = frame_out.get_pixel(frame_in.get_sky_coord(center_model))
center_obs = jnp.array(frame_out.bbox.spatial.center)
shift = center_obs - center_model_in_obs
shift = tuple(c.item() for c in shift) # avoid tracing
return jacobian, shift
def get_pixel_size(wcs):
"""Extracts the pixel size from a wcs, and returns it in deg/pixel
Parameters
----------
wcs: `astropy.wcs.WCS`
WCS structure or transformation matrix
Returns
-------
pixel_size: `float`
"""
scale, _, _, _ = get_scale_angle_flip_shift(wcs)
return scale
def get_scale(wcs, separate=False):
"""Get WCS axis scales in deg/pixel
Parameters
----------
wcs: `astropy.wcs.WCS`
WCS structure or transformation matrix
separate: `bool`
Compute separate axis scales
Returns
-------
float
"""
if separate:
M = get_affine(wcs) # noqa: N806
c1 = (M[0, :] ** 2).sum() ** 0.5
c2 = (M[1, :] ** 2).sum() ** 0.5
scale = jnp.array([c1, c2]) * u.deg
return scale
else:
scale, _ = get_scale_angle_flip_shift(wcs)
return scale
def get_angle(wcs):
"""Get WCS rotation angle
The angle is computed counter-clockwise from the positive x-axis, in radians.
Parameters
----------
wcs: `astropy.wcs.WCS`
WCS structure or transformation matrix
Returns
-------
`astropy.units.quantity.Quantity`, unit = u.rad
"""
scale, angle, flip, shift = get_scale_angle_flip_shift(wcs)
return angle
def get_flip(wcs):
"""Return WCS sign convention
A negative sign means that the rotation is improper and requires a flip.
By convention, we define this to be a flip in the y-axis.
Parameters
----------
wcs: `astropy.wcs.WCS`
WCS structure or transformation matrix
Returns
-------
-1 or 1
"""
scale, angle, flip, shift = get_scale_angle_flip_shift(wcs)
return flip
def get_shift(wcs):
"""Return WCS shift
The WCS specify an affine transformation via the `CRPIX` keyword. This method
returns the affine shift parameter in standard form.
Parameters
----------
wcs: `astropy.wcs.WCS`
WCS structure or transformation matrix
Returns
-------
array
See Also
--------
get_affine
"""
scale, angle, flip, shift = get_scale_angle_flip_shift(wcs)
return shift