Skip to content

Commit

Permalink
Add typehints where applicable
Browse files Browse the repository at this point in the history
  • Loading branch information
mjruckriegel committed May 13, 2020
1 parent 38b7e65 commit 37bae98
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 100 deletions.
39 changes: 25 additions & 14 deletions zhinst/toolkit/control/drivers/base/awg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import numpy as np
import time
from typing import List, Union

from zhinst.toolkit.helpers import SequenceProgram, Waveform, SequenceType
from .base import ToolkitError
from .base import ToolkitError, BaseInstrument


class AWGCore:
Expand Down Expand Up @@ -89,7 +90,7 @@ class AWGCore:
"""

def __init__(self, parent, index):
def __init__(self, parent: BaseInstrument, index: int) -> None:
self._parent = parent
self._index = index
self._module = None
Expand Down Expand Up @@ -131,15 +132,15 @@ def __repr__(self):
s += f" {i}\n"
return s

def run(self):
def run(self) -> None:
"""Runs the AWG Core."""
self._parent._set(f"/awgs/{self._index}/enable", 1)

def stop(self):
def stop(self) -> None:
"""Stops the AWG Core."""
self._parent._set(f"/awgs/{self._index}/enable", 0)

def wait_done(self, timeout=10):
def wait_done(self, timeout: float = 10) -> None:
"""Waits until the AWG Core is finished.
Keyword Arguments:
Expand All @@ -153,9 +154,8 @@ def wait_done(self, timeout=10):
time.sleep(0.1)
if tik - tok > timeout:
break
return

def compile(self):
def compile(self) -> None:
"""Compiles the current SequenceProgram on the AWG Core.
Raises:
Expand Down Expand Up @@ -188,11 +188,16 @@ def compile(self):
print("Compilation successful")
self._wait_upload_done()

def reset_queue(self):
def reset_queue(self) -> None:
"""Resets the waveform queue to an empty list."""
self._waveforms = []

def queue_waveform(self, wave1, wave2, delay=0):
def queue_waveform(
self,
wave1: Union[List, np.array],
wave2: Union[List, np.array],
delay: float = 0,
) -> None:
"""Adds a new waveform to the queue.
Arguments:
Expand All @@ -216,7 +221,13 @@ def queue_waveform(self, wave1, wave2, delay=0):
self._waveforms.append(Waveform(wave1, wave2, delay=delay))
print(f"Current length of queue: {len(self._waveforms)}")

def replace_waveform(self, wave1, wave2, i=0, delay=0):
def replace_waveform(
self,
wave1: Union[List, np.array],
wave2: Union[List, np.array],
i: int = 0,
delay: float = 0,
) -> None:
"""Replaces a waveform in the queue at a given index.
Arguments:
Expand All @@ -236,7 +247,7 @@ def replace_waveform(self, wave1, wave2, i=0, delay=0):
raise ToolkitError("Index out of range!")
self._waveforms[i].replace_data(wave1, wave2, delay=delay)

def upload_waveforms(self):
def upload_waveforms(self) -> None:
"""Uploads all waveforms in the queue to the AWG Core.
This method only works as expected if the Sequence Program is in
Expand All @@ -252,7 +263,7 @@ def upload_waveforms(self):
tik = time.time()
print(f"Upload of {len(waveform_data)} waveforms took {tik - tok:.5} s")

def compile_and_upload_waveforms(self):
def compile_and_upload_waveforms(self) -> None:
"""Compiles the Sequence Program and uploads the queued waveforms.
Simply combines the two methods to make sure the sequence is compiled
Expand All @@ -262,7 +273,7 @@ def compile_and_upload_waveforms(self):
self.compile()
self.upload_waveforms()

def _wait_upload_done(self, timeout=10):
def _wait_upload_done(self, timeout: float = 10) -> None:
if self._module is None:
raise ToolkitError("This AWG is not connected to a awgModule!")
time.sleep(0.01)
Expand All @@ -278,7 +289,7 @@ def _wait_upload_done(self, timeout=10):
f"{self.name}: Sequencer status: {'ELF file uploaded' if status == 0 else 'FAILED!!'}"
)

def set_sequence_params(self, **kwargs):
def set_sequence_params(self, **kwargs) -> None:
"""Sets the parameters of the Sequence Program.
Passes all the keyword arguments to the `set_param(...)` method of the
Expand Down
15 changes: 9 additions & 6 deletions zhinst/toolkit/control/drivers/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# of the MIT license. See the LICENSE file for details.

import numpy as np
from typing import List, Dict

from zhinst.toolkit.control.connection import DeviceConnection, ZIConnection
from zhinst.toolkit.control.node_tree import NodeTree
Expand Down Expand Up @@ -60,7 +61,9 @@ class BaseInstrument:
"""

def __init__(self, name: str, device_type: DeviceTypes, serial: str, **kwargs):
def __init__(
self, name: str, device_type: DeviceTypes, serial: str, **kwargs
) -> None:
self._config = InstrumentConfiguration()
self._config._instrument._name = name
self._config._instrument._config._device_type = device_type
Expand All @@ -72,7 +75,7 @@ def __init__(self, name: str, device_type: DeviceTypes, serial: str, **kwargs):
self._controller = DeviceConnection(self)
self._nodetree = None

def setup(self, connection: ZIConnection = None):
def setup(self, connection: ZIConnection = None) -> None:
"""Sets up the data server connection.
The details of the connection (`host`, `port`, `api_level`) can be
Expand All @@ -88,7 +91,7 @@ def setup(self, connection: ZIConnection = None):
"""
self._controller.setup(connection=connection)

def connect_device(self, nodetree=True):
def connect_device(self, nodetree: bool = True) -> None:
"""Connects the device to the data server.
This method connects the device to the data server of its connection,
Expand Down Expand Up @@ -158,7 +161,7 @@ def _set(self, *args):
self._check_connected()
return self._controller.set(*args)

def _get(self, command, valueonly=True):
def _get(self, command: str, valueonly: bool = True):
"""Getter for the instrument.
This method gets a node value from the device, specified by a node
Expand Down Expand Up @@ -205,7 +208,7 @@ def _get(self, command, valueonly=True):
self._check_connected()
return self._controller.get(command, valueonly=valueonly)

def _get_nodetree(self, prefix, **kwargs):
def _get_nodetree(self, prefix: str, **kwargs) -> Dict:
"""Gets the entire nodetree from the instrument as a dictionary.
This method passes the arguments to the :class:`DeviceConnection` and
Expand All @@ -228,7 +231,7 @@ def _get_nodetree(self, prefix, **kwargs):
self._check_connected()
return self._controller.get_nodetree(prefix, **kwargs)

def _get_streamingnodes(self):
def _get_streamingnodes(self) -> List:
self._check_connected()
nodes = self._controller.get_nodetree(f"/{self.serial}/*", streamingonly=True)
nodes = list(nodes.keys())
Expand Down
58 changes: 32 additions & 26 deletions zhinst/toolkit/control/drivers/base/daq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time
import numpy as np
from typing import List, Dict

from .base import ToolkitError
from .base import ToolkitError, BaseInstrument
from zhinst.toolkit.control.node_tree import Parameter


Expand Down Expand Up @@ -117,7 +118,7 @@ class DAQModule:
"""

def __init__(self, parent, clk_rate=60e6):
def __init__(self, parent: BaseInstrument, clk_rate: float = 60e6) -> None:
self._parent = parent
self._module = None
self._signals = []
Expand Down Expand Up @@ -153,7 +154,7 @@ def __init__(self, parent, clk_rate=60e6):
self._trigger_signals = {}
self._trigger_types = {}

def _setup(self):
def _setup(self) -> None:
self._module = self._parent._controller._connection.daq_module
# add all parameters from nodetree
nodetree = self._module.get_nodetree("*")
Expand All @@ -168,7 +169,7 @@ def _set(self, *args):
raise ToolkitError("This DAQ is not connected to a dataAcquisitionModule!")
return self._module.set(*args, device=self._parent.serial)

def _get(self, *args, valueonly=True):
def _get(self, *args, valueonly: bool = True):
if self._module is None:
raise ToolkitError("This DAQ is not connected to a dataAcquisitionModule!")
data = self._module.get(*args, device=self._parent.serial)
Expand All @@ -183,7 +184,7 @@ def _init_settings(self):
self._set("clearhistory", 1)
self._set("bandwidth", 0)

def trigger_list(self, source=None):
def trigger_list(self, source=None) -> List:
"""Returns a list of all the available signal sources for triggering.
Keyword Arguments:
Expand All @@ -204,7 +205,7 @@ def trigger_list(self, source=None):
if signal in source:
return list(self._trigger_types[signal].keys())

def trigger(self, trigger_source, trigger_type):
def trigger(self, trigger_source: str, trigger_type: str) -> None:
"""Sets the trigger signal of the *DAQ Module*.
This method can be used to specify the signal used to trigger the data
Expand All @@ -223,7 +224,7 @@ def trigger(self, trigger_source, trigger_type):
self._set("/triggernode", trigger_node)
print(f"set trigger node to '{trigger_node}'")

def signals_list(self, source=None):
def signals_list(self, source=None) -> List:
"""Returns a list of all the available signal sources for data acquisition.
Keyword Arguments:
Expand All @@ -248,12 +249,12 @@ def signals_list(self, source=None):

def signals_add(
self,
signal_source,
signal_type="",
operation="avg",
fft=False,
complex_selector="abs",
):
signal_source: str,
signal_type: str = "",
operation: str = "avg",
fft: bool = False,
complex_selector: str = "abs",
) -> str:
"""Add a signal to the signals list to be subscribed to during measurement.
The specified signal is added to the property *signals* list. On
Expand Down Expand Up @@ -298,11 +299,11 @@ def signals_add(
self._signals.append(signal_node)
return signal_node

def signals_clear(self):
def signals_clear(self) -> None:
"""Resets the signals list."""
self._signals = []

def measure(self, verbose=True, timeout=20):
def measure(self, verbose: bool = True, timeout: float = 20) -> None:
"""Performs the measurement.
Starts a measurement and stores the result in `daq.results`. This
Expand Down Expand Up @@ -340,24 +341,29 @@ def measure(self, verbose=True, timeout=20):
self._get_result_from_dict(result)

def _parse_signals(
self, signal_source, signal_type, operation, fft, complex_selector,
):
self,
signal_source: str,
signal_type: str,
operation: str,
fft: bool,
complex_selector: str,
) -> str:
signal_node = "/" + self._parent.serial
signal_node += self._parse_signal_source(signal_source)
signal_node += self._parse_signal_type(signal_type, signal_source)
signal_node += self._parse_fft(fft, complex_selector)
signal_node += self._parse_operation(operation)
return signal_node.lower()

def _parse_signal_source(self, source):
def _parse_signal_source(self, source: str) -> str:
source = source.lower()
if source not in self._signal_sources:
raise ToolkitError(
f"Signal source must be in {self._signal_sources.keys()}"
)
return self._signal_sources[source]

def _parse_signal_type(self, signal_type, signal_source):
def _parse_signal_type(self, signal_type: str, signal_source: str) -> str:
signal_source = signal_source.lower()
signal_type = signal_type.lower()
types = {}
Expand All @@ -368,15 +374,15 @@ def _parse_signal_type(self, signal_type, signal_source):
raise ToolkitError(f"Signal type must be in {types.keys()}")
return types[signal_type]

def _parse_operation(self, operation):
def _parse_operation(self, operation: str) -> str:
operations = ["replace", "avg", "std"]
if operation not in operations:
raise ToolkitError(f"Operation must be in {operations}")
if operation == "replace":
operation = ""
return f".{operation}"

def _parse_fft(self, fft, selector):
def _parse_fft(self, fft: bool, selector: str) -> str:
if fft:
selectors = ["real", "imag", "abs", "phase"]
if selector not in selectors:
Expand All @@ -385,20 +391,20 @@ def _parse_fft(self, fft, selector):
else:
return ""

def _parse_trigger(self, trigger_source, trigger_type):
def _parse_trigger(self, trigger_source: str, trigger_type: str) -> str:
trigger_node = "/" + self._parent.serial
trigger_node += self._parse_trigger_source(trigger_source)
trigger_node += self._parse_trigger_type(trigger_source, trigger_type)
return trigger_node

def _parse_trigger_source(self, source):
def _parse_trigger_source(self, source: str) -> str:
source = source.lower()
sources = self._trigger_signals
if source not in sources:
raise ToolkitError(f"Signal source must be in {sources.keys()}")
return sources[source]

def _parse_trigger_type(self, trigger_source, trigger_type):
def _parse_trigger_type(self, trigger_source: str, trigger_type: str) -> str:
trigger_source = trigger_source.lower()
trigger_type = trigger_type.lower()
types = {}
Expand All @@ -409,7 +415,7 @@ def _parse_trigger_type(self, trigger_source, trigger_type):
raise ToolkitError(f"Signal type must be in {types.keys()}")
return types[trigger_type]

def _get_result_from_dict(self, result):
def _get_result_from_dict(self, result: Dict):
self._results = {}
for node in self.signals:
node = node.lower()
Expand Down Expand Up @@ -485,7 +491,7 @@ class DAQResult:
"""

def __init__(self, path, result_dict, clk_rate=60e6):
def __init__(self, path: str, result_dict: Dict, clk_rate: float = 60e6) -> None:
self._path = path
self._clk_rate = clk_rate
self._is_fft = "fft" in self._path
Expand Down
Loading

0 comments on commit 37bae98

Please sign in to comment.