Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: onboard onto ruff #38

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,7 @@ def test_exact_shots(benchmark, device_id, nq, exact_results, circuit):
exact_results in ("state_vector",) or nq > 10
):
pytest.skip()
if (
device_id in ("braket_sv",)
and exact_results in ("density_matrix q[0], q[1]",)
and nq >= 17
):
if device_id in ("braket_sv",) and exact_results in ("density_matrix q[0], q[1]",) and nq >= 17:
pytest.skip()
result_type = exact_results
oq3_prog = Program(source=circuit(nq, result_type))
Expand All @@ -86,9 +82,7 @@ def test_exact_shots(benchmark, device_id, nq, exact_results, circuit):
@pytest.mark.parametrize("batch_size", batch_size)
@pytest.mark.parametrize("exact_results", exact_shots_results)
@pytest.mark.parametrize("circuit", generators)
def test_exact_shots_batched(
benchmark, device_id, nq, batch_size, exact_results, circuit
):
def test_exact_shots_batched(benchmark, device_id, nq, batch_size, exact_results, circuit):
if device_id in ("braket_dm_v2", "braket_dm") and (
exact_results in ("state_vector,") or nq >= 5
):
Expand All @@ -100,9 +94,7 @@ def test_exact_shots_batched(
result_type = exact_results
oq3_prog = [Program(source=circuit(nq, result_type)) for _ in range(batch_size)]
sim = LocalSimulator(device_id)
benchmark.pedantic(
run_sim_batch, args=(oq3_prog, sim, 0), iterations=5, warmup_rounds=1
)
benchmark.pedantic(run_sim_batch, args=(oq3_prog, sim, 0), iterations=5, warmup_rounds=1)


shots = (100,)
Expand All @@ -119,9 +111,7 @@ def test_nonzero_shots(benchmark, device_id, nq, shots, nonzero_shots_results, c
result_type = nonzero_shots_results
oq3_prog = Program(source=circuit(nq, result_type))
sim = LocalSimulator(device_id)
benchmark.pedantic(
run_sim, args=(oq3_prog, sim, shots), iterations=5, warmup_rounds=1
)
benchmark.pedantic(run_sim, args=(oq3_prog, sim, shots), iterations=5, warmup_rounds=1)
del sim


Expand All @@ -145,7 +135,5 @@ def test_nonzero_shots_batched(
result_type = nonzero_shots_results
oq3_prog = [Program(source=circuit(nq, result_type)) for _ in range(batch_size)]
sim = LocalSimulator(device_id)
benchmark.pedantic(
run_sim_batch, args=(oq3_prog, sim, shots), iterations=5, warmup_rounds=1
)
benchmark.pedantic(run_sim_batch, args=(oq3_prog, sim, shots), iterations=5, warmup_rounds=1)
del sim
52 changes: 50 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,53 @@ package-data = {"*" = ["*.json"]}
dependencies = {file = "requirements.txt"}
optional-dependencies.test = { file = "requirements-test.txt" }

[tool.isort]
profile = "black"

[tool.ruff]
target-version = "py39"
line-length = 100
format.preview = true
format.docstring-code-line-length = 100
lint.select = [
"ALL",
]
lint.ignore = [
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod"
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed in `arg`"
"BLE001", # This needs to be cleaned up later.
"COM812", # conflicts with formatter
"CPY", # No copyright header
"D", # ignore documentation for now
"D203", # `one-blank-line-before-class` (D203) and `no-blank-line-before-class` (D211) are incompatible
"D212", # `multi-line-summary-first-line` (D212) and `multi-line-summary-second-line` (D213) are incompatible
"DOC201", # no restructuredtext support yet
"DOC402", # no restructuredtext support yet
"DOC501", # broken with sphinx docs
"INP001", # no implicit namespaces here
"ISC001", # conflicts with formatter
"PLR0914", ## Too many local variables
"PLR0917", ## Too many positional arguments
"PLW0603", # Allow usage of global vars
"S104", # Possible binding to all interfaces
"S404", # Using subprocess is alright.
"S603", # Using subprocess is alright.
]
lint.per-file-ignores."tests/**/*.py" = [
"D", # don't care about documentation in tests
"FBT", # don"t care about booleans as positional arguments in tests
"INP001", # no implicit namespace
"PLR2004", # Magic value used in comparison, consider replacing with a constant variable
"S101", # asserts allowed in tests...
"S603", # `subprocess` call: check for execution of untrusted input
]
lint.isort = { known-first-party = [
"simulator_v2",
"tests",
] }
lint.preview = true

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.flake8-annotations]
mypy-init-return = false
9 changes: 1 addition & 8 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
amazon-braket-pennylane-plugin
black
flake8
flake8-rst-docstrings
isort
pre-commit
pylint
pytest==7.1.2
pytest-benchmark
pytest-cov
Expand All @@ -14,7 +9,5 @@ pytest-xdist
qiskit==1.2.0
qiskit-braket-provider==0.4.1
qiskit-algorithms
sphinx
sphinx-rtd-theme
sphinxcontrib-apidoc
ruff
tox
20 changes: 0 additions & 20 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,3 @@ addopts =
--verbose -n auto --durations=0 --durations-min=1 --dist worksteal
testpaths = test/unit_tests

[flake8]
ignore =
# not pep8, black adds whitespace before ':'
E203,
# not pep8, black adds line break before binary operator
W503,
# Google Python style is not RST until after processed by Napoleon
# See https://github.com/peterjc/flake8-rst-docstrings/issues/17
RST201,RST203,RST301,
max_line_length = 100
max-complexity = 10
exclude =
__pycache__
.tox
.git
bin
dist
examples
build
venv
125 changes: 62 additions & 63 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,76 @@
from __future__ import annotations

import atexit
import json
from collections.abc import Sequence
import os
import sys
from itertools import starmap
from multiprocessing.pool import Pool
from typing import Optional, Union
from typing import TYPE_CHECKING

import numpy as np

from braket.default_simulator.simulator import BaseLocalSimulator
from braket.ir.jaqcd import DensityMatrix, Probability, StateVector
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.task_result import GateModelTaskResult

from braket.simulator_v2.julia_workers import (
_handle_julia_error,
_handle_julia_error, # noqa: PLC2701
translate_and_run,
translate_and_run_multiple,
)
from braket.task_result import GateModelTaskResult

__JULIA_POOL__ = None
if TYPE_CHECKING:
from collections.abc import Sequence

from braket.ir.openqasm import Program as OpenQASMProgram

__JULIA_POOL__ = None

def setup_julia():
import os
import sys

def setup_julia() -> None:
# don't reimport if we don't have to
if "juliacall" in sys.modules:
os.environ["PYTHON_JULIACALL_HANDLE_SIGNALS"] = "yes"
return
else:
for k, default in (
("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
("PYTHON_JULIACALL_THREADS", "auto"),
("PYTHON_JULIACALL_OPTLEVEL", "3"),
# let the user's Conda/Pip handle installing things
("JULIA_CONDAPKG_BACKEND", "Null"),
):
os.environ[k] = os.environ.get(k, default)

import juliacall

jl = juliacall.Main
jl.seval("using BraketSimulator, JSON3")
stock_oq3 = """
OPENQASM 3.0;
qubit[2] q;
h q[0];
cphaseshift(1.5707963267948966) q[1], q[0];
cnot q;
#pragma braket noise bit_flip(0.1) q[0]
#pragma braket result variance y(q[0])
#pragma braket result density_matrix q[0], q[1]
#pragma braket result probability
"""
jl.BraketSimulator.simulate("braket_dm_v2", stock_oq3, "{}", 0)
return
for k, default in (
("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
("PYTHON_JULIACALL_THREADS", "auto"),
("PYTHON_JULIACALL_OPTLEVEL", "3"),
# let the user's Conda/Pip handle installing things
("JULIA_CONDAPKG_BACKEND", "Null"),
):
os.environ[k] = os.environ.get(k, default)

import juliacall # noqa: PLC0415

jl = juliacall.Main
jl.seval("using BraketSimulator, JSON3")
stock_oq3 = """
OPENQASM 3.0;
qubit[2] q;
h q[0];
cphaseshift(1.5707963267948966) q[1], q[0];
cnot q;
#pragma braket noise bit_flip(0.1) q[0]
#pragma braket result variance y(q[0])
#pragma braket result density_matrix q[0], q[1]
#pragma braket result probability
"""
jl.BraketSimulator.simulate("braket_dm_v2", stock_oq3, "{}", 0)
return


def setup_pool():
def setup_pool() -> None:
global __JULIA_POOL__
__JULIA_POOL__ = Pool(processes=1)
__JULIA_POOL__.apply(setup_julia)
atexit.register(__JULIA_POOL__.join)
atexit.register(__JULIA_POOL__.close)
return


def _handle_mmaped_result(raw_result, mmap_paths, obj_lengths):
def _handle_mmaped_result(
raw_result: dict, mmap_paths: list, obj_lengths: list
) -> GateModelTaskResult:
result = GateModelTaskResult(**raw_result)
if mmap_paths:
mmap_files = mmap_paths
Expand All @@ -89,20 +94,19 @@ def _handle_mmaped_result(raw_result, mmap_paths, obj_lengths):


class BaseLocalSimulatorV2(BaseLocalSimulator):
def __init__(self, device: str):
global __JULIA_POOL__
def __init__(self, device: str) -> None:
if __JULIA_POOL__ is None:
setup_pool()
self._device = device

def initialize_simulation(self, **kwargs):
return
def initialize_simulation(self, **kwargs: dict) -> None:
pass

def run_openqasm(
self,
openqasm_ir: OpenQASMProgram,
shots: int = 0,
batch_size: int = 1, # unused
batch_size: int = 1, # noqa: ARG002
) -> GateModelTaskResult:
"""Executes the circuit specified by the supplied `openqasm_ir` on the simulator.

Expand All @@ -118,8 +122,7 @@ def run_openqasm(
ValueError: If result types are not specified in the IR or sample is specified
as a result type when shots=0. Or, if StateVector and Amplitude result types
are requested when shots>0.
"""
global __JULIA_POOL__
""" # noqa: DOC502
try:
jl_result = __JULIA_POOL__.apply(
translate_and_run,
Expand All @@ -142,9 +145,9 @@ def run_openqasm(
def run_multiple(
self,
programs: Sequence[OpenQASMProgram],
max_parallel: Optional[int] = -1,
shots: Optional[int] = 0,
inputs: Optional[Union[dict, Sequence[dict]]] = {},
max_parallel: int = -1, # noqa: ARG002
shots: int = 0,
inputs: dict | Sequence[dict] | None = None,
) -> list[GateModelTaskResult]:
"""
Run the tasks specified by the given IR programs.
Expand All @@ -154,11 +157,13 @@ def run_multiple(
programs (Sequence[OQ3Program]): The IR representations of the programs
max_parallel (Optional[int]): The maximum number of programs to run in parallel.
Default is the number of logical CPUs.

Returns:
list[GateModelTaskResult]: A list of result objects, with the ith object being
the result of the ith program.
"""
global __JULIA_POOL__
if inputs is None:
inputs = {}
try:
jl_results = __JULIA_POOL__.apply(
translate_and_run_multiple,
Expand All @@ -173,10 +178,7 @@ def run_multiple(
(loaded_result[r_ix], paths_and_lens[r_ix][0], paths_and_lens[r_ix][1])
for r_ix in range(len(loaded_result))
]
results = [
_handle_mmaped_result(*result_path_len)
for result_path_len in results_paths_lens
]
results = list(starmap(_handle_mmaped_result, results_paths_lens))
jl_results = None
for p_ix, program in enumerate(programs):
results[p_ix].additionalMetadata.action = program
Expand All @@ -198,11 +200,10 @@ def _result_value_to_ndarray(
with the pydantic specification for ResultTypeValues.
"""

def reconstruct_complex(v):
def reconstruct_complex(v: list | float) -> complex | float:
if isinstance(v, list):
return complex(v[0], v[1])
else:
return v
return v

for result_ind, result_type in enumerate(task_result.resultTypes):
# Amplitude
Expand All @@ -211,19 +212,17 @@ def reconstruct_complex(v):
task_result.resultTypes[result_ind].value = {
k: reconstruct_complex(v) for (k, v) in val.items()
}
if isinstance(result_type.type, StateVector):
elif isinstance(result_type.type, StateVector):
val = task_result.resultTypes[result_ind].value
if isinstance(val, list):
fixed_val = [reconstruct_complex(v) for v in val]
task_result.resultTypes[result_ind].value = np.asarray(fixed_val)
if isinstance(result_type.type, DensityMatrix):
val = task_result.resultTypes[result_ind].value
# complex are stored as tuples of reals
fixed_val = [
[reconstruct_complex(v) for v in inner_val] for inner_val in val
]
fixed_val = [[reconstruct_complex(v) for v in inner_val] for inner_val in val]
task_result.resultTypes[result_ind].value = np.asarray(fixed_val)
if isinstance(result_type.type, Probability):
elif isinstance(result_type.type, Probability):
val = task_result.resultTypes[result_ind].value
task_result.resultTypes[result_ind].value = np.asarray(val)

Expand Down
Loading
Loading