From df54745b3a35938b4dc52043bc8b6d74d55e1f5c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <67932820+kshyatt-aws@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:49:30 -0400 Subject: [PATCH] fix: Memory frugal density matrix computation for state vectors (#49) --- src/dm_simulator.jl | 1 + src/result_types.jl | 7 +++--- src/sv_simulator.jl | 47 +++++++++++++++++++++++++++++++++++++++ test/test_sv_simulator.jl | 11 ++++++++- 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/dm_simulator.jl b/src/dm_simulator.jl index f2792c4..dbcaf19 100644 --- a/src/dm_simulator.jl +++ b/src/dm_simulator.jl @@ -318,3 +318,4 @@ function partial_trace( end return final_ρ end +partial_trace(sim::DensityMatrixSimulator, output_qubits = collect(0:qubit_count(sim)-1)) = partial_trace(density_matrix(sim), output_qubits) diff --git a/src/result_types.jl b/src/result_types.jl index 15dccc1..b0547d4 100644 --- a/src/result_types.jl +++ b/src/result_types.jl @@ -238,10 +238,9 @@ function calculate(variance::Variance, sim::AbstractSimulator) end function calculate(dm::DensityMatrix, sim::AbstractSimulator) - ρ = density_matrix(sim) full_qubits = collect(0:qubit_count(sim)-1) - (collect(dm.targets) == full_qubits || isnothing(dm.targets) || isempty(dm.targets)) && return ρ - length(dm.targets) == sim.qubit_count && return permute_density_matrix(ρ, sim.qubit_count, collect(dm.targets)) + (collect(dm.targets) == full_qubits || isnothing(dm.targets) || isempty(dm.targets)) && return density_matrix(sim) + length(dm.targets) == sim.qubit_count && return permute_density_matrix(density_matrix(sim), sim.qubit_count, collect(dm.targets)) # otherwise must compute a partial trace - return partial_trace(ρ, dm.targets) + return partial_trace(sim, dm.targets) end diff --git a/src/sv_simulator.jl b/src/sv_simulator.jl index 67bbbc3..572d8f2 100644 --- a/src/sv_simulator.jl +++ b/src/sv_simulator.jl @@ -272,3 +272,50 @@ end Compute the observation probabilities of all amplitudes in the state vector in `svs`. """ probabilities(svs::StateVectorSimulator) = abs2.(svs.state_vector) + +function partial_trace( + sv::AbstractVector{ComplexF64}, + output_qubits = collect(0:Int(log2(size(sv, 1)))-1), +) + isempty(output_qubits) && return mapreduce(abs2, +, sv) + n_amps = length(sv) + n_qubits = Int(log2(n_amps)) + length(unique(output_qubits)) == n_qubits && return kron(sv, adjoint(sv)) + + qubits = setdiff(collect(0:n_qubits-1), output_qubits) + endian_qubits = sort(n_qubits .- qubits .- 1) + qubit_combos = vcat([Int[]], collect(combinations(endian_qubits))) + final_ρ_dim = 2^(n_qubits - length(qubits)) + final_ρ = zeros(ComplexF64, final_ρ_dim, final_ρ_dim) + # handle possibly permuted targets + needs_perm = !issorted(output_qubits) + final_n_qubits = length(output_qubits) + output_qubit_mapping = if needs_perm + original_outputs = final_n_qubits .- output_qubits .- 1 + permuted_outputs = final_n_qubits .- collect(0:final_n_qubits-1) .- 1 + Dict(zip(original_outputs, permuted_outputs)) + else + Dict{Int,Int}() + end + for ix = 0:final_ρ_dim-1, jx = ix:final_ρ_dim-1 + padded_ix = pad_bits(ix, endian_qubits) + padded_jx = pad_bits(jx, endian_qubits) + flipped_ixs = Vector{Int}(undef, length(qubit_combos)) + flipped_jxs = Vector{Int}(undef, length(qubit_combos)) + for (c_ix, flipped_qubits) in enumerate(qubit_combos) + flipped_ixs[c_ix] = flip_bits(padded_ix, flipped_qubits) + 1 + flipped_jxs[c_ix] = flip_bits(padded_jx, flipped_qubits) + 1 + end + # if the output qubits weren't in sorted order, we need to permute the + # final indices of ρ to match the desired qubit mapping + out_ix = needs_perm ? swap_bits(ix, output_qubit_mapping) : ix + out_jx = needs_perm ? swap_bits(jx, output_qubit_mapping) : jx + traced = @inbounds dot(sv[flipped_jxs], sv[flipped_ixs]) + @inbounds final_ρ[out_ix+1, out_jx+1] += traced + if out_jx != out_ix + @inbounds final_ρ[out_jx+1, out_ix+1] += conj(traced) + end + end + return final_ρ +end +partial_trace(sim::StateVectorSimulator, output_qubits = collect(0:qubit_count(sim)-1)) = partial_trace(state_vector(sim), output_qubits) diff --git a/test/test_sv_simulator.jl b/test/test_sv_simulator.jl index 25b0e6d..f481e62 100644 --- a/test/test_sv_simulator.jl +++ b/test/test_sv_simulator.jl @@ -1,4 +1,4 @@ -using Test, BraketSimulator, DataStructures +using Test, BraketSimulator, DataStructures, LinearAlgebra, BraketSimulator.Combinatorics LARGE_TESTS = get(ENV, "BRAKET_SIM_LARGE_TESTS", "false") == "true" @@ -708,6 +708,15 @@ LARGE_TESTS = get(ENV, "BRAKET_SIM_LARGE_TESTS", "false") == "true" @test new_sv_props.paradigm.qubitCount == new_sv_qubit_count @test BraketSimulator.supported_result_types(sim) == BraketSimulator.supported_result_types(sim, Val(:OpenQASM)) end + @testset "partial trace $nq" for nq in 3:6 + ψ = normalize(rand(ComplexF64, 2^nq)) + full_ρ = kron(ψ, adjoint(ψ)) + @testset "output qubits $q" for q in combinations(0:nq-1) + ρ = BraketSimulator.partial_trace(ψ, q) + full_pt = BraketSimulator.partial_trace(full_ρ, q) + @test full_pt ≈ ρ + end + end @testset "inputs handling" begin sv_adder_qasm = """ OPENQASM 3;