Skip to content

Commit

Permalink
Remove allocations for better performance from Python, cleanup (#48)
Browse files Browse the repository at this point in the history
* change: Make things more inferrable

* change: Use JSON3 for writing/reading to Python

* fix: Make matrix generation less allocating

* fix: Make matrix generation less allocating for custom gates too

* fix: Handle empty/nothing targets properly

* fix: Remove timing statements

* fix: Allow subtypes of abstract result supertypes

* fix: Remove timings and add inputs to precompile

* fix: More validation tests

* change: Bump version

* fix: Semgrep

* change: More tests and a fix for DoubleExcitation

* Update src/gate_kernels.jl

Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com>

---------

Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com>
  • Loading branch information
kshyatt-aws and rmshaffer authored Aug 27, 2024
1 parent 598c8b6 commit 579a0b9
Show file tree
Hide file tree
Showing 16 changed files with 443 additions and 242 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BraketSimulator"
uuid = "76d27892-9a0b-406c-98e4-7c178e9b3dff"
authors = ["Katharine Hyatt <hyatkath@amazon.com> and contributors"]
version = "0.0.3"
version = "0.0.4"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
42 changes: 32 additions & 10 deletions ext/BraketSimulatorPythonExt/BraketSimulatorPythonExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,45 @@ module BraketSimulatorPythonExt
using PrecompileTools

@recompile_invalidations begin
using BraketSimulator, PythonCall, JSON3
using BraketSimulator, BraketSimulator.Quasar, PythonCall, JSON3
end

using BraketSimulator: simulate
using BraketSimulator: simulate, AbstractProgramResult

function BraketSimulator.simulate(simulator, task_spec::String, inputs::Dict{String, Any}, shots::Int; kwargs...)
jl_specs = BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), task_spec, inputs)
jl_results = simulate(simulator, jl_specs, shots; kwargs...)
json = JSON3.write(jl_results)
return json
function BraketSimulator.simulate(simulator_id::String, task_spec::String, py_inputs::String, shots::Int; kwargs...)
inputs = JSON3.read(py_inputs, Dict{String, Any})
jl_spec = BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), task_spec, inputs)
simulator = if simulator_id == "braket_sv_v2"
StateVectorSimulator(0, shots)
elseif simulator_id == "braket_dm_v2"
DensityMatrixSimulator(0, shots)
end
jl_results = simulate(simulator, jl_spec, shots; kwargs...)
# this is expensive due to allocations
py_results = JSON3.write(jl_results)
simulator = nothing
inputs = nothing
jl_spec = nothing
jl_results = nothing
return py_results
end
function BraketSimulator.simulate(simulator, task_specs::Vector{String}, inputs::Vector{Dict{String, Any}}, shots::Int; kwargs...)
jl_specs = map(zip(task_specs, inputs)) do (task_spec, input)
BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), task_spec, input)
function BraketSimulator.simulate(simulator_id::String, task_specs::PyList, py_inputs::String, shots::Int; kwargs...)
inputs = JSON3.read(py_inputs, Vector{Dict{String, Any}})
jl_specs = map(zip(task_specs, inputs)) do (task_spec, input)
jl_spec = task_spec isa Py ? pyconvert(String, task_spec) : task_spec
BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), jl_spec, input)
end
simulator = if simulator_id == "braket_sv_v2"
StateVectorSimulator(0, shots)
elseif simulator_id == "braket_dm_v2"
DensityMatrixSimulator(0, shots)
end
jl_results = simulate(simulator, jl_specs, shots; kwargs...)
jsons = [JSON3.write(r) for r in jl_results]
simulator = nothing
jl_results = nothing
inputs = nothing
jl_specs = nothing
return jsons
end

Expand Down
109 changes: 61 additions & 48 deletions src/BraketSimulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ include("noise_kernels.jl")
include("Quasar.jl")
using .Quasar

const CHUNK_SIZE = 2^10
const LOG2_CHUNK_SIZE = 10
const CHUNK_SIZE = 2^LOG2_CHUNK_SIZE

function _index_to_endian_bits(ix::Int, qubit_count::Int)
bits = Vector{Int}(undef, qubit_count)
Expand Down Expand Up @@ -139,53 +140,52 @@ function _bundle_results(
end

function _generate_results(
results::Vector{<:AbstractProgramResult},
result_types::Vector,
result_types,
simulator::D,
) where {D<:AbstractSimulator}
result_values = map(result_type -> calculate(result_type, simulator), result_types)
final_results = Vector{ResultTypeValue}(undef, length(result_values))
for r_ix in 1:length(final_results)
final_results[r_ix] = ResultTypeValue(results[r_ix],
complex_matrix_to_ir(result_values[r_ix]))
ir_results = map(StructTypes.lower, result_types)
results = map(zip(ir_results, result_values)) do (ir, val)
ir_val = complex_matrix_to_ir(val)
return ResultTypeValue(ir, ir_val)
end
return final_results
return results
end

_translate_result_type(r::IR.Amplitude, qc::Int) = Amplitude(r.states)
_translate_result_type(r::IR.StateVector, qc::Int) = StateVector()
# The IR result types support `nothing` as a valid option for the `targets` field,
# however `Result`s represent this with an empty `QubitSet` for type
# stability reasons. Here we take a `nothing` value for `targets` and translate it
# to apply to all qubits.
_translate_result_type(r::IR.DensityMatrix, qc::Int) = isnothing(r.targets) ? DensityMatrix(collect(0:qc-1)) : DensityMatrix(r.targets)
_translate_result_type(r::IR.Probability, qc::Int) = isnothing(r.targets) ? Probability(collect(0:qc-1)) : Probability(r.targets)
_translate_result_type(r::IR.Amplitude) = Amplitude(r.states)
_translate_result_type(r::IR.StateVector) = StateVector()
_translate_result_type(r::IR.DensityMatrix) = DensityMatrix(r.targets)
_translate_result_type(r::IR.Probability) = Probability(r.targets)
for (RT, IRT) in ((:Expectation, :(IR.Expectation)), (:Variance, :(IR.Variance)), (:Sample, :(IR.Sample)))
@eval begin
function _translate_result_type(r::$IRT, qc::Int)
targets = isnothing(r.targets) ? collect(0:qc-1) : r.targets
obs = StructTypes.constructfrom(Observables.Observable, r.observable)
$RT(obs, QubitSet(targets))
function _translate_result_type(r::$IRT)
obs = StructTypes.constructfrom(Observables.Observable, r.observable)
$RT(obs, QubitSet(r.targets))
end
end
end
_translate_result_types(results::Vector{AbstractProgramResult}, qubit_count::Int) = map(result->_translate_result_type(result, qubit_count), results)
_translate_result_types(results::Vector{AbstractProgramResult}) = map(_translate_result_type, results)

function _compute_exact_results(d::AbstractSimulator, program::Program, qubit_count::Int)
result_types = _translate_result_types(program.results, qubit_count)
result_types = _translate_result_types(program.results)
_validate_result_types_qubits_exist(result_types, qubit_count)
return _generate_results(program.results, result_types, d)
return _generate_results(result_types, d)
end

function _compute_exact_results(d::AbstractSimulator, program::Circuit, qubit_count::Int)
_validate_result_types_qubits_exist(program.result_types, qubit_count)
return _generate_results(program.result_types, d)
end

"""
_get_measured_qubits(program::Program, qubit_count::Int) -> Vector{Int}
_get_measured_qubits(program, qubit_count::Int) -> Vector{Int}
Get the qubits measured by the program. If [`Measure`](@ref)
instructions are present in the program's instruction list,
their targets are used to form the list of measured qubits.
If not, all qubits from 0 to `qubit_count-1` are measured.
"""
function _get_measured_qubits(program::Program, qubit_count::Int)
function _get_measured_qubits(program, qubit_count::Int)
measure_inds = findall(ix->ix.operator isa Measure, program.instructions)
isempty(measure_inds) && return collect(0:qubit_count-1)
measure_ixs = program.instructions[measure_inds]
Expand All @@ -208,29 +208,29 @@ function _prepare_program(circuit_ir::OpenQasmProgram, inputs::Dict{String, <:An
_verify_openqasm_shots_observables(circuit, n_qubits)
basis_rotation_instructions!(circuit)
end
return Program(circuit), n_qubits
return circuit, n_qubits
end
"""
_prepare_program(circuit_ir::Program, inputs::Dict{String, <:Any}, shots::Int) -> (Program, Int)
Apply any `inputs` provided for the simulation. Return the `Program`
(with bound parameters) and the qubit count of the circuit.
"""
# nosemgrep
function _prepare_program(circuit_ir::Program, inputs::Dict{String, <:Any}, shots::Int)
function _prepare_program(circuit_ir::Program, inputs::Dict{String, <:Any}, shots::Int) # nosemgrep
operations::Vector{Instruction} = circuit_ir.instructions
symbol_inputs = Dict(Symbol(k) => v for (k, v) in inputs)
operations = [bind_value!(operation, symbol_inputs) for operation in operations]
qc = qubit_count(circuit_ir)
bound_program = Program(circuit_ir.braketSchemaHeader, operations, circuit_ir.results, circuit_ir.basis_rotation_instructions)
return bound_program, qubit_count(circuit_ir)
return bound_program, qc
end
"""
_combine_operations(program::Program, shots::Int) -> Program
_combine_operations(program, shots::Int) -> Program
Combine explicit instructions and basis rotation instructions (if necessary).
Validate that all operations are performed on qubits within `qubit_count`.
"""
function _combine_operations(program::Program, shots::Int)
function _combine_operations(program, shots::Int)
operations = program.instructions
if shots > 0 && !isempty(program.basis_rotation_instructions)
operations = vcat(operations, program.basis_rotation_instructions)
Expand All @@ -248,17 +248,19 @@ Compute the results once `simulator` has finished applying all the instructions.
the results array is populated with the parsed result types (to help the Braket SDK compute them from the sampled measurements)
and a placeholder zero value.
"""
function _compute_results(::Type{OpenQasmProgram}, simulator, program, n_qubits, shots) # nosemgrep
analytic_results = shots == 0 && !isnothing(program.results) && !isempty(program.results)
function _compute_results(simulator, program::Circuit, n_qubits, shots)
results = program.result_types
has_no_results = isnothing(results) || isempty(results)
analytic_results = shots == 0 && !has_no_results
if analytic_results
return _compute_exact_results(simulator, program, n_qubits)
elseif isnothing(program.results) || isempty(program.results)
elseif has_no_results
return ResultTypeValue[]
else
return ResultTypeValue[ResultTypeValue(result_type, 0.0) for result_type in program.results]
return ResultTypeValue[ResultTypeValue(StructTypes.lower(result_type), 0.0) for result_type in results]
end
end
function _compute_results(::Type{Program}, simulator, program, n_qubits, shots) # nosemgrep
function _compute_results(simulator, program::Program, n_qubits, shots)
analytic_results = shots == 0 && !isnothing(program.results) && !isempty(program.results)
if analytic_results
return _compute_exact_results(simulator, program, n_qubits)
Expand All @@ -272,6 +274,12 @@ function _validate_circuit_ir(simulator, circuit_ir::Program, qubit_count::Int,
_validate_shots_and_ir_results(shots, circuit_ir.results, qubit_count)
return
end
function _validate_circuit_ir(simulator, circuit_ir::Circuit, qubit_count::Int, shots::Int)
_validate_ir_results_compatibility(simulator, circuit_ir.result_types, Val(:JAQCD))
_validate_ir_instructions_compatibility(simulator, circuit_ir, Val(:JAQCD))
_validate_shots_and_ir_results(shots, circuit_ir.result_types, qubit_count)
return
end

"""
simulate(simulator::AbstractSimulator, circuit_ir::Union{OpenQasmProgram, Program}, shots::Int; kwargs...) -> GateModelTaskResult
Expand All @@ -296,7 +304,7 @@ function simulate(
reinit!(simulator, n_qubits, shots)
simulator = evolve!(simulator, operations)
measured_qubits = _get_measured_qubits(program, n_qubits)
results = _compute_results(T, simulator, program, n_qubits, shots)
results = _compute_results(simulator, program, n_qubits, shots)
return _bundle_results(results, circuit_ir, simulator, measured_qubits)
end

Expand Down Expand Up @@ -623,6 +631,7 @@ include("dm_simulator.jl")
"""
all_gates_qasm = """
OPENQASM 3.0;
input float theta;
bit[3] b;
qubit[3] q;
rx(0.1) q[0];
Expand Down Expand Up @@ -653,11 +662,11 @@ include("dm_simulator.jl")
swap q[0], q[1];
iswap q[0], q[1];
xx(6.249142469550989) q[0], q[1];
yy(6.249142469550989) q[0], q[1];
xy(6.249142469550989) q[0], q[1];
zz(6.249142469550989) q[0], q[1];
pswap(6.249142469550989) q[0], q[1];
xx(theta) q[0], q[1];
yy(theta) q[0], q[1];
xy(theta) q[0], q[1];
zz(theta) q[0], q[1];
pswap(theta) q[0], q[1];
ms(0.1, 0.2, 0.3) q[0], q[1];
cphaseshift(6.249142469550989) q[0], q[1];
Expand Down Expand Up @@ -701,7 +710,7 @@ include("dm_simulator.jl")
#pragma braket result sample x(q[0]) @ y(q[1])
"""
@compile_workload begin
using BraketSimulator, BraketSimulator.Quasar
using BraketSimulator, BraketSimulator.Quasar, BraketSimulator.StructTypes
simulator = StateVectorSimulator(5, 0)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), custom_qasm, nothing)
simulate(simulator, oq3_program, 100)
Expand Down Expand Up @@ -730,19 +739,23 @@ include("dm_simulator.jl")

sv_simulator = StateVectorSimulator(3, 0)
dm_simulator = DensityMatrixSimulator(3, 0)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), all_gates_qasm, nothing)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), all_gates_qasm, Dict("theta"=>0.665))
simulate(sv_simulator, oq3_program, 100)
simulate(dm_simulator, oq3_program, 100)

sv_simulator = StateVectorSimulator(2, 0)
dm_simulator = DensityMatrixSimulator(2, 0)
sv_oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), sv_exact_results_qasm, nothing)
dm_oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), dm_exact_results_qasm, nothing)
simulate(sv_simulator, sv_oq3_program, 0)
simulate(dm_simulator, dm_oq3_program, 0)
results = simulate(sv_simulator, sv_oq3_program, 0)
map(StructTypes.lower, results.resultTypes)
results = simulate(dm_simulator, dm_oq3_program, 0)
map(StructTypes.lower, results.resultTypes)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), shots_results_qasm, nothing)
simulate(sv_simulator, oq3_program, 10)
simulate(dm_simulator, oq3_program, 10)
results = simulate(sv_simulator, oq3_program, 10)
map(StructTypes.lower, results.resultTypes)
results = simulate(dm_simulator, oq3_program, 10)
map(StructTypes.lower, results.resultTypes)
end
end
end # module BraketSimulator
6 changes: 5 additions & 1 deletion src/circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ julia> qubit_count(c)
qubit_count(c::Circuit) = length(qubits(c))
qubit_count(p::Program) = length(qubits(p))

Base.convert(::Type{Program}, c::Circuit) = (basis_rotation_instructions!(c); return Program(braketSchemaHeader("braket.ir.jaqcd.program" ,"1"), c.instructions, ir.(c.result_types, Val(:JAQCD)), c.basis_rotation_instructions))
function Base.convert(::Type{Program}, c::Circuit) # nosemgrep
lowered_rts = map(StructTypes.lower, c.result_types)
header = braketSchemaHeader("braket.ir.jaqcd.program" ,"1")
return Program(header, c.instructions, lowered_rts, c.basis_rotation_instructions)
end
Program(c::Circuit) = convert(Program, c)

extract_observable(rt::ObservableResult) = rt.observable
Expand Down
29 changes: 11 additions & 18 deletions src/custom_gates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ mutable struct DoubleExcitation <: AngledGate{1}
new(angle, Float64(pow_exponent))
end
qubit_count(::Type{DoubleExcitation}) = 4
function matrix_rep_raw(g::DoubleExcitation)
cosϕ = cos(g.angle[1] / 2.0)
sinϕ = sin(g.angle[1] / 2.0)

mat = diagm(ones(ComplexF64, 16))
mat[4, 4] = cosϕ
mat[13, 13] = cosϕ
mat[4, 13] = -sinϕ
mat[13, 4] = sinϕ
function matrix_rep_raw(::DoubleExcitation, ϕ) # nosemgrep
sθ, cθ = sincos/2.0)
mat = diagm(ones(ComplexF64, 16))
mat[4, 4] =
mat[13, 13] =
mat[4, 13] = -
mat[13, 4] =
return SMatrix{16,16,ComplexF64}(mat)
end

Expand All @@ -24,11 +22,7 @@ mutable struct SingleExcitation <: AngledGate{1}
new(angle, Float64(pow_exponent))
end
qubit_count(::Type{SingleExcitation}) = 2
function matrix_rep_raw(g::SingleExcitation)
cosϕ = cos(g.angle[1] / 2.0)
sinϕ = sin(g.angle[1] / 2.0)
return SMatrix{4,4,ComplexF64}([1.0 0 0 0; 0 cosϕ sinϕ 0; 0 -sinϕ cosϕ 0; 0 0 0 1.0])
end
matrix_rep_raw(::SingleExcitation, ϕ) = ((sθ, cθ) = sincos/2.0); return SMatrix{4,4,ComplexF64}(complex(1.0), 0, 0, 0, 0, cθ, -sθ, 0, 0, sθ, cθ, 0, 0, 0, 0, complex(1.0)))
"""
MultiRz(angle)
Expand Down Expand Up @@ -95,13 +89,12 @@ function apply_gate!(
) where {T<:Complex}
n_amps, endian_ts = get_amps_and_qubits(state_vec, t1, t2, t3, t4)
ordered_ts = sort(collect(endian_ts))
cosϕ = cos(g.angle[1] / 2.0)
sinϕ = sin(g.angle[1] / 2.0)
sinϕ, cosϕ = sincos(g.angle[1] * g.pow_exponent / 2.0)
e_t1, e_t2, e_t3, e_t4 = endian_ts
Threads.@threads for ix = 0:div(n_amps, 2^4)-1
padded_ix = pad_bits(ix, ordered_ts)
i0011 = flip_bits(padded_ix, (e_t3, e_t4)) + 1
i1100 = flip_bits(padded_ix, (e_t1, e_t2)) + 1
i0011 = flip_bits(padded_ix, (e_t3, e_t4)) + 1
i1100 = flip_bits(padded_ix, (e_t1, e_t2)) + 1
@inbounds begin
amp0011 = state_vec[i0011]
amp1100 = state_vec[i1100]
Expand Down
Loading

2 comments on commit 579a0b9

@kshyatt-aws
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/113961

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.0.4 -m "<description of version>" 579a0b928563944163eab8dd40a1d28bd3916e87
git push origin v0.0.4

Please sign in to comment.