Source code for vallenae.io.datatypes

# ruff: noqa: E501

from __future__ import annotations

from dataclasses import dataclass, field
from enum import IntEnum, IntFlag
from typing import Any

import numpy as np

from .compression import decode_data_blob


def _to_volts(value: float | None):
    """Convert µV to V if not None."""
    if value is None:
        return None
    return float(value) / 1e6


def _to_seconds(value: float | None):
    """Convert µs to s if not None."""
    if value is None:
        return None
    return float(value) / 1e6


[docs] class SetType(IntEnum): """Record types of the pridb.""" # fmt: off NONE = 0 #: None PARAMETRIC = 1 #: Parametric record HIT = 2 #: Hit record STATUS = 3 #: Status record LABEL = 4 #: Label marker DATETIME = 5 #: Datetime marker, inserted by the acquisition software whenever recording is started SECTION = 6 #: Section marker, e.g. if acquisition settings changed
# fmt: on
[docs] class HitFlags(IntFlag): """ Hit record flags. """ # fmt: off TIMEOUT = 1 << 0 #: Timeout aborted signal, indicates a very long hit (> 100 ms) AFTER_TIMEOUT = 1 << 8 #: After timeout signal, indicates an artificially started hit after hit timeout DDT_TOO_SHORT = 1 << 9 #: DDT too short, indicates that channel(s) switched into "long duration mode" because the buffer has run half full SATURATION = 1 << 4 #: Saturation, indicates that the signal exceeds 95% of the input range PULSE_SENT = 1 << 11 #: Pulse (sent), indicates that the signal was generated by the test pulse (sending channel) PULSE_RECEIVED = 1 << 6 #: Pulse (received), indicates that the signal was generated by the test pulse (receiving channel) TR_TRIGGER = 1 << 15 #: TR trigger, indicates if transient data record belongs to a hit (only used in tradb)
# fmt: on
[docs] class StatusFlags(IntFlag): """ Status flags. """ # fmt: off PULSE_ACTIVE = 1 << 1 #: Pulsing active AE_ENABLED = 1 << 2 #: AE recording enabled TR_ENABLED = 1 << 3 #: TR recording enabled
# fmt: on
[docs] @dataclass class HitRecord: """ Hit record in pridb (`SetType.HIT`). """ # fmt: off time: float #: Time in seconds channel: int #: Channel number param_id: int #: Parameter ID of table ae_params for ADC value conversion amplitude: float #: Peak amplitude in volts duration: float #: Hit duration in seconds energy: float #: Energy (EN 1330-9) in eu (1e-14 V²s) rms: float #: RMS of the noise before the hit in volts # optional for creating: set_id: int | None = None #: Unique identifier for data set in pridb status: HitFlags = field(default=HitFlags(0)) #: Status flags threshold: float | None = None #: Threshold amplitude in volts rise_time: float | None = None #: Rise time in seconds signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs) counts: int | None = None #: Number of positive threshold crossings trai: int | None = None #: Transient recorder index (foreign key between pridb and tradb) cascade_hits: int | None = None #: Total number of hits in the same hit-cascade cascade_counts: int | None = None #: Summed counts of hits in the same hit-cascade cascade_energy: int | None = None #: Summed energy of hits in the same hit-cascade cascade_signal_strength: int | None = None #: Summed signal strength of hits in the same hit-cascade # fmt: on
[docs] @classmethod def from_sql(cls, row: dict[str, Any]) -> "HitRecord": """ Create `HitRecord` from SQL row. Args: row: Dict of column names and values """ return cls( set_id=row["SetID"], time=row["Time"], channel=row["Chan"], status=row["Status"], param_id=row["ParamID"], threshold=_to_volts(row.get("Thr")), # optional amplitude=_to_volts(row["Amp"]), rise_time=_to_seconds(row.get("RiseT")), # optional duration=_to_seconds(row["Dur"]), energy=row["Eny"], signal_strength=row.get("SS"), # optional for spotWave rms=_to_volts(row["RMS"]), counts=row.get("Counts"), # optional trai=row.get("TRAI"), # optional cascade_hits=row.get("CHits"), # optional cascade_counts=row.get("CCnt"), # optional cascade_energy=row.get("CEny"), # optional cascade_signal_strength=row.get("CSS"), # optional )
[docs] @dataclass class MarkerRecord: """ Marker record in pridb (`SetType.LABEL`, `SetType.DATETIME`, `SetType.SECTION`). """ time: float #: Time in seconds set_type: SetType #: Marker type (see above) data: str #: Content of marker (label text or datetime) # optional for creating: number: int | None = None #: Marker number set_id: int | None = None #: Unique identifier for data set in pridb
[docs] @classmethod def from_sql(cls, row: dict[str, Any]) -> "MarkerRecord": """ Create `MarkerRecord` from SQL row. Args: row: Dict of column names and values """ return cls( set_id=row["SetID"], time=row["Time"], set_type=row["SetType"], number=row["Number"], data=row["Data"], )
[docs] @dataclass class StatusRecord: """ Status data record in pridb (`SetType.STATUS`). """ time: float #: Time in seconds channel: int #: Channel number param_id: int #: Parameter ID of table ae_params for ADC value conversion energy: float #: Energy (EN 1330-9) in eu (1e-14 V²s) rms: float #: RMS in volts # optional for creating: set_id: int | None = None #: Unique identifier for data set in pridb status: StatusFlags = field(default=StatusFlags(0)) #: Status flags threshold: float | None = None #: Threshold amplitude in volts signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs)
[docs] @classmethod def from_sql(cls, row: dict[str, Any]) -> "StatusRecord": """ Create `StatusRecord` from SQL row. Args: row: Dict of column names and values """ return cls( set_id=row["SetID"], time=row["Time"], channel=row["Chan"], status=row["Status"], param_id=row["ParamID"], threshold=_to_volts(row.get("Thr")), # optional energy=row["Eny"], signal_strength=row.get("SS"), # optional for spotWave rms=_to_volts(row["RMS"]), # optional )
[docs] @dataclass class ParametricRecord: """ Parametric data record in pridb (`SetType.PARAMETRIC`). """ time: float #: Time in seconds param_id: int #: Parameter ID of table ae_params for ADC value conversion # optional for creating: set_id: int | None = None #: Unique identifier for data set in pridb status: StatusFlags = field(default=StatusFlags(0)) #: Status flags pctd: int | None = None #: Digital counter value pcta: int | None = None #: Analog hysteresis counter pa0: int | None = None #: Amplitude of parametric input 0 in volts pa1: int | None = None #: Amplitude of parametric input 1 in volts pa2: int | None = None #: Amplitude of parametric input 2 in volts pa3: int | None = None #: Amplitude of parametric input 3 in volts pa4: int | None = None #: Amplitude of parametric input 4 in volts pa5: int | None = None #: Amplitude of parametric input 5 in volts pa6: int | None = None #: Amplitude of parametric input 6 in volts pa7: int | None = None #: Amplitude of parametric input 7 in volts
[docs] @classmethod def from_sql(cls, row: dict[str, Any]) -> "ParametricRecord": """ Create `ParametricRecord` from SQL row. Args: row: Dict of column names and values """ return cls( set_id=row["SetID"], time=row["Time"], status=row["Status"], param_id=row["ParamID"], pctd=row.get("PCTD"), # optional pcta=row.get("PCTA"), # optional pa0=row.get("PA0"), # optional pa1=row.get("PA1"), # optional pa2=row.get("PA2"), # optional pa3=row.get("PA3"), # optional pa4=row.get("PA4"), # optional pa5=row.get("PA5"), # optional pa6=row.get("PA6"), # optional pa7=row.get("PA7"), # optional )
[docs] @dataclass class TraRecord: """Transient data record in tradb.""" time: float #: Time in seconds channel: int #: Channel number param_id: int #: Parameter ID of table tr_params for ADC value conversion pretrigger: int #: Pretrigger samples threshold: float #: Threshold amplitude in volts samplerate: int #: Samplerate in Hz samples: int #: Number of samples data: np.ndarray #: Transient signal in volts or ADC values if `raw` = `True` # optional for writing status: HitFlags = field(default=HitFlags(0)) #: Status flags trai: int | None = None #: Transient recorder index (foreign key between pridb and tradb) rms: float | None = None #: RMS of the noise before the hit # optional raw: bool = False #: `data` is stored as ADC values (int16)
[docs] @classmethod def from_sql(cls, row: dict[str, Any], *, raw: bool = False) -> "TraRecord": """ Create `TraRecord` from SQL row. Args: row: Dict of column names and values raw: Provide `data` as ADC values (int16) """ return TraRecord( time=row["Time"], channel=row["Chan"], status=row["Status"], param_id=row["ParamID"], pretrigger=row["Pretrigger"], threshold=_to_volts(row["Thr"]), samplerate=row["SampleRate"], samples=row["Samples"], data=decode_data_blob(row["Data"], row["DataFormat"], row["TR_mV"], raw=raw), trai=row["TRAI"], raw=raw, )
[docs] @dataclass class FeatureRecord: """ Transient feature record in trfdb. """ trai: int #: Transient recorder index features: dict[str, float] #: Feature dictionary (feature name -> value)
[docs] @classmethod def from_sql(cls, row: dict[str, Any]) -> "FeatureRecord": """ Create `FeatureRecord` from SQL row. Args: row: Dict of column names and values """ return FeatureRecord( trai=row.pop("TRAI"), features=row, )