Skip to content

Commit

Permalink
fix: Memory frugal density matrix computation for state vectors (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws authored Sep 3, 2024
1 parent 579a0b9 commit df54745
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/dm_simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,4 @@ function partial_trace(
end
return final_ρ
end
partial_trace(sim::DensityMatrixSimulator, output_qubits = collect(0:qubit_count(sim)-1)) = partial_trace(density_matrix(sim), output_qubits)
7 changes: 3 additions & 4 deletions src/result_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,9 @@ function calculate(variance::Variance, sim::AbstractSimulator)
end

function calculate(dm::DensityMatrix, sim::AbstractSimulator)
ρ = density_matrix(sim)
full_qubits = collect(0:qubit_count(sim)-1)
(collect(dm.targets) == full_qubits || isnothing(dm.targets) || isempty(dm.targets)) && return ρ
length(dm.targets) == sim.qubit_count && return permute_density_matrix(ρ, sim.qubit_count, collect(dm.targets))
(collect(dm.targets) == full_qubits || isnothing(dm.targets) || isempty(dm.targets)) && return density_matrix(sim)
length(dm.targets) == sim.qubit_count && return permute_density_matrix(density_matrix(sim), sim.qubit_count, collect(dm.targets))
# otherwise must compute a partial trace
return partial_trace(ρ, dm.targets)
return partial_trace(sim, dm.targets)
end
47 changes: 47 additions & 0 deletions src/sv_simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,50 @@ end
Compute the observation probabilities of all amplitudes in the state vector in `svs`.
"""
probabilities(svs::StateVectorSimulator) = abs2.(svs.state_vector)

function partial_trace(
sv::AbstractVector{ComplexF64},
output_qubits = collect(0:Int(log2(size(sv, 1)))-1),
)
isempty(output_qubits) && return mapreduce(abs2, +, sv)
n_amps = length(sv)
n_qubits = Int(log2(n_amps))
length(unique(output_qubits)) == n_qubits && return kron(sv, adjoint(sv))

qubits = setdiff(collect(0:n_qubits-1), output_qubits)
endian_qubits = sort(n_qubits .- qubits .- 1)
qubit_combos = vcat([Int[]], collect(combinations(endian_qubits)))
final_ρ_dim = 2^(n_qubits - length(qubits))
final_ρ = zeros(ComplexF64, final_ρ_dim, final_ρ_dim)
# handle possibly permuted targets
needs_perm = !issorted(output_qubits)
final_n_qubits = length(output_qubits)
output_qubit_mapping = if needs_perm
original_outputs = final_n_qubits .- output_qubits .- 1
permuted_outputs = final_n_qubits .- collect(0:final_n_qubits-1) .- 1
Dict(zip(original_outputs, permuted_outputs))
else
Dict{Int,Int}()
end
for ix = 0:final_ρ_dim-1, jx = ix:final_ρ_dim-1
padded_ix = pad_bits(ix, endian_qubits)
padded_jx = pad_bits(jx, endian_qubits)
flipped_ixs = Vector{Int}(undef, length(qubit_combos))
flipped_jxs = Vector{Int}(undef, length(qubit_combos))
for (c_ix, flipped_qubits) in enumerate(qubit_combos)
flipped_ixs[c_ix] = flip_bits(padded_ix, flipped_qubits) + 1
flipped_jxs[c_ix] = flip_bits(padded_jx, flipped_qubits) + 1
end
# if the output qubits weren't in sorted order, we need to permute the
# final indices of ρ to match the desired qubit mapping
out_ix = needs_perm ? swap_bits(ix, output_qubit_mapping) : ix
out_jx = needs_perm ? swap_bits(jx, output_qubit_mapping) : jx
traced = @inbounds dot(sv[flipped_jxs], sv[flipped_ixs])
@inbounds final_ρ[out_ix+1, out_jx+1] += traced
if out_jx != out_ix
@inbounds final_ρ[out_jx+1, out_ix+1] += conj(traced)
end
end
return final_ρ
end
partial_trace(sim::StateVectorSimulator, output_qubits = collect(0:qubit_count(sim)-1)) = partial_trace(state_vector(sim), output_qubits)
11 changes: 10 additions & 1 deletion test/test_sv_simulator.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test, BraketSimulator, DataStructures
using Test, BraketSimulator, DataStructures, LinearAlgebra, BraketSimulator.Combinatorics

LARGE_TESTS = get(ENV, "BRAKET_SIM_LARGE_TESTS", "false") == "true"

Expand Down Expand Up @@ -708,6 +708,15 @@ LARGE_TESTS = get(ENV, "BRAKET_SIM_LARGE_TESTS", "false") == "true"
@test new_sv_props.paradigm.qubitCount == new_sv_qubit_count
@test BraketSimulator.supported_result_types(sim) == BraketSimulator.supported_result_types(sim, Val(:OpenQASM))
end
@testset "partial trace $nq" for nq in 3:6
ψ = normalize(rand(ComplexF64, 2^nq))
full_ρ = kron(ψ, adjoint(ψ))
@testset "output qubits $q" for q in combinations(0:nq-1)
ρ = BraketSimulator.partial_trace(ψ, q)
full_pt = BraketSimulator.partial_trace(full_ρ, q)
@test full_pt ρ
end
end
@testset "inputs handling" begin
sv_adder_qasm = """
OPENQASM 3;
Expand Down

0 comments on commit df54745

Please sign in to comment.