Skip to content

Commit

Permalink
fix: Issues uncovered in UAT (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws authored Jul 5, 2024
1 parent 1c4da02 commit 458596c
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 20 deletions.
8 changes: 7 additions & 1 deletion src/BraketSimulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,13 @@ end
)
is_single_task = length(task_specs) == 1
is_single_input = inputs isa Dict || length(inputs) == 1
is_single_input && is_single_task && return simulate(simulator, only(task_specs), only(task_args); inputs=inputs, kwargs...)
if is_single_input && is_single_task
if inputs isa Vector
return [simulate(simulator, only(task_specs), shots; inputs=only(inputs), kwargs...)]
else
return [simulate(simulator, only(task_specs), shots; inputs=inputs, kwargs...)]
end
end
if is_single_input
if inputs isa Dict
inputs = [deepcopy(inputs) for ix in 1:length(task_specs)]
Expand Down
11 changes: 6 additions & 5 deletions src/Quasar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,9 +1106,9 @@ mutable struct QasmFunctionVisitor <: AbstractVisitor
function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::Vector{QasmExpression})
v = new(parent,
classical_defs(parent),
Dict{String, Qubit}(),
Dict{String, Vector{Int}}(),
0,
deepcopy(parent.qubit_defs),
deepcopy(parent.qubit_mapping),
qubit_count(parent),
Instruction[],
)
arg_map = Dict(zip(declared_arguments, provided_arguments))
Expand Down Expand Up @@ -1411,7 +1411,8 @@ function evaluate(v::AbstractVisitor, expr::QasmExpression)
end
end
end
push!(v, Instruction[remap(ix, reverse_qubits_map) for ix in function_v.instructions])
remapped = isempty(reverse_qubits_map) ? function_v.instructions : Instruction[remap(ix, reverse_qubits_map) for ix in function_v.instructions]
push!(v, remapped)
return return_val
end
else
Expand Down Expand Up @@ -1741,7 +1742,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
v.classical_defs[var_name] = ClassicalVariable(var_name, var_type, v.classical_defs[var_name].val, true)
elseif head(program_expr) == :qubit_declaration
qubit_name::String = name(program_expr)
qubit_size::Int = program_expr.args[2].args[1]
qubit_size::Int = evaluate(v, program_expr.args[2])
qubit_defs(v)[qubit_name] = Qubit(qubit_name, qubit_size)
qubit_mapping(v)[qubit_name] = collect(qubit_count(v) : qubit_count(v) + qubit_size - 1)
for qubit_i in 0:qubit_size-1
Expand Down
10 changes: 5 additions & 5 deletions test/test_braket_integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,11 @@ end
tasks = (bell_circ, (()->ir(bell_circ(), Val(:OpenQASM))), (()->ir(bell_circ(), Val(:JAQCD))))
device = DEVICE
@testset for task in tasks
run_circuit(circuit) = result(device(circuit, shots = SHOTS))
task_array = [task() for ii = 1:Threads.nthreads()]
futures = [Threads.@spawn run_circuit(c) for c in task_array]
future_results = fetch.(futures)
for r in future_results
run_circuit(circuit) = result(simulate(device, circuit; shots = SHOTS))
batch_size = 5
task_array = [task() for ii = 1:batch_size]
batch_results = results(simulate(device, task_array; shots=SHOTS))
for r in batch_results
@test isapprox(
r.measurement_probabilities["00"],
0.5,
Expand Down
24 changes: 24 additions & 0 deletions test/test_openqasm3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,30 @@ get_tol(shots::Int) = return (
err_msg = "Invalid observable specified: [[[-6.0, 0.0], [2.0, 1.0], [-3.0, 0.0], [-5.0, 2.0]], [[2.0, -1.0], [0.0, 0.0], [2.0, -1.0], [-5.0, 4.0]], [[-3.0, 0.0], [2.0, 1.0], [0.0, 0.0], [-4.0, 3.0]], [[-5.0, -2.0], [-5.0, -4.0], [-4.0, -3.0], [-6.0, 0.0]]], targets: [0]"
@test_throws Quasar.QasmVisitorError(err_msg, "ValueError") visitor(parsed)
end
@testset "Qubits with variable as size" begin
qasm_string = """
OPENQASM 3.0;
def ghz(int[32] n) {
h q[0];
for int i in [0:n - 1] {
cnot q[i], q[i + 1];
}
}
int[32] n = 5;
bit[n + 1] c;
qubit[n + 1] q;
ghz(n);
c = measure q;
"""
parsed = parse_qasm(qasm_string)
visitor = Quasar.QasmProgramVisitor()
visitor(parsed)
@test Quasar.qubit_count(visitor) == 6
end
@testset "String inputs" begin
qasm_string = """
const int[8] n = 4;
Expand Down
7 changes: 7 additions & 0 deletions test/test_python_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ using PythonCall: pyconvert
for oq3_result in oq3_results
@test pyconvert(Vector{Float64}, oq3_result.resultTypes[0].value) pyconvert(Vector{Float64}, oq3_result.resultTypes[1].value)
end
# test a "batch" of length 1
sv_simulator = StateVectorSimulator(n_qubits, 0)
py_inputs = PyList{Any}([pydict(Dict("a_in"=>2, "b_in"=>5))])
oq3_results = simulate(sv_simulator, PyList{Any}([oq3_program]); inputs=py_inputs, shots=0)
for oq3_result in oq3_results
@test pyconvert(Vector{Float64}, oq3_result.resultTypes[0].value) pyconvert(Vector{Float64}, oq3_result.resultTypes[1].value)
end
end
end
end
Expand Down
19 changes: 10 additions & 9 deletions test/test_sv_simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,16 @@ LARGE_TESTS = get(ENV, "BRAKET_SV_LARGE_TESTS", false)
return ir(ghz)
end
num_qubits = 5
n_circuits = 100
shots = 1000
jl_ghz = [make_ghz(num_qubits) for ix in 1:n_circuits]
jl_sim = StateVectorSimulator(num_qubits, 0);
results = simulate(jl_sim, jl_ghz, shots)
for (r_ix, r) in enumerate(results)
@test length(r.measurements) == shots
@test 400 < count(m->m == fill(0, num_qubits), r.measurements) < 600
@test 400 < count(m->m == fill(1, num_qubits), r.measurements) < 600
@testset for n_circuits in (1, 100)
shots = 1000
jl_ghz = [make_ghz(num_qubits) for ix in 1:n_circuits]
jl_sim = StateVectorSimulator(num_qubits, 0);
results = simulate(jl_sim, jl_ghz, shots)
for (r_ix, r) in enumerate(results)
@test length(r.measurements) == shots
@test 400 < count(m->m == fill(0, num_qubits), r.measurements) < 600
@test 400 < count(m->m == fill(1, num_qubits), r.measurements) < 600
end
end
end
@testset "similar, copy and copyto!" begin
Expand Down

0 comments on commit 458596c

Please sign in to comment.