diff --git a/src/BraketSimulator.jl b/src/BraketSimulator.jl index 75a0be0..c140e11 100644 --- a/src/BraketSimulator.jl +++ b/src/BraketSimulator.jl @@ -329,7 +329,13 @@ end ) is_single_task = length(task_specs) == 1 is_single_input = inputs isa Dict || length(inputs) == 1 - is_single_input && is_single_task && return simulate(simulator, only(task_specs), only(task_args); inputs=inputs, kwargs...) + if is_single_input && is_single_task + if inputs isa Vector + return [simulate(simulator, only(task_specs), shots; inputs=only(inputs), kwargs...)] + else + return [simulate(simulator, only(task_specs), shots; inputs=inputs, kwargs...)] + end + end if is_single_input if inputs isa Dict inputs = [deepcopy(inputs) for ix in 1:length(task_specs)] diff --git a/src/Quasar.jl b/src/Quasar.jl index 31ce1e3..2114766 100644 --- a/src/Quasar.jl +++ b/src/Quasar.jl @@ -1106,9 +1106,9 @@ mutable struct QasmFunctionVisitor <: AbstractVisitor function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::Vector{QasmExpression}) v = new(parent, classical_defs(parent), - Dict{String, Qubit}(), - Dict{String, Vector{Int}}(), - 0, + deepcopy(parent.qubit_defs), + deepcopy(parent.qubit_mapping), + qubit_count(parent), Instruction[], ) arg_map = Dict(zip(declared_arguments, provided_arguments)) @@ -1411,7 +1411,8 @@ function evaluate(v::AbstractVisitor, expr::QasmExpression) end end end - push!(v, Instruction[remap(ix, reverse_qubits_map) for ix in function_v.instructions]) + remapped = isempty(reverse_qubits_map) ? function_v.instructions : Instruction[remap(ix, reverse_qubits_map) for ix in function_v.instructions] + push!(v, remapped) return return_val end else @@ -1741,7 +1742,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) v.classical_defs[var_name] = ClassicalVariable(var_name, var_type, v.classical_defs[var_name].val, true) elseif head(program_expr) == :qubit_declaration qubit_name::String = name(program_expr) - qubit_size::Int = program_expr.args[2].args[1] + qubit_size::Int = evaluate(v, program_expr.args[2]) qubit_defs(v)[qubit_name] = Qubit(qubit_name, qubit_size) qubit_mapping(v)[qubit_name] = collect(qubit_count(v) : qubit_count(v) + qubit_size - 1) for qubit_i in 0:qubit_size-1 diff --git a/test/test_braket_integration.jl b/test/test_braket_integration.jl index 9002d3e..b597034 100644 --- a/test/test_braket_integration.jl +++ b/test/test_braket_integration.jl @@ -592,11 +592,11 @@ end tasks = (bell_circ, (()->ir(bell_circ(), Val(:OpenQASM))), (()->ir(bell_circ(), Val(:JAQCD)))) device = DEVICE @testset for task in tasks - run_circuit(circuit) = result(device(circuit, shots = SHOTS)) - task_array = [task() for ii = 1:Threads.nthreads()] - futures = [Threads.@spawn run_circuit(c) for c in task_array] - future_results = fetch.(futures) - for r in future_results + run_circuit(circuit) = result(simulate(device, circuit; shots = SHOTS)) + batch_size = 5 + task_array = [task() for ii = 1:batch_size] + batch_results = results(simulate(device, task_array; shots=SHOTS)) + for r in batch_results @test isapprox( r.measurement_probabilities["00"], 0.5, diff --git a/test/test_openqasm3.jl b/test/test_openqasm3.jl index 5cd8287..3601154 100644 --- a/test/test_openqasm3.jl +++ b/test/test_openqasm3.jl @@ -1494,6 +1494,30 @@ get_tol(shots::Int) = return ( err_msg = "Invalid observable specified: [[[-6.0, 0.0], [2.0, 1.0], [-3.0, 0.0], [-5.0, 2.0]], [[2.0, -1.0], [0.0, 0.0], [2.0, -1.0], [-5.0, 4.0]], [[-3.0, 0.0], [2.0, 1.0], [0.0, 0.0], [-4.0, 3.0]], [[-5.0, -2.0], [-5.0, -4.0], [-4.0, -3.0], [-6.0, 0.0]]], targets: [0]" @test_throws Quasar.QasmVisitorError(err_msg, "ValueError") visitor(parsed) end + @testset "Qubits with variable as size" begin + qasm_string = """ + OPENQASM 3.0; + + def ghz(int[32] n) { + h q[0]; + for int i in [0:n - 1] { + cnot q[i], q[i + 1]; + } + } + + int[32] n = 5; + bit[n + 1] c; + qubit[n + 1] q; + + ghz(n); + + c = measure q; + """ + parsed = parse_qasm(qasm_string) + visitor = Quasar.QasmProgramVisitor() + visitor(parsed) + @test Quasar.qubit_count(visitor) == 6 + end @testset "String inputs" begin qasm_string = """ const int[8] n = 4; diff --git a/test/test_python_ext.jl b/test/test_python_ext.jl index 66bcf58..37c1ae0 100644 --- a/test/test_python_ext.jl +++ b/test/test_python_ext.jl @@ -166,6 +166,13 @@ using PythonCall: pyconvert for oq3_result in oq3_results @test pyconvert(Vector{Float64}, oq3_result.resultTypes[0].value) ≠ pyconvert(Vector{Float64}, oq3_result.resultTypes[1].value) end + # test a "batch" of length 1 + sv_simulator = StateVectorSimulator(n_qubits, 0) + py_inputs = PyList{Any}([pydict(Dict("a_in"=>2, "b_in"=>5))]) + oq3_results = simulate(sv_simulator, PyList{Any}([oq3_program]); inputs=py_inputs, shots=0) + for oq3_result in oq3_results + @test pyconvert(Vector{Float64}, oq3_result.resultTypes[0].value) ≠ pyconvert(Vector{Float64}, oq3_result.resultTypes[1].value) + end end end end diff --git a/test/test_sv_simulator.jl b/test/test_sv_simulator.jl index 094eff8..ecb0106 100644 --- a/test/test_sv_simulator.jl +++ b/test/test_sv_simulator.jl @@ -358,15 +358,16 @@ LARGE_TESTS = get(ENV, "BRAKET_SV_LARGE_TESTS", false) return ir(ghz) end num_qubits = 5 - n_circuits = 100 - shots = 1000 - jl_ghz = [make_ghz(num_qubits) for ix in 1:n_circuits] - jl_sim = StateVectorSimulator(num_qubits, 0); - results = simulate(jl_sim, jl_ghz, shots) - for (r_ix, r) in enumerate(results) - @test length(r.measurements) == shots - @test 400 < count(m->m == fill(0, num_qubits), r.measurements) < 600 - @test 400 < count(m->m == fill(1, num_qubits), r.measurements) < 600 + @testset for n_circuits in (1, 100) + shots = 1000 + jl_ghz = [make_ghz(num_qubits) for ix in 1:n_circuits] + jl_sim = StateVectorSimulator(num_qubits, 0); + results = simulate(jl_sim, jl_ghz, shots) + for (r_ix, r) in enumerate(results) + @test length(r.measurements) == shots + @test 400 < count(m->m == fill(0, num_qubits), r.measurements) < 600 + @test 400 < count(m->m == fill(1, num_qubits), r.measurements) < 600 + end end end @testset "similar, copy and copyto!" begin