Skip to content

Commit

Permalink
Even more small cleanups (#22)
Browse files Browse the repository at this point in the history
* Case mini cleanup

* Small cleanup for gphase
  • Loading branch information
kshyatt-aws authored Nov 27, 2024
1 parent 2d8e369 commit 3a3b69e
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions src/visitor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,14 @@ function handle_gate_modifiers(ixs, mods::Vector{QasmExpression}, control_qubits
if head(mod) (:negctrl, :ctrl)
control_qubit = pop!(control_qubits)
for (ii, ix) in enumerate(ixs)
exp = ix.exponent
targets = ix.targets
controls = ix.controls
bit = head(mod) == :ctrl ? 1 : 0
controls = pushfirst!(controls, control_qubit=>bit)
if !is_gphase
targets = pushfirst!(targets, control_qubit)
end
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=targets, controls=controls, exponent=exp)
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=targets, controls=controls, exponent=ix.exponent)
end
elseif head(mod) == :inv
reverse!(ixs)
Expand All @@ -400,9 +399,9 @@ function handle_gate_modifiers(ixs, mods::Vector{QasmExpression}, control_qubits
end

function splat_gate_targets(gate_targets::Vector{Vector{Int}})
target_lengths::Vector{Int} = Int[length(t) for t in gate_targets]
longest = maximum(target_lengths)
must_splat::Bool = any(len->len!=1 || len != longest, target_lengths)
target_lengths::Vector{Int} = map(length, gate_targets)::Vector{Int}
longest = maximum(target_lengths)
must_splat::Bool = any(len->len!=1 || len != longest, target_lengths)
!must_splat && return longest, gate_targets
for target_ix in filter(ix->target_lengths[ix] == 1, 1:length(gate_targets))
append!(gate_targets[target_ix], fill(only(gate_targets[target_ix]), longest-1))
Expand All @@ -415,11 +414,11 @@ function visit_gphase_call(v::AbstractVisitor, program_expr::QasmExpression)
n_called_with::Int = qubit_count(v)
gate_targets::Vector{Int} = collect(0:n_called_with-1)
provided_arg::QasmExpression = only(program_expr.args[2].args)
evaled_arg = v(provided_arg)
applied_arguments = CircuitInstruction[(type="gphase", arguments=[evaled_arg], targets=gate_targets, controls=Pair{Int,Int}[], exponent=1.0)]
mods::Vector{QasmExpression} = has_modifiers ? program_expr.args[4].args : QasmExpression[]
applied_arguments = handle_gate_modifiers(applied_arguments, mods, deepcopy(gate_targets), true)
target_mapper = Dict{Int, Int}(g_ix=>gate_targets[g_ix+1][1] for g_ix in 0:n_called_with-1)
evaled_arg = v(provided_arg)
applied_arguments = CircuitInstruction[(type="gphase", arguments=[evaled_arg], targets=gate_targets, controls=Pair{Int,Int}[], exponent=1.0)]
mods::Vector{QasmExpression} = length(program_expr.args) == 4 ? program_expr.args[4].args : QasmExpression[]
applied_arguments = handle_gate_modifiers(applied_arguments, mods, deepcopy(gate_targets), true)
target_mapper = Dict{Int, Int}(g_ix=>gate_targets[g_ix+1][1] for g_ix in 0:n_called_with-1)
push!(v, map(ix->remap(ix, target_mapper), applied_arguments))
return
end
Expand Down Expand Up @@ -601,12 +600,10 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
all_cases = convert(Vector{QasmExpression}, program_expr.args[2:end])
default = findfirst(expr->head(expr) == :default, all_cases)
case_found = false
for case in all_cases
if head(case) == :case && case_val v(case.args[1])
case_found = true
foreach(v, convert(Vector{QasmExpression}, case.args[2:end]))
break
end
for case in filter(case->head(case) == :case && case_val v(case.args[1]), all_cases)
case_found = true
foreach(v, convert(Vector{QasmExpression}, case.args[2:end]))
break
end
if !case_found
isnothing(default) && throw(QasmVisitorError("no case matched and no default defined."))
Expand Down Expand Up @@ -809,7 +806,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
end
v.qubit_count += qubit_size
elseif head(program_expr) (:power_mod, :inverse_mod, :control_mod, :negctrl_mod)
mods = QasmExpression(:modifiers)
mods = QasmExpression(:modifiers)
mod_expr, inner = evaluate_modifiers(v, program_expr)
push!(mods, mod_expr)
while head(inner) != :gate_call # done
Expand Down

0 comments on commit 3a3b69e

Please sign in to comment.