Source code for sparcl.specutils

"""Functions for converting SPARCL results to specutils objects.
"""
import numpy as np
try:
    # specutils >= 2.0
    from specutils import Spectrum
except ImportError:
    from specutils import Spectrum1D as Spectrum
from specutils import SpectrumCollection, SpectrumList
from astropy.nddata import InverseVariance
import astropy.units as u

# Default units used as fallback when results have no header (e.g. bare
# dict or _AttrDict records passed directly to conversion functions).
DEFAULT_FLUX_UNIT = '10**-17 erg cm-2 s-1 AA-1'
DEFAULT_WAVE_UNIT = 'AA'

def _get_units(results):
    """Extract flux and wavelength unit strings from a Results header.

    Falls back to module-level defaults when results is a bare dict or
    _AttrDict record (no header available). Raises if the header exists
    but unit values are None (i.e. units=False was used on retrieve/find).

    Parameters
    ----------
    results : Retrieved, dict, or _AttrDict
        Results object or single record.

    Returns
    -------
    flux_unit : str
        Unit string for flux, compatible with astropy.units.Unit().
    wave_unit : str
        Unit string for wavelength, compatible with astropy.units.Unit().

    Raises
    ------
    ValueError
        If the header is present but flux or wavelength units are None,
        meaning units=False was used on retrieve/find.
    """
    if not hasattr(results, 'hdr'):
        # Bare dict or _AttrDict, no header, use defaults
        return DEFAULT_FLUX_UNIT, DEFAULT_WAVE_UNIT

    flux_unit = results.unit_for('flux')
    wave_unit = results.unit_for('wavelength')

    if flux_unit is None:
        raise ValueError(
            "Flux unit is None in the results header. "
            "Cannot convert to specutils without unit information. "
            "Was units=False used on retrieve/find?")
    if wave_unit is None:
        raise ValueError(
            "Wavelength unit is None in the results header. "
            "Cannot convert to specutils without unit information. "
            "Was units=False used on retrieve/find?")

    return flux_unit, wave_unit

def _validate_records(records, r0, collection):
    """Validate that records can be converted to Spectrum.

    Parameters
    ----------
    records : list of dict
        All records to validate.
    r0 : dict
        First record, used as reference for validation.
    collection: bool
        If ``True``, attempt to convert to a
        :class:`~specutils.SpectrumCollection` instead.

    Raises
    ------
    ValueError
        If records lack 'wavelength' attribute.
        If records lack 'flux' attribute.
        If wavelength array lengths differ (suggests using SpectrumList).
        If wavelength pixel values differ and collection=False
        (suggests using SpectrumCollection).
    """

    # Check if the first record has wavelength data
    if 'wavelength' not in r0:
        raise ValueError("Results do not have a wavelength attribute. "
                         "Conversion is not possible.")

    if 'flux' not in r0:
        raise ValueError("Results do not have a flux attribute. "
                         "Conversion is not possible.")

    # Check if all records have the same number of wavelength points
    if not all([len(r.wavelength) == len(r0.wavelength) for r in records]):
        raise ValueError("Results do not have the same wavelength solution. "
                         "Consider using .to_SpectrumList instead.")

    # If not creating a SpectrumCollection, check that wavelength values
    # are identical across all records
    if not collection and not all([(r.wavelength == r0.wavelength).all()
                                   for r in records]):
        raise ValueError("Results do not have the same wavelength pixels. "
                         "Consider using SpectrumCollection instead.")


# Pre-compute the set of fields we've already handled so the metadata
# loop below can skip them quickly
_SKIP_FIELDS = {'flux', 'ivar', 'mask', 'model', 'wavelength'}

def _as_quantity(field, value, units_map):
    """Wrap value in a Quantity if a unit is known for field."""
    unit_str = units_map.get(field)
    if unit_str is not None:
        try:
            return value * u.Unit(unit_str)
        except (ValueError, TypeError):
            pass
    return value

def _extract_single(record, flux, optional_fields, redshift,
                    has_redshift, meta, units_map):
    """Extract one record into 1-D arrays and scalar metadata."""
    flux[:] = getattr(record, 'flux')
    for arr, field in optional_fields:
        arr[:] = getattr(record, field)
    if has_redshift:
        redshift.append(record.redshift)
    for attr, val in record.items():
        if attr not in _SKIP_FIELDS:
            meta[attr] = _as_quantity(attr, val, units_map)

def _extract_row(k, record, flux, optional_fields, redshift,
                 has_redshift, collection, spectral_axis, meta, units_map):
    """Extract one record into row *k* of 2-D arrays and list metadata."""
    flux[k, :] = getattr(record, 'flux')
    for arr, field in optional_fields:
        arr[k, :] = getattr(record, field)
    if has_redshift:
        redshift.append(record.redshift)
    if collection:
        spectral_axis[k, :] = record.wavelength
    for attr, val in record.items():
        if attr not in _SKIP_FIELDS:
            meta.setdefault(attr, []).append(_as_quantity(attr,
                                                          val,
                                                          units_map))

def _extract_record_data(records, flux, uncertainty, mask, model, redshift,
                         meta, spectral_axis, has_ivar, has_mask, has_model,
                         has_redshift, collection, single_record,
                         units_map=None):
    """Extract all data from records into arrays, then store any remaining
    attributes as metadata. All outputs are written in-place rather than
    returned.

    Parameters
    ----------
    records: list of dict
        Records containing flux, ivar, mask, wavelength, and optional
        model/redshift.
    flux : np.ndarray
        Pre-allocated array for flux values. Shape: (n_pixels,) if
        single_record, else (n_records, n_pixels).
    uncertainty : np.ndarray or None
        Pre-allocated array for inverse variance if has_ivar=True.
    mask : np.ndarray or None
        Pre-allocated array for data quality masks if has_mask=True.
    model : np.ndarray or None
        Pre-allocated array for model values if has_model=True.
    redshift : list
        Empty list to populate with redshift values.
    meta : dict
        Empty dict to populate with metadata. For a single record, values
        are stored as scalars. For multiple records, values are accumulated
        into lists.
    spectral_axis : np.ndarray
        For collections, pre-allocated 2D array to store per-record
        wavelength grids. For non-collections, this is just r0.wavelength
        (not modified).
    has_ivar : bool
        Whether records contain 'ivar' attribute.
    has_mask : bool
        Whether records contain 'mask' attribute.
    has_model : bool
        Whether records contain 'model' attribute.
    has_redshift : bool
        Whether records contain 'redshift' attribute.
    collection : bool
        If True, stores each record's wavelength array in spectral_axis.
    single_record : bool
        If True, writes to a 1D array and stores metadata as scalar values.
        If False, writes to rows of a 2D array and accumulates metadata
        values into lists.
    units_map : dict, optional
        Dict mapping field name to unit string, as returned by
        ``results.hdr['UNITS']`` for a given data release. When provided,
        metadata values whose field has a non-None unit entry are wrapped
        as :class:`~astropy.units.Quantity` objects.

    Returns
    -------
    None
        All outputs are written to the input arrays/containers in-place.
    """
    # Build a list of (array, field_name) pairs for optional fields so we
    # can loop over them instead of having a separate if-block for each one
    optional_fields = []
    if has_ivar:
        optional_fields.append((uncertainty, 'ivar'))
    if has_mask:
        optional_fields.append((mask, 'mask'))
    if has_model:
        optional_fields.append((model, 'model'))

    units_map = units_map or {}

    if single_record:
        if len(records) != 1:
            raise ValueError(
                f"single_record=True but {len(records)} records were provided")
        _extract_single(records[0], flux, optional_fields, redshift,
                        has_redshift, meta, units_map)
    else:
        for k, record in enumerate(records):
            _extract_row(k, record, flux, optional_fields, redshift,
                         has_redshift, collection, spectral_axis, meta,
                         units_map)

[docs] def to_Spectrum(results, *, collection=False, flux_unit=None, wave_unit=None): """Convert `results` to :class:`specutils.Spectrum`. Parameters ---------- results : :class:`sparcl.Results.Retrieved` Retrieved results, or a single record from a set of results. collection : bool, optional If ``True``, attempt to convert to a :class:`~specutils.SpectrumCollection` instead. flux_unit : str, optional Unit string for flux. If None, resolved via _get_units(). wave_unit : str, optional Unit string for wavelength. If None, resolved via _get_units(). Returns ------- :class:`~specutils.Spectrum` or :class:`~specutils.SpectrumCollection` The requested object. Raises ------ ValueError If `results` can't be converted to a :class:`~specutils.Spectrum` object in a valid way. For example, if some of the spectra have a different wavelength solution. """ if flux_unit is None or wave_unit is None: flux_unit, wave_unit = _get_units(results) # Prepare records if isinstance(results, dict): records = [results] r0 = results else: try: records = results.records if len(records) == 0: raise ValueError("No records found in results. Cannot " "convert empty results to Spectrum.") r0 = results.records[0] except (IndexError, AttributeError) as e: raise ValueError("No records found in results. Cannot " "convert empty results to Spectrum.") from e # Validate _validate_records(records, r0, collection) # Determine which optional data components exist in records has_redshift = 'redshift' in r0 has_model = 'model' in r0 has_ivar = 'ivar' in r0 has_mask = 'mask' in r0 n_pixels = r0.flux.shape[0] single_record = len(records) == 1 # Set flux shape based on number of records if single_record: flux_shape = (n_pixels,) else: flux_shape = (len(records), n_pixels) # Build spectral axis if collection: spectral_axis = np.empty((len(records), r0.wavelength.shape[0]), dtype=r0.wavelength.dtype) else: spectral_axis = r0.wavelength # Initialize arrays flux = np.empty(flux_shape, dtype=r0.flux.dtype) uncertainty = (np.empty(flux_shape, dtype=r0.ivar.dtype) if has_ivar else None) mask = np.empty(flux_shape, dtype=r0.mask.dtype) if has_mask else None model = np.empty(flux_shape, dtype=r0.model.dtype) if has_model else None redshift = [] meta = {} units_map = {} if hasattr(results, 'hdr'): try: dr = results.records[0].get('_dr') units_map = results.hdr.get('UNITS', {}).get(dr, {}) except (AttributeError, IndexError, KeyError): pass # Populate arrays by iterating through records _extract_record_data(records, flux, uncertainty, mask, model, redshift, meta, spectral_axis, has_ivar, has_mask, has_model, has_redshift, collection, single_record, units_map=units_map) # Convert redshift list to numpy array if exists if has_redshift: redshift = np.array(redshift) if single_record and len(redshift) == 1: # Convert to scalar if single record redshift = redshift[0] else: redshift = None # Add model to metadata if exists if has_model: meta['model'] = model # Prepare arguments common to both Spectrum and SpectrumCollection common_args = { 'flux': flux * u.Unit(flux_unit), 'spectral_axis': spectral_axis * u.Unit(wave_unit), 'uncertainty': InverseVariance(uncertainty) if has_ivar else None, 'mask': mask, 'meta': meta} if collection: return SpectrumCollection(**common_args) return Spectrum(**common_args, redshift=redshift)
[docs] def to_SpectrumList(results, *, flux_unit=None, wave_unit=None): """Convert `results` to :class:`specutils.SpectrumList`. Parameters ---------- results : :class:`sparcl.Results.Retrieved` Retrieved results. flux_unit : str, optional Unit string for flux. If None, resolved via _get_units(). wave_unit : str, optional Unit string for wavelength. If None, resolved via _get_units(). Returns ------- :class:`~specutils.SpectrumList` The requested object. """ if flux_unit is None or wave_unit is None: flux_unit, wave_unit = _get_units(results) s = SpectrumList() if isinstance(results, dict): records = [results] else: records = results.records for r in records: dr_units = {} if hasattr(results, 'hdr'): try: dr = r.get('_dr') dr_units = results.hdr.get('UNITS', {}).get(dr, {}) except (AttributeError, KeyError): pass redshift = r.redshift if 'redshift' in r else None uncertainty = InverseVariance(r.ivar) if 'ivar' in r else None mask = r.mask if 'mask' in r else None meta = {} for attribute in r: if attribute not in ('flux', 'wavelength', 'ivar', 'redshift', 'mask'): value = r[attribute] unit_str = dr_units.get(attribute) if unit_str is not None: try: value = value * u.Unit(unit_str) except (ValueError, TypeError): pass meta[attribute] = value s1 = Spectrum(flux=r.flux * u.Unit(flux_unit), spectral_axis=r.wavelength * u.Unit(wave_unit), uncertainty=uncertainty, redshift=redshift, mask=mask, meta=meta) s.append(s1) return s
[docs] def to_specutils(results): """Convert `results` to a specutils object. Parameters ---------- results : :class:`sparcl.Results.Retrieved` Retrieved results. Returns ------- :class:`~specutils.Spectrum` or :class:`~specutils.SpectrumCollection` or :class:`~specutils.SpectrumList` The most natural conversion to a specutils object. Raises ------ ValueError If no valid conversion can be performed, or if unit information is missing from the header. """ flux_unit, wave_unit = _get_units(results) try: # Try standard Spectrum conversion first s = to_Spectrum(results, flux_unit=flux_unit, wave_unit=wave_unit) except ValueError as ve: # Check the error message to determine appropriate alternative if 'SpectrumList' in str(ve): # Different wavelength array lengths use SpectrumList print("Returning a SpectrumList because the records have " "different spectral lengths.") s = to_SpectrumList(results, flux_unit=flux_unit, wave_unit=wave_unit) elif 'SpectrumCollection' in str(ve): # Same wavelength length but diff pixels use SpectrumCollection print("Returning a SpectrumCollection because the records have " "the same length but different spectral axes.") s = to_Spectrum(results, collection=True, flux_unit=flux_unit, wave_unit=wave_unit) else: raise return s