# ruff: noqa: E501
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum, IntFlag
from typing import Any
import numpy as np
from .compression import decode_data_blob
def _to_volts(value: float | None):
"""Convert µV to V if not None."""
if value is None:
return None
return float(value) / 1e6
def _to_seconds(value: float | None):
"""Convert µs to s if not None."""
if value is None:
return None
return float(value) / 1e6
[docs]
class SetType(IntEnum):
"""Record types of the pridb."""
# fmt: off
NONE = 0 #: None
PARAMETRIC = 1 #: Parametric record
HIT = 2 #: Hit record
STATUS = 3 #: Status record
LABEL = 4 #: Label marker
DATETIME = 5 #: Datetime marker, inserted by the acquisition software whenever recording is started
SECTION = 6 #: Section marker, e.g. if acquisition settings changed
# fmt: on
[docs]
class HitFlags(IntFlag):
"""
Hit record flags.
"""
# fmt: off
TIMEOUT = 1 << 0 #: Timeout aborted signal, indicates a very long hit (> 100 ms)
AFTER_TIMEOUT = 1 << 8 #: After timeout signal, indicates an artificially started hit after hit timeout
DDT_TOO_SHORT = 1 << 9 #: DDT too short, indicates that channel(s) switched into "long duration mode" because the buffer has run half full
SATURATION = 1 << 4 #: Saturation, indicates that the signal exceeds 95% of the input range
PULSE_SENT = 1 << 11 #: Pulse (sent), indicates that the signal was generated by the test pulse (sending channel)
PULSE_RECEIVED = 1 << 6 #: Pulse (received), indicates that the signal was generated by the test pulse (receiving channel)
TR_TRIGGER = 1 << 15 #: TR trigger, indicates if transient data record belongs to a hit (only used in tradb)
# fmt: on
[docs]
class StatusFlags(IntFlag):
"""
Status flags.
"""
# fmt: off
PULSE_ACTIVE = 1 << 1 #: Pulsing active
AE_ENABLED = 1 << 2 #: AE recording enabled
TR_ENABLED = 1 << 3 #: TR recording enabled
# fmt: on
[docs]
@dataclass
class HitRecord:
"""
Hit record in pridb (`SetType.HIT`).
"""
# fmt: off
time: float #: Time in seconds
channel: int #: Channel number
param_id: int #: Parameter ID of table ae_params for ADC value conversion
amplitude: float #: Peak amplitude in volts
duration: float #: Hit duration in seconds
energy: float #: Energy (EN 1330-9) in eu (1e-14 V²s)
rms: float #: RMS of the noise before the hit in volts
# optional for creating:
set_id: int | None = None #: Unique identifier for data set in pridb
status: HitFlags = field(default=HitFlags(0)) #: Status flags
threshold: float | None = None #: Threshold amplitude in volts
rise_time: float | None = None #: Rise time in seconds
signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs)
counts: int | None = None #: Number of positive threshold crossings
trai: int | None = None #: Transient recorder index (foreign key between pridb and tradb)
cascade_hits: int | None = None #: Total number of hits in the same hit-cascade
cascade_counts: int | None = None #: Summed counts of hits in the same hit-cascade
cascade_energy: int | None = None #: Summed energy of hits in the same hit-cascade
cascade_signal_strength: int | None = None #: Summed signal strength of hits in the same hit-cascade
# fmt: on
[docs]
@classmethod
def from_sql(cls, row: dict[str, Any]) -> "HitRecord":
"""
Create `HitRecord` from SQL row.
Args:
row: Dict of column names and values
"""
return cls(
set_id=row["SetID"],
time=row["Time"],
channel=row["Chan"],
status=row["Status"],
param_id=row["ParamID"],
threshold=_to_volts(row.get("Thr")), # optional
amplitude=_to_volts(row["Amp"]),
rise_time=_to_seconds(row.get("RiseT")), # optional
duration=_to_seconds(row["Dur"]),
energy=row["Eny"],
signal_strength=row.get("SS"), # optional for spotWave
rms=_to_volts(row["RMS"]),
counts=row.get("Counts"), # optional
trai=row.get("TRAI"), # optional
cascade_hits=row.get("CHits"), # optional
cascade_counts=row.get("CCnt"), # optional
cascade_energy=row.get("CEny"), # optional
cascade_signal_strength=row.get("CSS"), # optional
)
[docs]
@dataclass
class MarkerRecord:
"""
Marker record in pridb (`SetType.LABEL`, `SetType.DATETIME`, `SetType.SECTION`).
"""
time: float #: Time in seconds
set_type: SetType #: Marker type (see above)
data: str #: Content of marker (label text or datetime)
# optional for creating:
number: int | None = None #: Marker number
set_id: int | None = None #: Unique identifier for data set in pridb
[docs]
@classmethod
def from_sql(cls, row: dict[str, Any]) -> "MarkerRecord":
"""
Create `MarkerRecord` from SQL row.
Args:
row: Dict of column names and values
"""
return cls(
set_id=row["SetID"],
time=row["Time"],
set_type=row["SetType"],
number=row["Number"],
data=row["Data"],
)
[docs]
@dataclass
class StatusRecord:
"""
Status data record in pridb (`SetType.STATUS`).
"""
time: float #: Time in seconds
channel: int #: Channel number
param_id: int #: Parameter ID of table ae_params for ADC value conversion
energy: float #: Energy (EN 1330-9) in eu (1e-14 V²s)
rms: float #: RMS in volts
# optional for creating:
set_id: int | None = None #: Unique identifier for data set in pridb
status: StatusFlags = field(default=StatusFlags(0)) #: Status flags
threshold: float | None = None #: Threshold amplitude in volts
signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs)
[docs]
@classmethod
def from_sql(cls, row: dict[str, Any]) -> "StatusRecord":
"""
Create `StatusRecord` from SQL row.
Args:
row: Dict of column names and values
"""
return cls(
set_id=row["SetID"],
time=row["Time"],
channel=row["Chan"],
status=row["Status"],
param_id=row["ParamID"],
threshold=_to_volts(row.get("Thr")), # optional
energy=row["Eny"],
signal_strength=row.get("SS"), # optional for spotWave
rms=_to_volts(row["RMS"]), # optional
)
[docs]
@dataclass
class ParametricRecord:
"""
Parametric data record in pridb (`SetType.PARAMETRIC`).
"""
time: float #: Time in seconds
param_id: int #: Parameter ID of table ae_params for ADC value conversion
# optional for creating:
set_id: int | None = None #: Unique identifier for data set in pridb
status: StatusFlags = field(default=StatusFlags(0)) #: Status flags
pctd: int | None = None #: Digital counter value
pcta: int | None = None #: Analog hysteresis counter
pa0: int | None = None #: Amplitude of parametric input 0 in volts
pa1: int | None = None #: Amplitude of parametric input 1 in volts
pa2: int | None = None #: Amplitude of parametric input 2 in volts
pa3: int | None = None #: Amplitude of parametric input 3 in volts
pa4: int | None = None #: Amplitude of parametric input 4 in volts
pa5: int | None = None #: Amplitude of parametric input 5 in volts
pa6: int | None = None #: Amplitude of parametric input 6 in volts
pa7: int | None = None #: Amplitude of parametric input 7 in volts
[docs]
@classmethod
def from_sql(cls, row: dict[str, Any]) -> "ParametricRecord":
"""
Create `ParametricRecord` from SQL row.
Args:
row: Dict of column names and values
"""
return cls(
set_id=row["SetID"],
time=row["Time"],
status=row["Status"],
param_id=row["ParamID"],
pctd=row.get("PCTD"), # optional
pcta=row.get("PCTA"), # optional
pa0=row.get("PA0"), # optional
pa1=row.get("PA1"), # optional
pa2=row.get("PA2"), # optional
pa3=row.get("PA3"), # optional
pa4=row.get("PA4"), # optional
pa5=row.get("PA5"), # optional
pa6=row.get("PA6"), # optional
pa7=row.get("PA7"), # optional
)
[docs]
@dataclass
class TraRecord:
"""Transient data record in tradb."""
time: float #: Time in seconds
channel: int #: Channel number
param_id: int #: Parameter ID of table tr_params for ADC value conversion
pretrigger: int #: Pretrigger samples
threshold: float #: Threshold amplitude in volts
samplerate: int #: Samplerate in Hz
samples: int #: Number of samples
data: np.ndarray #: Transient signal in volts or ADC values if `raw` = `True`
# optional for writing
status: HitFlags = field(default=HitFlags(0)) #: Status flags
trai: int | None = None #: Transient recorder index (foreign key between pridb and tradb)
rms: float | None = None #: RMS of the noise before the hit
# optional
raw: bool = False #: `data` is stored as ADC values (int16)
[docs]
@classmethod
def from_sql(cls, row: dict[str, Any], *, raw: bool = False) -> "TraRecord":
"""
Create `TraRecord` from SQL row.
Args:
row: Dict of column names and values
raw: Provide `data` as ADC values (int16)
"""
return TraRecord(
time=row["Time"],
channel=row["Chan"],
status=row["Status"],
param_id=row["ParamID"],
pretrigger=row["Pretrigger"],
threshold=_to_volts(row["Thr"]),
samplerate=row["SampleRate"],
samples=row["Samples"],
data=decode_data_blob(row["Data"], row["DataFormat"], row["TR_mV"], raw=raw),
trai=row["TRAI"],
raw=raw,
)
[docs]
@dataclass
class FeatureRecord:
"""
Transient feature record in trfdb.
"""
trai: int #: Transient recorder index
features: dict[str, float] #: Feature dictionary (feature name -> value)
[docs]
@classmethod
def from_sql(cls, row: dict[str, Any]) -> "FeatureRecord":
"""
Create `FeatureRecord` from SQL row.
Args:
row: Dict of column names and values
"""
return FeatureRecord(
trai=row.pop("TRAI"),
features=row,
)