Skip to content

Commit

Permalink
Merge pull request xdf-modules#106 from cbrnr/gha
Browse files Browse the repository at this point in the history
Run tests with GitHub Action
  • Loading branch information
cboulay committed Apr 26, 2024
2 parents 99dfccf + 4d1ff90 commit 09ace34
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 114 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Test

on: [push, pull_request]

jobs:
style:
name: Check style
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- uses: chartboost/ruff-action@v1
with:
args: 'format --check'

test:
needs: style
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
name: Run tests (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Build sdist and wheel
run: |
pip install build
python -m build
- name: Install wheel
run: pip install "$(pwd)/$(echo dist/pyxdf*.whl)[dev]"
- name: Run tests
run: |
git clone https://github.com/xdf-modules/example-files.git
pip install pytest
pytest
77 changes: 56 additions & 21 deletions pyxdf/examples/playback_lsl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import time
import sys
from typing import List, Optional
import time
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import pylsl

import pyxdf


Expand Down Expand Up @@ -53,22 +54,33 @@ class Streamer:


class LSLPlaybackClock:
def __init__(self, rate: float = 1.0, loop_time: float = 0.0, max_sample_rate: Optional[float] = None):
def __init__(
self,
rate: float = 1.0,
loop_time: float = 0.0,
max_sample_rate: Optional[float] = None,
):
if rate != 1.0:
print("WARNING!! rate != 1.0; It is impossible to synchronize playback streams "
"with real time streams.")
print(
"WARNING!! rate != 1.0; It is impossible to synchronize playback streams "
"with real time streams."
)
self.rate: float = rate # Maximum rate is loop_time / avg_update_interval, whatever that might be.
self._boundary = loop_time
self._max_srate = max_sample_rate
decr = (1 / self._max_srate) if self._max_srate else 2 * sys.float_info.epsilon
self._wall_start: float = pylsl.local_clock() - decr / 2
self._file_read_s: float = 0 # File read header in seconds
self._prev_file_read_s: float = 0 # File read header in seconds for previous iteration
self._prev_file_read_s: float = (
0 # File read header in seconds for previous iteration
)
self._n_loop: int = 0

def reset(self, reset_file_position: bool = False) -> None:
decr = (1 / self._max_srate) if self._max_srate else 2 * sys.float_info.epsilon
self._wall_start = pylsl.local_clock() - decr / 2 - self._file_read_s / self.rate
self._wall_start = (
pylsl.local_clock() - decr / 2 - self._file_read_s / self.rate
)
self._n_loop = 0
if reset_file_position:
self._file_read_s = 0
Expand Down Expand Up @@ -117,8 +129,13 @@ def sleep(self, duration: Optional[float] = None) -> None:
time.sleep(duration / self.rate)


def main(fname: str, playback_speed: float = 1.0, loop: bool = True, wait_for_consumer: bool = False):
streams, header = pyxdf.load_xdf(fname)
def main(
fname: str,
playback_speed: float = 1.0,
loop: bool = True,
wait_for_consumer: bool = False,
):
streams, _ = pyxdf.load_xdf(fname)

# First iterate over all streams to calculate some globals.
xdf_t0 = np.inf
Expand All @@ -140,20 +157,32 @@ def main(fname: str, playback_speed: float = 1.0, loop: bool = True, wait_for_co
tvec = strm["time_stamps"]
srate = float(strm["info"]["nominal_srate"][0])
if len(tvec) > 0:
new_info: pylsl.StreamInfo = _create_info_from_xdf_stream_header(strm["info"])
new_info: pylsl.StreamInfo = _create_info_from_xdf_stream_header(
strm["info"]
)
new_outlet: pylsl.StreamOutlet = pylsl.StreamOutlet(new_info)
streamers.append(Streamer(strm_ix, new_info.name(), tvec - xdf_t0, new_info, new_outlet, srate))
streamers.append(
Streamer(
strm_ix, new_info.name(), tvec - xdf_t0, new_info, new_outlet, srate
)
)

# Create timer to manage playback.
timer = LSLPlaybackClock(rate=playback_speed, loop_time=wrap_dur if loop else None, max_sample_rate=max_rate)
timer = LSLPlaybackClock(
rate=playback_speed,
loop_time=wrap_dur if loop else None,
max_sample_rate=max_rate,
)
read_heads = {_.name: 0 for _ in streamers}
b_push = not wait_for_consumer # A flag to indicate we can push samples.
try:
while True:
if not b_push:
# We are looking for consumers.
time.sleep(0.01)
have_consumers = [streamer.outlet.have_consumers() for streamer in streamers]
have_consumers = [
streamer.outlet.have_consumers() for streamer in streamers
]
# b_push = any(have_consumers)
b_push = all(have_consumers)
if b_push:
Expand All @@ -175,8 +204,9 @@ def main(fname: str, playback_speed: float = 1.0, loop: bool = True, wait_for_co
# Irregular rate, like events and markers
for dat_idx in range(start_idx, stop_idx):
sample = streams[streamer.stream_ix]["time_series"][dat_idx]
streamer.outlet.push_sample(sample,
timestamp=timer.t0 + streamer.tvec[dat_idx])
streamer.outlet.push_sample(
sample, timestamp=timer.t0 + streamer.tvec[dat_idx]
)
# print(f"Pushed sample: {sample}")
read_heads[streamer.name] = stop_idx
timer.sleep()
Expand All @@ -186,14 +216,19 @@ def main(fname: str, playback_speed: float = 1.0, loop: bool = True, wait_for_co


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Playback an XDF file over LSL streams.")
parser = argparse.ArgumentParser(
description="Playback an XDF file over LSL streams."
)
parser.add_argument("filename", type=str, help="Path to the XDF file")
parser.add_argument(
"filename",
type=str,
help="Path to the XDF file"
"--playback_speed", type=float, default=1.0, help="Playback speed multiplier."
)
parser.add_argument("--playback_speed", type=float, default=1.0, help="Playback speed multiplier.")
parser.add_argument("--loop", action="store_false")
parser.add_argument("--wait_for_consumer", action="store_true")
args = parser.parse_args()
main(args.filename, playback_speed=args.playback_speed, loop=args.loop, wait_for_consumer=args.wait_for_consumer)
main(
args.filename,
playback_speed=args.playback_speed,
loop=args.loop,
wait_for_consumer=args.wait_for_consumer,
)
29 changes: 16 additions & 13 deletions pyxdf/examples/print_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,33 @@
# Chadwick Boulay
#
# License: BSD (2-clause)
from os.path import abspath, join, dirname
import logging
import argparse
import logging
from os.path import abspath, dirname, join

import pyxdf


def main(fname: str):
logging.basicConfig(level=logging.DEBUG) # Use logging.INFO to reduce output
streams, fileheader = pyxdf.load_xdf(fname)
streams, _ = pyxdf.load_xdf(fname)

print("Found {} streams:".format(len(streams)))
for ix, stream in enumerate(streams):
msg = "Stream {}: {} - type {} - uid {} - shape {} at {} (effective {}) Hz"
print(msg.format(
ix + 1, stream['info']['name'][0],
stream['info']['type'][0],
stream['info']['uid'][0],
(int(stream['info']['channel_count'][0]), len(stream['time_stamps'])),
stream['info']['nominal_srate'][0],
stream['info']['effective_srate'])
print(
msg.format(
ix + 1,
stream["info"]["name"][0],
stream["info"]["type"][0],
stream["info"]["uid"][0],
(int(stream["info"]["channel_count"][0]), len(stream["time_stamps"])),
stream["info"]["nominal_srate"][0],
stream["info"]["effective_srate"],
)
)
if any(stream['time_stamps']):
duration = stream['time_stamps'][-1] - stream['time_stamps'][0]
if any(stream["time_stamps"]):
duration = stream["time_stamps"][-1] - stream["time_stamps"][0]
print("\tDuration: {} s".format(duration))
print("Done.")

Expand All @@ -37,7 +40,7 @@ def main(fname: str):
"-f",
type=str,
help="Path to the XDF file",
default=abspath(join(dirname(__file__), "..", "..", "..", "xdf_sample.xdf"))
default=abspath(join(dirname(__file__), "..", "..", "..", "xdf_sample.xdf")),
)
args = parser.parse_args()
main(args.f)
59 changes: 18 additions & 41 deletions pyxdf/pyxdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
This function is closely following the load_xdf reference implementation.
"""

import gzip
import io
import struct
import itertools
import gzip
from xml.etree.ElementTree import fromstring, ParseError
from collections import OrderedDict, defaultdict
import logging
import struct
from collections import OrderedDict, defaultdict
from pathlib import Path
from xml.etree.ElementTree import ParseError, fromstring

import numpy as np


__all__ = ["load_xdf"]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,7 +81,7 @@ def load_xdf(
clock_reset_threshold_offset_seconds=1,
clock_reset_threshold_offset_stds=10,
winsor_threshold=0.0001,
verbose=None
verbose=None,
):
"""Import an XDF file.
Expand Down Expand Up @@ -210,9 +210,7 @@ def load_xdf(
elif isinstance(select_streams, int):
select_streams = [select_streams]
elif all([isinstance(elem, dict) for elem in select_streams]):
select_streams = match_streaminfos(
resolve_streams(filename), select_streams
)
select_streams = match_streaminfos(resolve_streams(filename), select_streams)
if not select_streams: # no streams found
raise ValueError("No matching streams found.")
elif not all([isinstance(elem, int) for elem in select_streams]):
Expand Down Expand Up @@ -306,9 +304,7 @@ def load_xdf(
# noinspection PyBroadException
try:
nsamples, stamps, values = _read_chunk3(f, temp[StreamId])
logger.debug(
" reading [%s,%s]" % (temp[StreamId].nchns, nsamples)
)
logger.debug(" reading [%s,%s]" % (temp[StreamId].nchns, nsamples))
# optionally send through the on_chunk function
if on_chunk is not None:
values, stamps, streams[StreamId] = on_chunk(
Expand Down Expand Up @@ -337,12 +333,8 @@ def load_xdf(
)
elif tag == 4:
# read [ClockOffset] chunk
temp[StreamId].clock_times.append(
struct.unpack("<d", f.read(8))[0]
)
temp[StreamId].clock_values.append(
struct.unpack("<d", f.read(8))[0]
)
temp[StreamId].clock_times.append(struct.unpack("<d", f.read(8))[0])
temp[StreamId].clock_values.append(struct.unpack("<d", f.read(8))[0])
else:
# skip other chunk types (Boundary, ...)
f.read(chunklen - 2)
Expand Down Expand Up @@ -509,7 +501,7 @@ def _xml2dict(t):

def _scan_forward(f):
"""Scan forward through file object until after the next boundary chunk."""
blocklen = 2 ** 20
blocklen = 2**20
signature = bytes(
[
0x43,
Expand Down Expand Up @@ -577,28 +569,18 @@ def _clock_sync(

# points where a glitch in the timing of successive clock
# measurements happened
mad = (
np.median(np.abs(time_diff - median_ival))
+ np.finfo(float).eps
)
mad = np.median(np.abs(time_diff - median_ival)) + np.finfo(float).eps
cond1 = time_diff < 0
cond2 = (time_diff - median_ival) / mad > reset_threshold_stds
cond3 = time_diff - median_ival > reset_threshold_seconds
time_glitch = cond1 | (cond2 & cond3)

# Points where a glitch in successive clock value estimates
# happened
mad = (
np.median(np.abs(value_diff - median_slope))
+ np.finfo(float).eps
)
mad = np.median(np.abs(value_diff - median_slope)) + np.finfo(float).eps
cond1 = value_diff < 0
cond2 = (
value_diff - median_slope
) / mad > reset_threshold_offset_stds
cond3 = (
value_diff - median_slope > reset_threshold_offset_seconds
)
cond2 = (value_diff - median_slope) / mad > reset_threshold_offset_stds
cond3 = value_diff - median_slope > reset_threshold_offset_seconds
value_glitch = cond1 | (cond2 & cond3)
resets_at = time_glitch & value_glitch

Expand All @@ -607,9 +589,7 @@ def _clock_sync(
ranges = [(0, len(clock_times) - 1)]
else:
indices = np.where(resets_at)[0]
indices = np.hstack(
(0, indices, indices + 1, len(resets_at) - 1)
)
indices = np.hstack((0, indices, indices + 1, len(resets_at) - 1))
ranges = np.reshape(indices, (2, -1)).T

# Otherwise we just assume that there are no clock resets
Expand All @@ -624,8 +604,7 @@ def _clock_sync(
X = np.column_stack(
[
np.ones((stop - start,)),
np.array(clock_times[start:stop])
/ winsor_threshold,
np.array(clock_times[start:stop]) / winsor_threshold,
]
)
y = np.array(clock_values[start:stop]) / winsor_threshold
Expand All @@ -638,9 +617,7 @@ def _clock_sync(

# Apply the correction to all time stamps
if len(ranges) == 1:
stream.time_stamps += coef[0][0] + (
coef[0][1] * stream.time_stamps
)
stream.time_stamps += coef[0][0] + (coef[0][1] * stream.time_stamps)
else:
for coef_i, range_i in zip(coef, ranges):
r = slice(range_i[0], range_i[1])
Expand Down
Loading

0 comments on commit 09ace34

Please sign in to comment.