# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Defines coordinate frames and ties them to data axes.
"""
import numpy as np
from astropy.utils.misc import isiterable
import astropy.time
from astropy import units as u
from astropy import utils as astutil
from astropy import coordinates as coord
from astropy.wcs.wcsapi.low_level_api import (validate_physical_types,
VALID_UCDS)
__all__ = ['Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame',
'CoordinateFrame', 'TemporalFrame']
STANDARD_REFERENCE_FRAMES = [frame.upper() for frame in coord.builtin_frames.__all__]
STANDARD_REFERENCE_POSITION = ["GEOCENTER", "BARYCENTER", "HELIOCENTER",
"TOPOCENTER", "LSR", "LSRK", "LSRD",
"GALACTIC_CENTER", "LOCAL_GROUP_CENTER"]
[docs]class CoordinateFrame:
"""
Base class for Coordinate Frames.
Parameters
----------
naxes : int
Number of axes.
axes_type : str
One of ["SPATIAL", "SPECTRAL", "TIME"]
axes_order : tuple of int
A dimension in the input data that corresponds to this axis.
reference_frame : astropy.coordinates.builtin_frames
Reference frame (usually used with output_frame to convert to world coordinate objects).
reference_position : str
Reference position - one of `STANDARD_REFERENCE_POSITION`
unit : list of astropy.units.Unit
Unit for each axis.
axes_names : list
Names of the axes in this frame.
name : str
Name of this frame.
"""
def __init__(self, naxes, axes_type, axes_order, reference_frame=None,
reference_position=None, unit=None, axes_names=None,
name=None, axis_physical_types=None):
self._naxes = naxes
self._axes_order = tuple(axes_order)
if isinstance(axes_type, str):
self._axes_type = (axes_type,)
else:
self._axes_type = tuple(axes_type)
self._reference_frame = reference_frame
if unit is not None:
if astutil.isiterable(unit):
unit = tuple(unit)
else:
unit = (unit,)
if len(unit) != naxes:
raise ValueError("Number of units does not match number of axes.")
else:
self._unit = tuple([u.Unit(au) for au in unit])
else:
self._unit = tuple("" for na in range(naxes))
if axes_names is not None:
if isinstance(axes_names, str):
axes_names = (axes_names,)
else:
axes_names = tuple(axes_names)
if len(axes_names) != naxes:
raise ValueError("Number of axes names does not match number of axes.")
else:
axes_names = tuple([""] * naxes)
self._axes_names = axes_names
if name is None:
self._name = self.__class__.__name__
else:
self._name = name
self._reference_position = reference_position
if len(self._axes_type) != naxes:
raise ValueError("Length of axes_type does not match number of axes.")
if len(self._axes_order) != naxes:
raise ValueError("Length of axes_order does not match number of axes.")
super(CoordinateFrame, self).__init__()
self._axis_physical_types = self._set_axis_physical_types(axis_physical_types)
def _set_axis_physical_types(self, pht=None):
"""
Set the physical type of the coordinate axes using VO UCD1+ v1.23 definitions.
"""
if pht is not None:
if isinstance(pht, str):
pht = (pht,)
elif not isiterable(pht):
raise TypeError("axis_physical_types must be of type string or iterable of strings")
if len(pht) != self.naxes:
raise ValueError('"axis_physical_types" must be of length {}'.format(self.naxes))
ph_type = []
for axt in pht:
if axt not in VALID_UCDS and not axt.startswith("custom:"):
ph_type.append("custom:{}".format(axt))
else:
ph_type.append(axt)
validate_physical_types(ph_type)
return tuple(ph_type)
if isinstance(self, CelestialFrame):
if isinstance(self.reference_frame, coord.Galactic):
ph_type = "pos.galactic.lon", "pos.galactic.lat"
elif isinstance(self.reference_frame, (coord.GeocentricTrueEcliptic,
coord.GCRS,
coord.PrecessedGeocentric)):
ph_type = "pos.bodyrc.lon", "pos.bodyrc.lat"
elif isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame):
ph_type = "pos.eq.ra", "pos.eq.dec"
elif isinstance(self.reference_frame, coord.builtin_frames.BaseEclipticFrame):
ph_type = "pos.ecliptic.lon", "pos.ecliptic.lat"
else:
ph_type = tuple("custom:{}".format(t) for t in self.axes_names)
elif isinstance(self, SpectralFrame):
if self.unit[0].physical_type == "frequency":
ph_type = ("em.freq",)
elif self.unit[0].physical_type == "length":
ph_type = ("em.wl",)
elif isinstance(self, TemporalFrame):
ph_type = ("time",)
elif isinstance(self, Frame2D):
if all(self.axes_names):
ph_type = self.axes_names
else:
ph_type = self.axes_type
ph_type = tuple("custom:{}".format(t) for t in ph_type)
else:
ph_type = tuple("custom:{}".format(t) for t in self.axes_type)
validate_physical_types(ph_type)
return ph_type
def __repr__(self):
fmt = '<{0}(name="{1}", unit={2}, axes_names={3}, axes_order={4}'.format(
self.__class__.__name__, self.name,
self.unit, self.axes_names, self.axes_order)
if self.reference_position is not None:
fmt += ', reference_position="{0}"'.format(self.reference_position)
if self.reference_frame is not None:
fmt += ", reference_frame={0}".format(self.reference_frame)
fmt += ")>"
return fmt
def __str__(self):
if self._name is not None:
return self._name
else:
return self.__class__.__name__
@property
def name(self):
""" A custom name of this frame."""
return self._name
@name.setter
def name(self, val):
""" A custom name of this frame."""
self._name = val
@property
def naxes(self):
""" The number of axes in this frame."""
return self._naxes
@property
def unit(self):
"""The unit of this frame."""
return self._unit
@property
def axes_names(self):
""" Names of axes in the frame."""
return self._axes_names
@property
def axes_order(self):
""" A tuple of indices which map inputs to axes."""
return self._axes_order
@property
def reference_frame(self):
""" Reference frame, used to convert to world coordinate objects. """
return self._reference_frame
@property
def reference_position(self):
""" Reference Position. """
return getattr(self, "_reference_position", None)
@property
def axes_type(self):
""" Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'. """
return self._axes_type
[docs] def coordinates(self, *args):
""" Create world coordinates object"""
args = [args[i] for i in self.axes_order]
coo = tuple([arg * un if not hasattr(arg, "to") else arg.to(un) for arg, un in zip(args, self.unit)])
return coo
[docs] def coordinate_to_quantity(self, *coords):
"""
Given a rich coordinate object return an astropy quantity object.
"""
# NoOp leaves it to the model to handle
return coords
@property
def axis_physical_types(self):
return self._axis_physical_types
[docs]class CelestialFrame(CoordinateFrame):
"""
Celestial Frame Representation
Parameters
----------
axes_order : tuple of int
A dimension in the input data that corresponds to this axis.
reference_frame : astropy.coordinates.builtin_frames
A reference frame.
unit : str or units.Unit instance or iterable of those
Units on axes.
axes_names : list
Names of the axes in this frame.
name : str
Name of this frame.
"""
def __init__(self, axes_order=None, reference_frame=None,
unit=None, axes_names=None,
name=None, axis_physical_types=None):
naxes = 2
if reference_frame is not None:
if not isinstance(reference_frame, str):
if reference_frame.name.upper() in STANDARD_REFERENCE_FRAMES:
_axes_names = list(reference_frame.representation_component_names.values())
if 'distance' in _axes_names:
_axes_names.remove('distance')
if axes_names is None:
axes_names = _axes_names
naxes = len(_axes_names)
_unit = list(reference_frame.representation_component_units.values())
if unit is None and _unit:
unit = _unit
if axes_order is None:
axes_order = tuple(range(naxes))
if unit is None:
unit = tuple([u.degree] * naxes)
axes_type = ['SPATIAL'] * naxes
super(CelestialFrame, self).__init__(naxes=naxes, axes_type=axes_type,
axes_order=axes_order,
reference_frame=reference_frame,
unit=unit,
axes_names=axes_names,
name=name, axis_physical_types = axis_physical_types)
[docs] def coordinates(self, *args):
"""
Create a SkyCoord object.
Parameters
----------
args : float
inputs to wcs.input_frame
"""
if isinstance(args[0], coord.SkyCoord):
return args[0].transform_to(self.reference_frame)
else:
return coord.SkyCoord(*args, unit=self.unit, frame=self.reference_frame)
[docs] def coordinate_to_quantity(self, *coords):
if len(coords) == 2:
arg = coords
elif len(coords) == 1:
arg = coords[0]
else:
raise ValueError("Unexpected number of coordinates in "
"input to frame {} : "
"expected 2, got {}".format(self.name, len(coords)))
if isinstance(arg, coord.SkyCoord):
arg = arg.transform_to(self._reference_frame)
try:
lon = arg.data.lon
lat = arg.data.lat
except AttributeError:
lon = arg.spherical.lon
lat = arg.spherical.lat
return lon, lat
elif all(isinstance(a, u.Quantity) for a in arg):
return tuple(arg)
else:
raise ValueError("Could not convert input {} to lon and lat quantities.".format(arg))
[docs]class SpectralFrame(CoordinateFrame):
"""
Represents Spectral Frame
Parameters
----------
axes_order : tuple or int
A dimension in the input data that corresponds to this axis.
reference_frame : astropy.coordinates.builtin_frames
Reference frame (usually used with output_frame to convert to world coordinate objects).
unit : str or units.Unit instance
Spectral unit.
axes_names : str
Spectral axis name.
name : str
Name for this frame.
reference_position : str
Reference position - one of `STANDARD_REFERENCE_POSITION`
"""
def __init__(self, axes_order=(0,), reference_frame=None, unit=None,
axes_names=None, name=None, axis_physical_types=None,
reference_position=None):
super(SpectralFrame, self).__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order,
axes_names=axes_names, reference_frame=reference_frame,
unit=unit, name=name,
reference_position=reference_position,
axis_physical_types=axis_physical_types)
[docs] def coordinates(self, *args, equivalencies=[]):
if hasattr(args[0], 'unit'):
return args[0].to(self.unit[0], equivalencies=equivalencies)
if np.isscalar(args):
return args * self.unit[0]
else:
return args[0] * self.unit[0]
[docs] def coordinate_to_quantity(self, *coords):
if hasattr(coords[0], 'unit'):
return coords[0]
else:
return coords[0] * self.unit[0]
[docs]class TemporalFrame(CoordinateFrame):
"""
A coordinate frame for time axes.
Parameters
----------
axes_order : tuple or int
A dimension in the input data that corresponds to this axis.
reference_frame : `object`
The object to instantiate to represent the time coordinate. Defaults to
`astropy.time.Time`. Use partial functions to customise the
`~astropy.time.Time` instance.
reference_time : `astropy.time.Time` or `None`
Reference time, the time of the 0th coordinate. If none the values of
the axis are assumed to be valid times.
unit : str or units.Unit instance
Spectral unit.
axes_names : str
Spectral axis name.
name : str
Name for this frame.
"""
def __init__(self, axes_order=(0,), reference_time=None,
reference_frame=astropy.time.Time, unit=None,
axes_names=None, name=None, axis_physical_types=None):
super().__init__(naxes=1, axes_type="TIME", axes_order=axes_order,
axes_names=axes_names, reference_frame=reference_frame,
unit=unit, name=name,
reference_position=reference_time, axis_physical_types=axis_physical_types)
[docs] def coordinates(self, *args):
if np.isscalar(args):
dt = args
else:
dt = args[0]
if self.reference_position:
if not hasattr(dt, 'unit'):
dt = dt * self.unit[0]
return self.reference_position + dt
else:
return self.reference_frame(dt)
[docs] def coordinate_to_quantity(self, *coords):
if isinstance(coords[0], astropy.time.Time):
if self.reference_position:
return (coords[0] - self.reference_position).to(self.unit[0])
else:
# If we can't convert to a quantity just drop the object out
# and hope the transform can cope.
return coords[0]
# Is already a quantity
elif hasattr(coords[0], 'unit'):
return coords[0]
else:
raise ValueError("Can not convert {} to Quantity".format(coords[0]))
[docs]class CompositeFrame(CoordinateFrame):
"""
Represents one or more frames.
Parameters
----------
frames : list
List of frames (TemporalFrame, CelestialFrame, SpectralFrame, CoordinateFrame).
name : str
Name for this frame.
"""
def __init__(self, frames, name=None):
self._frames = frames[:]
naxes = sum([frame._naxes for frame in self._frames])
axes_type = list(range(naxes))
unit = list(range(naxes))
axes_names = list(range(naxes))
axes_order = []
ph_type = list(range(naxes))
for frame in frames:
axes_order.extend(frame.axes_order)
for frame in frames:
for ind, axtype, un, n, pht in zip(frame.axes_order, frame.axes_type,
frame.unit, frame.axes_names, frame.axis_physical_types):
axes_type[ind] = axtype
axes_names[ind] = n
unit[ind] = un
ph_type[ind] = pht
if len(np.unique(axes_order)) != len(axes_order):
raise ValueError("Incorrect numbering of axes, "
"axes_order should contain unique numbers, "
"got {}.".format(axes_order))
super(CompositeFrame, self).__init__(naxes, axes_type=axes_type,
axes_order=axes_order,
unit=unit, axes_names=axes_names,
name=name)
self._axis_physical_types = tuple(ph_type)
@property
def frames(self):
return self._frames
def __repr__(self):
return repr(self.frames)
[docs] def coordinates(self, *args):
coo = []
if len(args) == len(self.frames):
for frame, arg in zip(self.frames, args):
coo.append(frame.coordinates(arg))
else:
for frame in self.frames:
fargs = [args[i] for i in frame.axes_order]
coo.append(frame.coordinates(*fargs))
return coo
[docs] def coordinate_to_quantity(self, *coords):
if len(coords) == len(self.frames):
args = coords
elif len(coords) == self.naxes:
args = []
for _frame in self.frames:
if _frame.naxes > 1:
# Collect the arguments for this frame based on axes_order
args.append([coords[i] for i in _frame.axes_order])
else:
args.append(coords[_frame.axes_order[0]])
else:
raise ValueError("Incorrect number of arguments")
qs = []
for _frame, arg in zip(self.frames, args):
ret = _frame.coordinate_to_quantity(arg)
if isinstance(ret, tuple):
qs += list(ret)
else:
qs.append(ret)
return qs
class StokesFrame(CoordinateFrame):
"""
A coordinate frame for representing stokes polarisation states
Parameters
----------
name : str
Name of this frame.
"""
def __init__(self, axes_order=(0,), name=None):
self._stokes_components = ['I', 'Q', 'U', 'V']
super(StokesFrame, self).__init__(1, ["STOKES"], axes_order, name=name,
axes_names=("stokes",), unit=u.one)
def coordinates(self, *args):
if hasattr(args[0], 'value'):
arg = args[0].value
else:
arg = args[0]
return self._stokes_components[int(arg)]
def coordinate_to_quantity(self, *coords):
if isinstance(coords[0], str):
if coords[0] in self._stokes_components:
return self._stokes_components.index(coords[0]) * u.pix
else:
return coords[0]
[docs]class Frame2D(CoordinateFrame):
"""
A 2D coordinate frame.
Parameters
----------
axes_order : tuple of int
A dimension in the input data that corresponds to this axis.
unit : list of astropy.units.Unit
Unit for each axis.
axes_names : list
Names of the axes in this frame.
name : str
Name of this frame.
"""
def __init__(self, axes_order=(0, 1), unit=(u.pix, u.pix), axes_names=('x', 'y'),
name=None, axis_physical_types=None):
super(Frame2D, self).__init__(naxes=2, axes_type=["SPATIAL", "SPATIAL"],
axes_order=axes_order, name=name,
axes_names=axes_names, unit=unit,
axis_physical_types=axis_physical_types)
[docs] def coordinates(self, *args):
args = [args[i] for i in self.axes_order]
coo = tuple([arg * un for arg, un in zip(args, self.unit)])
return coo
[docs] def coordinate_to_quantity(self, *coords):
# list or tuple
if len(coords) == 1 and astutil.isiterable(coords[0]):
coords = list(coords[0])
elif len(coords) == 2:
coords = list(coords)
else:
raise ValueError("Unexpected number of coordinates in "
"input to frame {} : "
"expected 2, got {}".format(self.name, len(coords)))
for i in range(2):
if not hasattr(coords[i], 'unit'):
coords[i] = coords[i] * self.unit[i]
return tuple(coords)