Skip to content

Commit

Permalink
Add support for allow_version_mismatch (L1-2516)
Browse files Browse the repository at this point in the history
Expose the flag so that once it's changed to False by default, it can be
overridden.
  • Loading branch information
Fabio Rossetto authored and fabiorossetto committed Nov 25, 2024
1 parent 49c6282 commit 90aa8a0
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 43 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# zhinst-toolkit Changelog

## Version 0.7.1
* Added support for the `allow_version_mismatch` in the `Session` constructor. When set to False, an exception will be raised when attempting to connect to a data-server on a different version than that of the zhinst.core library.

## Version 0.7.0
* Add QHub driver

Expand Down
7 changes: 1 addition & 6 deletions src/zhinst/toolkit/driver/modules/shfqa_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path

import numpy as np
from zhinst.core import ziDAQServer
from zhinst.utils.shf_sweeper import AvgConfig, EnvelopeConfig, RfConfig
from zhinst.utils.shf_sweeper import ShfSweeper as CoreSweeper
from zhinst.utils.shf_sweeper import SweepConfig, TriggerConfig
Expand Down Expand Up @@ -71,11 +70,7 @@ def __init__(self, session: "Session"):
"force_sw_trigger": "sw_trigger_mode",
}
super().__init__(self._create_nodetree(), tuple())
self._daq_server = ziDAQServer(
session.daq_server.host,
session.daq_server.port,
6,
)
self._daq_server = session.clone_underlying_session()
self._raw_module = CoreSweeper(self._daq_server, "")
self._session = session
self.root.update_nodes(
Expand Down
48 changes: 44 additions & 4 deletions src/zhinst/toolkit/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for managing a session to a Data Server through zhinst.core."""

import json
import typing as t
from collections.abc import MutableMapping
Expand Down Expand Up @@ -654,6 +655,12 @@ class Session(Node):
connection: Existing DAQ server object. If specified the session will
not create a new session to the data server but reuse the passed
one. (default = None)
allow_version_mismatch: When set to False, an exception will be raised
when attempting to connect to a data-server on a different version
than that of the zhinst.core library. (default = True)
.. versionchanged:: 0.7.1
Added `allow_version_mismatch` argument.
"""

def __init__(
Expand All @@ -663,6 +670,7 @@ def __init__(
*,
hf2: t.Optional[bool] = None,
connection: t.Optional[core.ziDAQServer] = None,
allow_version_mismatch: bool = True,
):
self._is_hf2_server = bool(hf2)
if connection is not None:
Expand All @@ -683,10 +691,8 @@ def __init__(
if self._is_hf2_server and server_port == 8004:
server_port = 8005
try:
self._daq_server = core.ziDAQServer(
server_host,
server_port,
1 if self._is_hf2_server else 6,
self._daq_server = self._create_daq(
server_host, server_port, allow_version_mismatch
)
except RuntimeError as error:
if "Unsupported API level" not in error.args[0]:
Expand Down Expand Up @@ -987,3 +993,37 @@ def server_host(self) -> str:
def server_port(self) -> int:
"""Server port."""
return self._daq_server.port

def clone_underlying_session(self) -> core.ziDAQServer:
"""Create a new session to the data server.
Create a new core.ziDAQServer connected to the same data-server this
session is connected to.
"""
# Don't execute version checking. When clone_underlying_session is called,
# a connection has already been made, so checking again would be redundant.
return self._create_daq(self.server_host, self.server_port, True)

def _create_daq(
self,
server_host: str,
server_port: int,
allow_version_mismatch: bool,
):
"""Create a new session to the data server.
Attempt to pass the allow_version_mismatch flag. Fallback in case
zhinst.core does not support it yet.
"""
api_level = 1 if self._is_hf2_server else 6
try:
return core.ziDAQServer(
server_host,
server_port,
api_level,
allow_version_mismatch=allow_version_mismatch,
)
except TypeError as error:
if "allow_version_mismatch" not in error.args[0]:
raise
return core.ziDAQServer(server_host, server_port, api_level)
11 changes: 0 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def hf2_session(mock_connection):

@pytest.fixture()
def shfqa(data_dir, mock_connection, session):

json_path = data_dir / "nodedoc_dev1234_shfqa.json"
with json_path.open("r", encoding="UTF-8") as file:
nodes_json = file.read()
Expand All @@ -62,7 +61,6 @@ def shfqa(data_dir, mock_connection, session):

@pytest.fixture()
def shfsg(data_dir, mock_connection, session):

json_path = data_dir / "nodedoc_dev1234_shfsg.json"
with json_path.open("r", encoding="UTF-8") as file:
nodes_json = file.read()
Expand All @@ -74,7 +72,6 @@ def shfsg(data_dir, mock_connection, session):

@pytest.fixture()
def shfqc(data_dir, mock_connection, session):

json_path = data_dir / "nodedoc_dev1234_shfqc.json"
with json_path.open("r", encoding="UTF-8") as file:
nodes_json = file.read()
Expand All @@ -96,11 +93,3 @@ def nodedoc_dev1234_json(data_dir):
json_path = data_dir / "nodedoc_dev1234.json"
with json_path.open("r", encoding="UTF-8") as file:
return file.read()


@pytest.fixture()
def mock_sweeper_daq():
with patch(
"zhinst.toolkit.driver.modules.shfqa_sweeper.ziDAQServer", autospec=True
) as connection:
yield connection
61 changes: 54 additions & 7 deletions tests/test_data_server_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand All @@ -9,7 +9,9 @@


def test_setup(mock_connection, session):
mock_connection.assert_called_once_with("localhost", 8004, 6)
mock_connection.assert_called_once_with(
"localhost", 8004, 6, allow_version_mismatch=True
)
mock_connection.return_value.listNodesJSON.assert_called_once_with("/zi/*")
assert repr(session) == "DataServerSession(localhost:8004)"
assert not session.is_hf2_server
Expand All @@ -18,12 +20,60 @@ def test_setup(mock_connection, session):


def test_setup_hf2(mock_connection, hf2_session):
mock_connection.assert_called_once_with("localhost", 8005, 1)
mock_connection.assert_called_once_with(
"localhost", 8005, 1, allow_version_mismatch=True
)
mock_connection.return_value.listNodesJSON.assert_not_called()
assert repr(hf2_session) == "HF2DataServerSession(localhost:8005)"
assert hf2_session.is_hf2_server


def test_allow_mismatch_not_supported(mock_connection, nodedoc_zi_json):
mock_daq = MagicMock()
mock_daq.listNodesJSON.return_value = nodedoc_zi_json

def create_daq(*args, **kwargs):
if "allow_version_mismatch" in kwargs:
raise TypeError("allow_version_mismatch not recognized")
return mock_daq

mock_connection.side_effect = create_daq
Session("localhost", 8004)
mock_connection.assert_any_call("localhost", 8004, 6, allow_version_mismatch=True)
mock_connection.assert_called_with("localhost", 8004, 6)


# Passing "allow_version_mismatch" does not cause an error, even if the underlying
# zhinst.core does not recognize this flag.
def test_allow_mismatch_passed_but_not_supported(mock_connection, nodedoc_zi_json):
mock_daq = MagicMock()
mock_daq.listNodesJSON.return_value = nodedoc_zi_json

def create_daq(*args, **kwargs):
if "allow_version_mismatch" in kwargs:
raise TypeError("allow_version_mismatch not recognized")
return mock_daq

mock_connection.side_effect = create_daq
Session("localhost", 8004, allow_version_mismatch=True)
mock_connection.assert_any_call("localhost", 8004, 6, allow_version_mismatch=True)
mock_connection.assert_called_with("localhost", 8004, 6)


def test_allow_mismatch_default(mock_connection, nodedoc_zi_json):
mock_daq = MagicMock()
mock_daq.listNodesJSON.return_value = nodedoc_zi_json

def create_daq(*args, **kwargs):
return mock_daq

mock_connection.side_effect = create_daq
Session("localhost", 8004)
mock_connection.assert_called_once_with(
"localhost", 8004, 6, allow_version_mismatch=True
)


def test_existing_connection(nodedoc_zi_json, mock_connection):
mock_connection.listNodesJSON.return_value = nodedoc_zi_json
mock_connection.getString.return_value = "DataServer"
Expand Down Expand Up @@ -77,7 +127,6 @@ def test_unkown_init_error(mock_connection):
def test_connect_device(
zi_devices_json, mock_connection, session, nodedoc_dev1234_json
):

connected_devices = ""

def get_string_side_effect(arg):
Expand Down Expand Up @@ -137,7 +186,6 @@ def connect_device_side_effect(serial, _):
def test_connect_device_autodetection(
zi_devices_json, mock_connection, session, nodedoc_dev1234_json
):

connected_devices = ""
selected_interface = ""

Expand Down Expand Up @@ -478,7 +526,7 @@ def test_sweeper_module(data_dir, mock_connection, session):
assert isinstance(sweeper_module.device, Node)


def test_shfqa_sweeper(session, mock_sweeper_daq):
def test_shfqa_sweeper(session):
sweeper = session.modules.shfqa_sweeper
assert sweeper == session.modules.shfqa_sweeper
assert isinstance(sweeper, tk_modules.SHFQASweeper)
Expand All @@ -487,7 +535,6 @@ def test_shfqa_sweeper(session, mock_sweeper_daq):
def test_session_wide_transaction(
mock_connection, nodedoc_dev1234_json, session, shfqa, shfsg
):

# Hack devices into the created once
session._devices._devices = {"dev1": shfqa, "dev2": shfsg}

Expand Down
21 changes: 6 additions & 15 deletions tests/test_shfqa_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def mock_TriggerConfig():


@pytest.fixture()
def sweeper_module(session, mock_sweeper_daq, mock_shf_sweeper):
def sweeper_module(session, mock_shf_sweeper):
yield SHFQASweeper(session)


def test_repr(sweeper_module):
assert "SHFQASweeper(DataServerSession(localhost:8004))" in repr(sweeper_module)


def test_missing_node(mock_connection, mock_shf_sweeper, mock_sweeper_daq, session):
def test_missing_node(mock_connection, mock_shf_sweeper, session):
with patch(
"zhinst.toolkit.driver.modules.shfqa_sweeper.SweepConfig",
make_dataclass("Y", fields=[("s", str, 0)], bases=(SweepConfig,)),
Expand All @@ -50,13 +50,11 @@ def test_missing_node(mock_connection, mock_shf_sweeper, mock_sweeper_daq, sessi
assert sweeper_module.sweep.s() == 0


def test_device(
mock_connection, sweeper_module, session, mock_shf_sweeper, mock_sweeper_daq
):
def test_device(mock_connection, sweeper_module, session, mock_shf_sweeper):
assert sweeper_module.device() == ""

sweeper_module.device("dev1234")
mock_shf_sweeper.assert_called_with(mock_sweeper_daq(), "dev1234")
mock_shf_sweeper.assert_called_with(sweeper_module._daq_server, "dev1234")
assert sweeper_module.device() == "dev1234"

connected_devices = "dev1234"
Expand All @@ -73,17 +71,15 @@ def get_string_side_effect(arg):
assert sweeper_module.device() == session.devices["dev1234"]


def test_update_settings(
mock_connection, sweeper_module, mock_shf_sweeper, mock_sweeper_daq
):
def test_update_settings(sweeper_module, mock_shf_sweeper):
assert not sweeper_module.envelope.enable()
mock_shf_sweeper.assert_called_with(sweeper_module._daq_server, "")
# device needs to be set first
with pytest.raises(RuntimeError) as e_info:
sweeper_module._update_settings()

sweeper_module.device("dev1234")
mock_shf_sweeper.assert_called_with(mock_sweeper_daq(), "dev1234")
mock_shf_sweeper.assert_called_with(sweeper_module._daq_server, "dev1234")
sweeper_module._update_settings()
mock_shf_sweeper.return_value.configure.assert_called_once()
assert "sweep_config" in mock_shf_sweeper.return_value.configure.call_args[1]
Expand Down Expand Up @@ -135,7 +131,6 @@ def test_update_settings_broken(
mock_shf_sweeper,
caplog,
):

sweeper_module.device("dev1234")
sweeper_module._update_settings()
assert (
Expand All @@ -157,31 +152,27 @@ def test_update_settings_broken(


def test_run(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.run()
mock_shf_sweeper.return_value.configure.assert_called_once()
mock_shf_sweeper.return_value.run.assert_called_once()


def test_get_result(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.get_result()
mock_shf_sweeper.return_value.configure.assert_called_once()
mock_shf_sweeper.return_value.get_result.assert_called_once()


def test_plot(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.plot()
mock_shf_sweeper.return_value.configure.assert_called_once()
mock_shf_sweeper.return_value.plot.assert_called_once()


def test_get_offset_freq_vector(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.get_offset_freq_vector()
mock_shf_sweeper.return_value.configure.assert_called_once()
Expand Down

0 comments on commit 90aa8a0

Please sign in to comment.