Source code for scarlet2.interpolation

"""Interpolation methods

Some of the code to perform interpolation in Fourier space as been adapted from
https://github.com/GalSim-developers/JAX-GalSim/blob/main/jax_galsim/interpolant.py
https://github.com/GalSim-developers/JAX-GalSim/blob/main/jax_galsim/interpolatedimage.py
"""

import equinox as eqx
import jax
import jax.numpy as jnp


### Interpolant class
[docs] class Interpolant(eqx.Module): """Base class for interpolants""" extent: int """Size of the interpolation kernel"""
[docs] def __call__(self): """Code to execute when the class is called""" raise NotImplementedError
[docs] def kernel(self, x): """Evaluate the kernel in configuration space at location `x` Parameters ---------- x: float or array Position in real space Returns ------- float or array """ raise NotImplementedError
[docs] def uval(self, u): """Evaluate the kernel in Fourier space at frequency `u` Parameters ---------- u: complex or array Position in Fourier space Returns ------- complex or array """ raise NotImplementedError
### Quintic interpolant
[docs] class Quintic(Interpolant): """Quintic interpolation from Gruen & Bernstein (2014)""" def __init__(self): self.extent = 3 def _f_0_1_q(self, x): return 1 + x * x * x * (-95 + 138 * x - 55 * x * x) / 12 def _f_1_2_q(self, x): return (x - 1) * (x - 2) * (-138 + 348 * x - 249 * x * x + 55 * x * x * x) / 24 def _f_2_3_q(self, x): return (x - 2) * (x - 3) * (x - 3) * (-54 + 50 * x - 11 * x * x) / 24 def _f_3_q(self, x): return jnp.zeros_like(x, dtype=x.dtype)
[docs] def kernel(self, x): """See parent class""" x = jnp.abs(x) # quitic kernel is even b1 = x <= 1 b2 = x <= 2 b3 = x <= 3 return jnp.piecewise( x, [b1, (~b1) & b2, (~b2) & b3], [self._f_0_1_q, self._f_1_2_q, self._f_2_3_q, self._f_3_q] )
[docs] def uval(self, u): """See parent class""" u = jnp.abs(u) s = jnp.sinc(u) piu = jnp.pi * u c = jnp.cos(piu) ssq = s * s piusq = piu * piu return s * ssq * ssq * (s * (55.0 - 19.0 * piusq) + 2.0 * c * (piusq - 27.0))
[docs] class Lanczos(Interpolant): """Lanczos interpolation class""" def __init__(self, n): """Lanczos interpolant Parameters ---------- n: int Lanczos order """ self.extent = n def _f_1(self, x, n): # Approximation from galsim # res = n/(pi x)^2 * sin(pi x) * sin(pi x / n) # ~= (1 - 1/6 pix^2) * (1 - 1/6 pix^2 / n^2) # = 1 - 1/6 pix^2 ( 1 + 1/n^2 ) px = jnp.pi * x temp = 1.0 / 6.0 * px * px res = 1.0 - temp * (1.0 + 1.0 / (n * n)) return res def _f_2(self, x, n): px = jnp.pi * x return n * jnp.sin(px) * jnp.sin(px / n) / px / px def _lanczos_n(self, x, n=3): small_x = jnp.abs(x) <= 1e-4 window_n = jnp.abs(x) <= n return jnp.piecewise( x, [small_x, (~small_x) & window_n], [self._f_1, self._f_2, lambda x, n: jnp.array(0)], n )
[docs] def kernel(self, x): """See parent class""" return self._lanczos_n(x, self.extent)
### Resampling function
[docs] def resample2d(signal, coords, warp, interpolant=Lanczos(3)): # noqa: B008 """Resample a 2-dimensional image using a Lanczos kernel Parameters ---------- signal: array 2d array containing the signal. We assume here that the coordinates of the signal. Shape: `[Ny, Nx]` coords: array Coordinates on which the signal is sampled. Shape: `[Ny, Nx, 2]` y-coordinates are `coords[0,:,0]`, x-coordinates are `coords[:,0,1]`. warp: array Coordinates on which to resample the signal. Shape:[ny, nx, 2] [ [[0, 0], [0, 1], ..., [0, N-1]], [ ... ], [[N-1,0], [N-1,1], ..., [N-1,N ]] ] interpolant: Interpolant Instance of interpolant Returns ------- array Resampled `signal` at the location indicated by `warp` """ y = warp[..., 0].flatten() x = warp[..., 1].flatten() coords_y = coords[0, :, 0] coords_x = coords[:, 0, 1] h = coords_x[1] - coords_x[0] xi = jnp.floor((x - coords_x[0]) / h).astype(jnp.int32) yi = jnp.floor((y - coords_y[0]) / h).astype(jnp.int32) n_y = coords.shape[0] n_x = coords.shape[1] def body_fun_x(i, args): res, yind, ky, masky = args xind = xi + i maskx = (xind >= 0) & (xind < n_x) kx = interpolant.kernel((x - coords_x[xind]) / h) k = kx * ky mask = maskx & masky res += jnp.where(mask, k * signal[yind, xind], 0) return res, yind, ky, masky def body_fun_y(i, args): res = args yind = yi + i masky = (yind >= 0) & (yind < n_y) ky = interpolant.kernel((y - coords_y[yind]) / h) res = jax.lax.fori_loop( -interpolant.extent, interpolant.extent + 1, body_fun_x, (res, yind, ky, masky) )[0] return res res = jax.lax.fori_loop( -interpolant.extent, interpolant.extent + 1, body_fun_y, jnp.zeros_like(x).astype(signal.dtype) ) return res.reshape(warp[..., 0].shape)
# @partial(jax.jit, static_argnums=(3))
[docs] def resample3d(signal, coords, warp, interpolant): """Resample a 3-dimensional image using a Lanczos kernel Parameters ---------- signal: array 3d array containing the signal. We assume here that the coordinates of the signal. Shape: `[C, Ny, Nx]` coords: array Coordinates on which the signal is sampled. Shape: `[Ny, Nx, 2]` y-coordinates are `coords[0,:,0]`, x-coordinates are `coords[:,0,1]`. warp: array Coordinates on which to resample the signal. Shape:[ny, nx, 2] [ [[0, 0], [0, 1], ..., [0, N-1]], [ ... ], [[N-1,0], [N-1,1], ..., [N-1,N ]] ] interpolant: Interpolant Instance of interpolant Returns ------- array Resampled `signal` at the location indicated by `warp` See Also -------- resample2d """ _resample2d = lambda s: resample2d(s, coords, warp, interpolant=interpolant) return jax.vmap(_resample2d, in_axes=0, out_axes=0)(signal)
[docs] def resample_hermitian(signal, warp, x_min, y_min, interpolant=Quintic()): # noqa: B008 """Resample a 2-dimensional image using an interpolation kernel This is assuming that the signal is Hermitian and starting at 0 on axis=2, i.e. f(-x, -y) = conjugate(f(x, y)) Parameters ---------- signal: array 2d array containing the signal. We assume here that the coordinates of the signal shape: [Nx, Ny] warp: array Coordinates on which to resample the signal, in the grid of signal coordinates [[0 ... signal.shape[0]], [0 ... signal.shape[1]] shape:[nx, ny, 2] [ [[0, 0], [0, 1], ..., [0, N-1]], [ ... ], [[N-1,0], [N-1,1], ..., [N-1,N ]] ] x_min: float Left coordinate of corner of bounding box that defines the location of `signal` y_min: float Low coordinate of corner of bounding box that defines the location of `signal` interpolant: Interpolant Instance of interpolant Returns ------- array Resampled `signal` at the location indicated by `warp` """ x = warp[..., 0].flatten() y = warp[..., 1].flatten() xi = jnp.floor(x - x_min).astype(jnp.int32) yi = jnp.floor(y - y_min).astype(jnp.int32) xp = xi + x_min yp = yi + y_min nkx_2 = signal.shape[1] - 1 nkx = signal.shape[0] def body_fun_x(i, args): res, yind, ky = args xind = (xi + i) % nkx kx = interpolant.kernel(x - (xp + i)) k = kx * ky tmp = jnp.where( xind < nkx_2, signal[(nkx - yind) % nkx, nkx - xind - nkx_2].conjugate(), signal[yind, xind - nkx_2], ) res += tmp * k return res, yind, ky def body_fun_y(i, args): res = args yind = yi + i ky = interpolant.kernel(y - (yp + i)) res = jax.lax.fori_loop(-interpolant.extent, interpolant.extent + 1, body_fun_x, (res, yind, ky))[0] return res res = jax.lax.fori_loop( -interpolant.extent, interpolant.extent + 1, body_fun_y, jnp.zeros_like(x).astype(signal.dtype) ) return res.reshape(warp[..., 0].shape)
[docs] def resample_fourier( kimage, shape_in, shape_out, jacobian=None, shift=None, interpolant=Quintic(), # noqa: B008 ): """Resampling operation This method uses the Fourier space resampling technique from Gruen & Bernstein (2014). It assumes that the signal is Hermitian and starting at 0 on axis=2, i.e. f(-x, -y) = conjugate(f(x, y)) Parameters ---------- kimage: array Complex array of image in Fourier space shape_in: tuple Shape of input image in configuration space shape_out: tuple Shape of output image in configuration space jacobian: jnp.array 2x2 transformation matrix from warped to unwarped coordinates (in x/y convention) shift: tuple Shift of the output image (in units of output pixels) interpolant: Interpolant Interpolation kernel function Returns ------- array """ # Apply rescaling to the frequencies # [0, Fe/2+1] # [-Fe/2+1, Fe/2] kcoords_out = jnp.stack( jnp.meshgrid( jnp.linspace(0, 1 / 2, shape_out // 2 + 1), jnp.linspace(-1 / 2, 1 / 2, shape_out), ), -1, ) # Apply scale, rotation, flip to the frequencies: inverse transpose in k-space than in x-space: # FFT[f(Jx)] \propto FFT[f](J^-T k) # by virtue of the affine theorem. # (Ronald N. Bracewell, Fourier Analysis and Imaging, Springer (2003), p. 159–161) # But the resampling jacobian goes from model to obs/warped coordinates, and the Fourier resampling needs # to express the warped frequencies in the coordinates of the model (i.e. the mapping J^-1), # we only need to transpose the jacobian below. if jacobian is None: kcoords_unwarped = kcoords_out else: b_shape = kcoords_out.shape kcoords_unwarped = (jacobian.T @ kcoords_out.reshape((-1, 2)).T).T.reshape(b_shape) # k interpolation of original signal # TODO: why do we need to multiply shape_in into coordinates? k_resampled = jax.vmap(resample_hermitian, in_axes=(0, None, None, None, None))( kimage, kcoords_unwarped * shape_in, -shape_in / 2, -shape_in / 2, interpolant ) # fft of x-interpolant xint_val = interpolant.uval(kcoords_out[..., 0]) * interpolant.uval(kcoords_out[..., 1]) # apply shift if shift is not None: shift_ = shift[::-1] # x,y needed here pfac = jnp.exp(-1j * 2 * jnp.pi * (kcoords_out @ shift_))[..., :, :] else: pfac = 1 return k_resampled * jnp.expand_dims(xint_val, 0) * pfac