"""Functions for converting SPARCL results to specutils objects.
"""
import numpy as np
try:
# specutils >= 2.0
from specutils import Spectrum
except ImportError:
from specutils import Spectrum1D as Spectrum
from specutils import SpectrumCollection, SpectrumList
from astropy.nddata import InverseVariance
import astropy.units as u
# Default units used as fallback when results have no header (e.g. bare
# dict or _AttrDict records passed directly to conversion functions).
DEFAULT_FLUX_UNIT = '10**-17 erg cm-2 s-1 AA-1'
DEFAULT_WAVE_UNIT = 'AA'
def _get_units(results):
"""Extract flux and wavelength unit strings from a Results header.
Falls back to module-level defaults when results is a bare dict or
_AttrDict record (no header available). Raises if the header exists
but unit values are None (i.e. units=False was used on retrieve/find).
Parameters
----------
results : Retrieved, dict, or _AttrDict
Results object or single record.
Returns
-------
flux_unit : str
Unit string for flux, compatible with astropy.units.Unit().
wave_unit : str
Unit string for wavelength, compatible with astropy.units.Unit().
Raises
------
ValueError
If the header is present but flux or wavelength units are None,
meaning units=False was used on retrieve/find.
"""
if not hasattr(results, 'hdr'):
# Bare dict or _AttrDict, no header, use defaults
return DEFAULT_FLUX_UNIT, DEFAULT_WAVE_UNIT
flux_unit = results.unit_for('flux')
wave_unit = results.unit_for('wavelength')
if flux_unit is None:
raise ValueError(
"Flux unit is None in the results header. "
"Cannot convert to specutils without unit information. "
"Was units=False used on retrieve/find?")
if wave_unit is None:
raise ValueError(
"Wavelength unit is None in the results header. "
"Cannot convert to specutils without unit information. "
"Was units=False used on retrieve/find?")
return flux_unit, wave_unit
def _validate_records(records, r0, collection):
"""Validate that records can be converted to Spectrum.
Parameters
----------
records : list of dict
All records to validate.
r0 : dict
First record, used as reference for validation.
collection: bool
If ``True``, attempt to convert to a
:class:`~specutils.SpectrumCollection` instead.
Raises
------
ValueError
If records lack 'wavelength' attribute.
If records lack 'flux' attribute.
If wavelength array lengths differ (suggests using SpectrumList).
If wavelength pixel values differ and collection=False
(suggests using SpectrumCollection).
"""
# Check if the first record has wavelength data
if 'wavelength' not in r0:
raise ValueError("Results do not have a wavelength attribute. "
"Conversion is not possible.")
if 'flux' not in r0:
raise ValueError("Results do not have a flux attribute. "
"Conversion is not possible.")
# Check if all records have the same number of wavelength points
if not all([len(r.wavelength) == len(r0.wavelength) for r in records]):
raise ValueError("Results do not have the same wavelength solution. "
"Consider using .to_SpectrumList instead.")
# If not creating a SpectrumCollection, check that wavelength values
# are identical across all records
if not collection and not all([(r.wavelength == r0.wavelength).all()
for r in records]):
raise ValueError("Results do not have the same wavelength pixels. "
"Consider using SpectrumCollection instead.")
# Pre-compute the set of fields we've already handled so the metadata
# loop below can skip them quickly
_SKIP_FIELDS = {'flux', 'ivar', 'mask', 'model', 'wavelength'}
def _as_quantity(field, value, units_map):
"""Wrap value in a Quantity if a unit is known for field."""
unit_str = units_map.get(field)
if unit_str is not None:
try:
return value * u.Unit(unit_str)
except (ValueError, TypeError):
pass
return value
def _extract_single(record, flux, optional_fields, redshift,
has_redshift, meta, units_map):
"""Extract one record into 1-D arrays and scalar metadata."""
flux[:] = getattr(record, 'flux')
for arr, field in optional_fields:
arr[:] = getattr(record, field)
if has_redshift:
redshift.append(record.redshift)
for attr, val in record.items():
if attr not in _SKIP_FIELDS:
meta[attr] = _as_quantity(attr, val, units_map)
def _extract_row(k, record, flux, optional_fields, redshift,
has_redshift, collection, spectral_axis, meta, units_map):
"""Extract one record into row *k* of 2-D arrays and list metadata."""
flux[k, :] = getattr(record, 'flux')
for arr, field in optional_fields:
arr[k, :] = getattr(record, field)
if has_redshift:
redshift.append(record.redshift)
if collection:
spectral_axis[k, :] = record.wavelength
for attr, val in record.items():
if attr not in _SKIP_FIELDS:
meta.setdefault(attr, []).append(_as_quantity(attr,
val,
units_map))
def _extract_record_data(records, flux, uncertainty, mask, model, redshift,
meta, spectral_axis, has_ivar, has_mask, has_model,
has_redshift, collection, single_record,
units_map=None):
"""Extract all data from records into arrays, then store any remaining
attributes as metadata. All outputs are written in-place rather than
returned.
Parameters
----------
records: list of dict
Records containing flux, ivar, mask, wavelength, and optional
model/redshift.
flux : np.ndarray
Pre-allocated array for flux values. Shape: (n_pixels,) if
single_record, else (n_records, n_pixels).
uncertainty : np.ndarray or None
Pre-allocated array for inverse variance if has_ivar=True.
mask : np.ndarray or None
Pre-allocated array for data quality masks if has_mask=True.
model : np.ndarray or None
Pre-allocated array for model values if has_model=True.
redshift : list
Empty list to populate with redshift values.
meta : dict
Empty dict to populate with metadata. For a single record, values
are stored as scalars. For multiple records, values are accumulated
into lists.
spectral_axis : np.ndarray
For collections, pre-allocated 2D array to store per-record
wavelength grids. For non-collections, this is just r0.wavelength
(not modified).
has_ivar : bool
Whether records contain 'ivar' attribute.
has_mask : bool
Whether records contain 'mask' attribute.
has_model : bool
Whether records contain 'model' attribute.
has_redshift : bool
Whether records contain 'redshift' attribute.
collection : bool
If True, stores each record's wavelength array in spectral_axis.
single_record : bool
If True, writes to a 1D array and stores metadata as scalar values.
If False, writes to rows of a 2D array and accumulates metadata
values into lists.
units_map : dict, optional
Dict mapping field name to unit string, as returned by
``results.hdr['UNITS']`` for a given data release. When provided,
metadata values whose field has a non-None unit entry are wrapped
as :class:`~astropy.units.Quantity` objects.
Returns
-------
None
All outputs are written to the input arrays/containers in-place.
"""
# Build a list of (array, field_name) pairs for optional fields so we
# can loop over them instead of having a separate if-block for each one
optional_fields = []
if has_ivar:
optional_fields.append((uncertainty, 'ivar'))
if has_mask:
optional_fields.append((mask, 'mask'))
if has_model:
optional_fields.append((model, 'model'))
units_map = units_map or {}
if single_record:
if len(records) != 1:
raise ValueError(
f"single_record=True but {len(records)} records were provided")
_extract_single(records[0], flux, optional_fields, redshift,
has_redshift, meta, units_map)
else:
for k, record in enumerate(records):
_extract_row(k, record, flux, optional_fields, redshift,
has_redshift, collection, spectral_axis, meta,
units_map)
[docs]
def to_Spectrum(results, *, collection=False, flux_unit=None, wave_unit=None):
"""Convert `results` to :class:`specutils.Spectrum`.
Parameters
----------
results : :class:`sparcl.Results.Retrieved`
Retrieved results, or a single record from a set of results.
collection : bool, optional
If ``True``, attempt to convert to a
:class:`~specutils.SpectrumCollection` instead.
flux_unit : str, optional
Unit string for flux. If None, resolved via _get_units().
wave_unit : str, optional
Unit string for wavelength. If None, resolved via _get_units().
Returns
-------
:class:`~specutils.Spectrum` or :class:`~specutils.SpectrumCollection`
The requested object.
Raises
------
ValueError
If `results` can't be converted to a :class:`~specutils.Spectrum`
object in a valid way. For example, if some of the spectra have a
different wavelength solution.
"""
if flux_unit is None or wave_unit is None:
flux_unit, wave_unit = _get_units(results)
# Prepare records
if isinstance(results, dict):
records = [results]
r0 = results
else:
try:
records = results.records
if len(records) == 0:
raise ValueError("No records found in results. Cannot "
"convert empty results to Spectrum.")
r0 = results.records[0]
except (IndexError, AttributeError) as e:
raise ValueError("No records found in results. Cannot "
"convert empty results to Spectrum.") from e
# Validate
_validate_records(records, r0, collection)
# Determine which optional data components exist in records
has_redshift = 'redshift' in r0
has_model = 'model' in r0
has_ivar = 'ivar' in r0
has_mask = 'mask' in r0
n_pixels = r0.flux.shape[0]
single_record = len(records) == 1
# Set flux shape based on number of records
if single_record:
flux_shape = (n_pixels,)
else:
flux_shape = (len(records), n_pixels)
# Build spectral axis
if collection:
spectral_axis = np.empty((len(records), r0.wavelength.shape[0]),
dtype=r0.wavelength.dtype)
else:
spectral_axis = r0.wavelength
# Initialize arrays
flux = np.empty(flux_shape, dtype=r0.flux.dtype)
uncertainty = (np.empty(flux_shape, dtype=r0.ivar.dtype)
if has_ivar else None)
mask = np.empty(flux_shape, dtype=r0.mask.dtype) if has_mask else None
model = np.empty(flux_shape, dtype=r0.model.dtype) if has_model else None
redshift = []
meta = {}
units_map = {}
if hasattr(results, 'hdr'):
try:
dr = results.records[0].get('_dr')
units_map = results.hdr.get('UNITS', {}).get(dr, {})
except (AttributeError, IndexError, KeyError):
pass
# Populate arrays by iterating through records
_extract_record_data(records, flux, uncertainty, mask, model, redshift,
meta, spectral_axis, has_ivar, has_mask, has_model,
has_redshift, collection, single_record,
units_map=units_map)
# Convert redshift list to numpy array if exists
if has_redshift:
redshift = np.array(redshift)
if single_record and len(redshift) == 1:
# Convert to scalar if single record
redshift = redshift[0]
else:
redshift = None
# Add model to metadata if exists
if has_model:
meta['model'] = model
# Prepare arguments common to both Spectrum and SpectrumCollection
common_args = {
'flux': flux * u.Unit(flux_unit),
'spectral_axis': spectral_axis * u.Unit(wave_unit),
'uncertainty': InverseVariance(uncertainty) if has_ivar else None,
'mask': mask,
'meta': meta}
if collection:
return SpectrumCollection(**common_args)
return Spectrum(**common_args, redshift=redshift)
[docs]
def to_SpectrumList(results, *, flux_unit=None, wave_unit=None):
"""Convert `results` to :class:`specutils.SpectrumList`.
Parameters
----------
results : :class:`sparcl.Results.Retrieved`
Retrieved results.
flux_unit : str, optional
Unit string for flux. If None, resolved via _get_units().
wave_unit : str, optional
Unit string for wavelength. If None, resolved via _get_units().
Returns
-------
:class:`~specutils.SpectrumList`
The requested object.
"""
if flux_unit is None or wave_unit is None:
flux_unit, wave_unit = _get_units(results)
s = SpectrumList()
if isinstance(results, dict):
records = [results]
else:
records = results.records
for r in records:
dr_units = {}
if hasattr(results, 'hdr'):
try:
dr = r.get('_dr')
dr_units = results.hdr.get('UNITS', {}).get(dr, {})
except (AttributeError, KeyError):
pass
redshift = r.redshift if 'redshift' in r else None
uncertainty = InverseVariance(r.ivar) if 'ivar' in r else None
mask = r.mask if 'mask' in r else None
meta = {}
for attribute in r:
if attribute not in ('flux', 'wavelength', 'ivar',
'redshift', 'mask'):
value = r[attribute]
unit_str = dr_units.get(attribute)
if unit_str is not None:
try:
value = value * u.Unit(unit_str)
except (ValueError, TypeError):
pass
meta[attribute] = value
s1 = Spectrum(flux=r.flux * u.Unit(flux_unit),
spectral_axis=r.wavelength * u.Unit(wave_unit),
uncertainty=uncertainty,
redshift=redshift,
mask=mask,
meta=meta)
s.append(s1)
return s
[docs]
def to_specutils(results):
"""Convert `results` to a specutils object.
Parameters
----------
results : :class:`sparcl.Results.Retrieved`
Retrieved results.
Returns
-------
:class:`~specutils.Spectrum` or :class:`~specutils.SpectrumCollection`
or :class:`~specutils.SpectrumList`
The most natural conversion to a specutils object.
Raises
------
ValueError
If no valid conversion can be performed, or if unit information
is missing from the header.
"""
flux_unit, wave_unit = _get_units(results)
try:
# Try standard Spectrum conversion first
s = to_Spectrum(results, flux_unit=flux_unit, wave_unit=wave_unit)
except ValueError as ve:
# Check the error message to determine appropriate alternative
if 'SpectrumList' in str(ve):
# Different wavelength array lengths use SpectrumList
print("Returning a SpectrumList because the records have "
"different spectral lengths.")
s = to_SpectrumList(results, flux_unit=flux_unit,
wave_unit=wave_unit)
elif 'SpectrumCollection' in str(ve):
# Same wavelength length but diff pixels use SpectrumCollection
print("Returning a SpectrumCollection because the records have "
"the same length but different spectral axes.")
s = to_Spectrum(results, collection=True, flux_unit=flux_unit,
wave_unit=wave_unit)
else:
raise
return s