"""
I/O engine for ASN HDF5 files and live ZMQ streams.
Includes :class:`ASNEngine` and a ZMQ-based real-time subscriber
(:class:`ZMQSubscriber`) for live ASN streams.
"""
import json
from bisect import bisect_left, bisect_right
import h5py
import numpy as np
import zmq
from ..coordinates.core import Coordinate, get_sampling_interval
from ..core.dataarray import DataArray
from ..virtual import VirtualSource
from .core import Engine
class ASNEngine(Engine, name="asn"):
"""Engine for reading ASN HDF5 files."""
_supported_vtypes = ["hdf5"]
_supported_ctypes = {
"time": ["interpolated", "sampled", "dense"],
"distance": ["interpolated"],
}
def open_dataarray(self, fname):
"""Read an ASN HDF5 file *fname* and return a virtual :class:`DataArray`."""
with h5py.File(fname, "r") as file:
header = file["header"]
demod = file["demodSpec"]
t0 = np.datetime64(round(header["time"][()] * 1e9), "ns")
dt = np.timedelta64(round(1e9 * header["dt"][()]), "ns")
dx = float(header["dx"][()]) # Note: dx before (internal) downsampling!
data = VirtualSource(file["data"])
# Get the optical distance for all the recorded channels (after downsampling)
# Note that this vector is not continuous for more than one ROI
all_dists = file["cableSpec"]["sensorDistances"][...]
# Buffer for the data index at which each ROI starts/stops
dist_tie_inds = []
# Buffer for the optical distance at which each ROI starts/stops
dist_tie_vals = []
# Loop over ROIs, get the start/stop index before downsampling
for n_start, n_end in zip(demod["roiStart"], demod["roiEnd"]):
# ASN stores ROI end as an upper boundary. Use the last sampled distance
# that does not exceed that boundary instead of indexing the insertion point.
i_start, i_end = self._get_roi_bound_indices(
all_dists, n_start, n_end, dx
)
# Get the index where the ROI starts based on the position in the
# distance vector. This solves the issue of rounding during decimation
# Append the data index and optical distance to the buffers
dist_tie_inds.append(i_start)
dist_tie_vals.append(float(all_dists[i_start]))
# Repeat the procedure for the index/distance at which the ROI ends.
dist_tie_inds.append(i_end)
dist_tie_vals.append(float(all_dists[i_end]))
nt = data.shape[0]
time = Coordinate[self.ctype["time"]].from_block(t0, nt, dt, dim="time")
distance = {"tie_indices": dist_tie_inds, "tie_values": dist_tie_vals}
return DataArray(data, {"time": time, "distance": distance})
def _get_roi_bound_indices(self, all_dists, n_start, n_end, dx):
start_index = bisect_left(all_dists, n_start * dx)
if start_index >= len(all_dists):
raise IndexError("ROI start lies beyond available sensor distances")
end_index = bisect_right(all_dists, n_end * dx) - 1
if end_index < 0:
raise IndexError("ROI end lies before available sensor distances")
return start_index, end_index
type_map = {
"short": np.int16,
"int": np.int32,
"long": np.int64,
"float": np.float32,
"double": np.float64,
}
[docs]
class ZMQSubscriber:
"""
Iterator that pulls :class:`DataArray` chunks from a live ASN ZMQ publisher.
Parameters
----------
address : str
ZMQ address of the publisher (e.g. ``"tcp://localhost:5555"``).
"""
[docs]
def __init__(self, address):
"""
Initialize a ZMQStream object.
Parameters
----------
address : str
The address to connect to.
Examples
--------
>>> import time
>>> import threading
>>> import xdas as xd
>>> from xdas.io.asn import ZMQSubscriber
>>> port = xd.io.get_free_port()
>>> address = f"tcp://localhost:{port}"
>>> publisher = ZMQPublisher(address)
>>> da = xd.synthetics.dummy()
>>> chunks = xd.split(da, 10)
>>> def publish():
... for chunk in chunks:
... time.sleep(0.001) # so that the subscriber can connect in time
... publisher.submit(chunk)
>>> threading.Thread(target=publish).start()
>>> subscriber = ZMQSubscriber(address)
>>> for nchunk in range(10):
... chunk = next(subscriber)
... # do something with the chunk
"""
self.address = address
self._connect(self.address)
message = self._get_message()
self._update_header(message)
def __iter__(self):
return self
def __next__(self):
message = self._get_message()
if not self._is_packet(message):
self._update_header(message)
return self.__next__()
else:
return self._unpack(message)
def _connect(self, address):
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.connect(address)
socket.setsockopt_string(zmq.SUBSCRIBE, "")
self._socket = socket
def _get_message(self):
return self._socket.recv()
def _is_packet(self, message):
return len(message) == self.packet_size
def _update_header(self, message):
header = json.loads(message.decode("utf-8"))
self.packet_size = 8 + header["bytesPerPackage"] * header["nPackagesPerMessage"]
self.shape = (header["nPackagesPerMessage"], header["nChannels"])
self.dtype = type_map[header["dataType"]]
roiTable = header["roiTable"][0]
di = (roiTable["roiStart"] // roiTable["roiDec"]) * header["dx"]
de = (roiTable["roiEnd"] // roiTable["roiDec"]) * header["dx"]
self.distance = { # TODO: use from_block
"tie_indices": [0, header["nChannels"] - 1],
"tie_values": [di, de],
}
self.delta = float_to_timedelta(header["dt"], header["dtUnit"])
def _unpack(self, message):
t0 = np.frombuffer(message[:8], "datetime64[ns]").reshape(())
data = np.frombuffer(message[8:], self.dtype).reshape(self.shape)
time = { # TODO: use from_block
"tie_indices": [0, self.shape[0] - 1],
"tie_values": [t0, t0 + (self.shape[0] - 1) * self.delta],
}
return DataArray(data, {"time": time, "distance": self.distance})
[docs]
class ZMQPublisher:
"""
A class to stream data using ZeroMQ.
Parameters
----------
address : str
The address to bind the ZeroMQ socket.
Attributes
----------
address : str
The address where the ZeroMQ is bound to.
Methods
-------
submit(da)
Submits the data array for publishing.
Examples
--------
>>> import xdas as xd
>>> from xdas.io.asn import ZMQPublisher
>>> da = xd.synthetics.dummy()
>>> port = xd.io.get_free_port()
>>> address = f"tcp://localhost:{port}"
>>> publisher = ZMQPublisher(address)
>>> chunks = xd.split(da, 10)
>>> for chunk in chunks:
... publisher.submit(chunk)
"""
[docs]
def __init__(self, address):
self.address = address
self._connect(address)
self._header = None
@property
def header(self):
"""The last welcome-message header dict sent to new subscribers."""
return self._header
@header.setter
def header(self, header):
"""Set the welcome-message header and push it to the ZMQ socket option."""
self._header = header
self.socket.setsockopt(zmq.XPUB_WELCOME_MSG, json.dumps(header).encode("utf-8"))
[docs]
def submit(self, da):
"""Publish *da* over ZMQ."""
self._send(da)
def write(self, da):
"""Alias for :meth:`submit`."""
self._send(da)
def _connect(self, address):
context = zmq.Context()
socket = context.socket(zmq.XPUB)
socket.setsockopt(zmq.XPUB_VERBOSE, True)
socket.bind(address)
self.socket = socket
@staticmethod
def _get_header(da):
da = da.transpose("time", "distance")
header = {
"bytesPerPackage": da.dtype.itemsize * da.shape[1],
"nPackagesPerMessage": da.shape[0],
"nChannels": da.shape[1],
"dataType": next((k for k, v in type_map.items() if v == da.dtype), None),
"dx": get_sampling_interval(da, "distance"),
"dt": get_sampling_interval(da, "time"),
"dtUnit": "s",
"dxUnit": "m",
"roiTable": [{"roiStart": 0, "roiEnd": da.shape[1] - 1, "roiDec": 1}],
}
return header
def _send(self, da):
da = da.transpose("time", "distance")
header = self._get_header(da)
if self.header is None:
self.header = header
if not header == self.header:
self.header = header
self._send_header()
self._send_data(da)
def _send_header(self):
message = json.dumps(self.header).encode("utf-8")
self._send_message(message)
def _send_data(self, da):
da = da.transpose("time", "distance")
t0 = da["time"][0].values.astype("datetime64[ns]")
data = da.values
message = t0.tobytes() + data.tobytes()
self._send_message(message)
def _send_message(self, message):
self.socket.send(message)
def float_to_timedelta(value, unit):
"""
Convert a floating-point value to a timedelta object.
Parameters
----------
value : float
The value to be converted.
unit : str
The unit of the value. Valid units are 'ns' (nanoseconds), 'us' (microseconds),
'ms' (milliseconds), and 's' (seconds).
Returns
-------
timedelta
The converted timedelta object.
Example
-------
float_to_timedelta(1.5, 'ms') # doctest: +SKIP
np.timedelta64(1500000,'ns')
"""
conversion_factors = {
"ns": 1e0,
"us": 1e3,
"ms": 1e6,
"s": 1e9,
}
conversion_factor = conversion_factors[unit]
return np.timedelta64(round(value * conversion_factor), "ns")