Skip to content

Commit

Permalink
feature: Add parsing support for barrier/reset/delay/duration (#52)
Browse files Browse the repository at this point in the history
* change: Add parsing for delay

* change: Add Reset, Barrier, Delay instructions and tests

* change: Test for Barrier

* change: Tests for other no-ops and reorg

* change: More tests for no-ops

* fix: Support for mu-s, dt

* fix: Include new docstrings
  • Loading branch information
kshyatt-aws authored Sep 25, 2024
1 parent 23991af commit 4780fdf
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 39 deletions.
3 changes: 3 additions & 0 deletions docs/src/circuits.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ BraketSimulator.Circuit
BraketSimulator.Operator
BraketSimulator.QuantumOperator
BraketSimulator.FreeParameter
BraketSimulator.Reset
BraketSimulator.Barrier
BraketSimulator.Delay
BraketSimulator.Measure
BraketSimulator.Instruction
BraketSimulator.QubitSet
Expand Down
98 changes: 70 additions & 28 deletions src/Quasar.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Quasar

using ..BraketSimulator
using Automa, AbstractTrees, DataStructures
using Automa, AbstractTrees, DataStructures, Dates
using DataStructures: Stack
using BraketSimulator: Control, Instruction, Result, bind_value!, remap, qubit_count, Circuit

Expand Down Expand Up @@ -71,9 +71,7 @@ const qasm_tokens = [
:arrow_token => re"->",
:reset_token => re"reset",
:delay_token => re"delay",
:stretch_token => re"stretch",
:barrier_token => re"barrier",
:duration_token => re"duration",
:barrier_token => re"barrier",
:void => re"void",
:const_token => re"const",
:assignment => re"=|-=|\+=|\*=|/=|^=|&=|\|=|<<=|>>=",
Expand Down Expand Up @@ -110,12 +108,18 @@ const qasm_tokens = [
:string_token => '"' * rep(re"[ !#-~]" | re"\\\\\"") * '"' | '\'' * rep(re"[ -&(-~]" | ('\\' * re"[ -~]")) * '\'',
:newline => re"\r?\n",
:spaces => re"[\t ]+",
:durationof_token => re"durationof", # this MUST be lower than duration_token to preempt duration
:classical_type => re"bool|uint|int|float|angle|complex|array|bit",
:duration_value => (float | integer) * re"ns|µs|us|ms|s",
:classical_type => re"bool|uint|int|float|angle|complex|array|bit|stretch|duration",
:durationof_token => re"durationof", # this MUST be lower than classical_type to preempt duration
:duration_literal => (float | integer) * re"dt|ns|us|ms|s|\xce\xbc\x73", # transcode'd μs
:forbidden_keyword => re"cal|defcal|extern",
]

const dt_type = Ref{DataType}()

function __init__()
dt_type[] = Nanosecond
end

@eval @enum Token error $(first.(qasm_tokens)...)
make_tokenizer((error,
[Token(i) => j for (i,j) in enumerate(last.(qasm_tokens))]
Expand Down Expand Up @@ -369,6 +373,15 @@ function parse_classical_type(tokens, stack, start, qasm)
first(array_tokens)[end] == comma && popfirst!(array_tokens)
size = parse_expression(array_tokens, stack, start, qasm)
return QasmExpression(:classical_type, SizedArray(eltype, size))
elseif var_type == "duration"
@warn "duration expression encountered -- currently `duration` is a no-op"
# TODO: add proper parsing of duration expressions, including
# support for units and algebraic durations like 2*a.
return QasmExpression(:classical_type, :duration)
elseif var_type == "stretch"
@warn "stretch expression encountered -- currently `stretch` is a no-op"
# TODO: add proper parsing of stretch expressions
return QasmExpression(:classical_type, :stretch)
else
!any(triplet->triplet[end] == semicolon, tokens) && push!(tokens, (-1, Int32(-1), semicolon))
size = is_sized ? parse_expression(tokens, stack, start, qasm) : QasmExpression(:integer_literal, -1)
Expand Down Expand Up @@ -427,6 +440,21 @@ parse_oct_literal(token, qasm) = QasmExpression(:integer_literal, tryparse(I
parse_bin_literal(token, qasm) = QasmExpression(:integer_literal, tryparse(Int, qasm[token[1]:token[1]+token[2]-1]))
parse_float_literal(token, qasm) = QasmExpression(:float_literal, tryparse(Float64, qasm[token[1]:token[1]+token[2]-1]))
parse_boolean_literal(token, qasm) = QasmExpression(:boolean_literal, tryparse(Bool, qasm[token[1]:token[1]+token[2]-1]))
function parse_duration_literal(token, qasm)
str = String(codeunits(qasm)[token[1]:token[1]+token[2]-1])
duration = if endswith(str, "ns")
Nanosecond(tryparse(Int, chop(str, tail=2)))
elseif endswith(str, "ms")
Millisecond(tryparse(Int, chop(str, tail=2)))
elseif endswith(str, "us") || endswith(str, "μs")
Microsecond(tryparse(Int, chop(str, tail=2)))
elseif endswith(str, "s")
Second(tryparse(Int, chop(str, tail=1)))
elseif endswith(str, "dt")
dt_type[](tryparse(Int, chop(str, tail=2)))
end
QasmExpression(:duration_literal, duration)
end
function parse_irrational_literal(token, qasm)
raw_string = String(codeunits(qasm)[token[1]:token[1]+token[2]-1])
raw_string == "pi" && return QasmExpression(:irrational_literal, π)
Expand Down Expand Up @@ -463,7 +491,7 @@ function extract_braced_block(tokens::Vector{Tuple{Int64, Int32, Token}}, stack,
next_token[end] == rbracket && (closers_met += 1)
push!(braced_tokens, next_token)
end
pop!(braced_tokens) # closing }
pop!(braced_tokens) # closing ]
push!(braced_tokens, (-1, Int32(-1), semicolon))
return braced_tokens
end
Expand Down Expand Up @@ -511,6 +539,7 @@ function parse_list_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack
end

function parse_literal(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm)
tokens[1][end] == duration_literal && return parse_duration_literal(popfirst!(tokens), qasm)
tokens[1][end] == string_token && return parse_string_literal(popfirst!(tokens), qasm)
tokens[1][end] == hex && return parse_hex_literal(popfirst!(tokens), qasm)
tokens[1][end] == oct && return parse_oct_literal(popfirst!(tokens), qasm)
Expand Down Expand Up @@ -634,7 +663,7 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta
token_name = parse_bracketed_expression(pushfirst!(tokens, start_token), stack, start, qasm)
elseif start_token[end] == classical_type
token_name = parse_classical_type(pushfirst!(tokens, start_token), stack, start, qasm)
elseif start_token[end] (string_token, integer_token, float_token, hex, oct, bin, irrational, dot, boolean)
elseif start_token[end] (string_token, integer_token, float_token, hex, oct, bin, irrational, dot, boolean, duration_literal)
token_name = parse_literal(pushfirst!(tokens, start_token), stack, start, qasm)
elseif start_token[end] (mutable, readonly, const_token)
token_name = parse_identifier(start_token, qasm)
Expand All @@ -644,7 +673,6 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta
token_name = QasmExpression(:n_dims, QasmExpression(:integer_literal, parse(Int, dim)))
end
head(token_name) == :empty && throw(QasmParseError("unable to parse line with start token $(start_token[end])", stack, start, qasm))

next_token = first(tokens)
if next_token[end] == semicolon || next_token[end] == comma || start_token[end] (lbracket, lbrace)
expr = token_name
Expand All @@ -657,7 +685,7 @@ function parse_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, sta
unary_op_symbol (:~, :!, :-) || throw(QasmParseError("invalid unary operator $unary_op_symbol.", stack, start, qasm))
next_expr = parse_expression(tokens, stack, start, qasm)
# apply unary op to next_expr
if head(next_expr) (:identifier, :indexed_identifier, :integer_literal, :float_literal, :string_literal, :irrational_literal, :boolean_literal, :complex_literal, :function_call, :cast)
if head(next_expr) (:identifier, :indexed_identifier, :integer_literal, :float_literal, :string_literal, :irrational_literal, :boolean_literal, :complex_literal, :function_call, :cast, :duration_literal)
expr = QasmExpression(:unary_op, unary_op_symbol, next_expr)
elseif head(next_expr) == :binary_op
# replace first argument
Expand Down Expand Up @@ -1005,36 +1033,37 @@ function parse_qasm(clean_tokens::Vector{Tuple{Int64, Int32, Token}}, qasm::Stri
line_exprs = collect(Iterators.reverse(line_body))[2:end]
push!(stack, QasmExpression(:return, line_exprs))
elseif token == box
@warn "box expression encountered -- currently boxed and delayed expressions are not supported"
@warn "box expression encountered -- currently `box` is a no-op"
box_expr = QasmExpression(:box)
parse_block_body(box_expr, clean_tokens, stack, start, qasm)
push!(stack, box_expr)
elseif token == reset_token
@warn "reset expression encountered -- currently `reset` is a no-op"
eol = findfirst(triplet->triplet[end] == semicolon, clean_tokens)
reset_tokens = splice!(clean_tokens, 1:eol)
targets = parse_expression(reset_tokens, stack, start, qasm)
targets = parse_list_expression(reset_tokens, stack, start, qasm)
push!(stack, QasmExpression(:reset, targets))
elseif token == barrier_token
@warn "barrier expression encountered -- currently `barrier` is a no-op"
eol = findfirst(triplet->triplet[end] == semicolon, clean_tokens)
barrier_tokens = splice!(clean_tokens, 1:eol)
targets = parse_expression(barrier_tokens, stack, start, qasm)
targets = parse_list_expression(barrier_tokens, stack, start, qasm)
push!(stack, QasmExpression(:barrier, targets))
elseif token == duration_token
@warn "duration expression encountered -- currently `duration` is a no-op"
eol = findfirst(triplet->triplet[end] == semicolon, clean_tokens)
duration_tokens = splice!(clean_tokens, 1:eol)
# TODO: add proper parsing of duration expressions, including
# support for units and algebraic durations like 2*a.
#dur_expr = parse_expression(duration_tokens, stack, start, qasm)
push!(stack, QasmExpression(:duration))
elseif token == stretch_token
@warn "stretch expression encountered -- currently `stretch` is a no-op"
elseif token == delay_token
@warn "delay expression encountered -- currently `delay` is a no-op"
eol = findfirst(triplet->triplet[end] == semicolon, clean_tokens)
stretch_tokens = splice!(clean_tokens, 1:eol)
stretch_expr = parse_expression(stretch_tokens, stack, start, qasm)
push!(stack, QasmExpression(:stretch, stretch_expr))
delay_tokens = splice!(clean_tokens, 1:eol)
delay_expr = QasmExpression(:delay)
# format is delay[duration]; or delay[duration] targets;
delay_duration = extract_braced_block(delay_tokens, stack, start, qasm)
push!(delay_expr, QasmExpression(:duration, parse_expression(delay_duration, stack, start, qasm)))
target_expr = QasmExpression(:targets)
if first(delay_tokens)[end] != semicolon # targets present
targets = parse_list_expression(delay_tokens, stack, start, qasm)
push!(target_expr, targets)
end
push!(delay_expr, target_expr)
push!(stack, delay_expr)
elseif token == end_token
push!(stack, QasmExpression(:end))
elseif token == identifier || token == builtin_gate
Expand Down Expand Up @@ -1327,7 +1356,7 @@ function evaluate(v::V, expr::QasmExpression) where {V<:AbstractVisitor}
step::Int = evaluate(v, raw_step)
stop::Int = evaluate(v, raw_stop)
return StepRange(start, step, stop)
elseif head(expr) (:integer_literal, :float_literal, :string_literal, :complex_literal, :irrational_literal, :boolean_literal)
elseif head(expr) (:integer_literal, :float_literal, :string_literal, :complex_literal, :irrational_literal, :boolean_literal, :duration_literal)
return expr.args[1]
elseif head(expr) == :array_literal
return [evaluate(v, arg) for arg in convert(Vector{QasmExpression}, expr.args)]
Expand Down Expand Up @@ -1630,8 +1659,21 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
elseif head(program_expr) == :version
return v
elseif head(program_expr) == :reset
targets = program_expr.args[1]::QasmExpression
target_qubits = evaluate(v, targets)
push!(v, [BraketSimulator.Instruction(BraketSimulator.Reset(), t) for t in target_qubits])
return v
elseif head(program_expr) == :barrier
targets = program_expr.args[1]::QasmExpression
target_qubits = evaluate(v, targets)
push!(v, [BraketSimulator.Instruction(BraketSimulator.Barrier(), t) for t in target_qubits])
return v
elseif head(program_expr) == :delay
duration_expr = program_expr.args[1].args[1]::QasmExpression
targets = program_expr.args[2].args[1]::QasmExpression
target_qubits = evaluate(v, targets)
duration = evaluate(v, duration_expr)
push!(v, [BraketSimulator.Instruction(BraketSimulator.Delay(duration), t) for t in target_qubits])
return v
elseif head(program_expr) == :stretch
return v
Expand Down
12 changes: 8 additions & 4 deletions src/dm_simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,15 @@ function _evolve_op!(
) where {T<:Complex,S<:AbstractDensityMatrix{T},N<:Noise}
apply_noise!(op, dms.density_matrix, target...)
end
# Measure operators are no-ops for now as measurement is handled at the end
# of simulation, in the results computation step. If/when mid-circuit
# measurement is supported, this operation will collapse the density
# matrix on the measured qubits.
# Measure, barrier, reset, and delay operators are no-ops for now as
# measurement is handled at the end of simulation, in the results
# computation step. If/when mid-circuit # measurement is supported,
# this operation will collapse the density # matrix on the measured qubits.
# Barrier, reset, and delay are also to-do implementations.
_evolve_op!(dms::DensityMatrixSimulator{T,S}, m::Measure, args...) where {T<:Complex,S<:AbstractDensityMatrix{T}} = return
_evolve_op!(dms::DensityMatrixSimulator{T,S}, b::Barrier, args...) where {T<:Complex,S<:AbstractDensityMatrix{T}} = return
_evolve_op!(dms::DensityMatrixSimulator{T,S}, r::Reset, args...) where {T<:Complex,S<:AbstractDensityMatrix{T}} = return
_evolve_op!(dms::DensityMatrixSimulator{T,S}, d::Delay, args...) where {T<:Complex,S<:AbstractDensityMatrix{T}} = return

"""
evolve!(dms::DensityMatrixSimulator{T, S<:AbstractMatrix{T}}, operations::Vector{Instruction}) -> DensityMatrixSimulator{T, S}
Expand Down
3 changes: 3 additions & 0 deletions src/gate_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ apply_gate!(::Val{false}, g::I, state_vec::AbstractStateVector{T}, qubits::Int..
apply_gate!(::Val{true}, g::I, state_vec::AbstractStateVector{T}, qubits::Int...) where {T<:Complex} =
return
apply_gate!(::Measure, state_vec, args...) = return
apply_gate!(::Reset, state_vec, args...) = return
apply_gate!(::Barrier, state_vec, args...) = return
apply_gate!(::Delay, state_vec, args...) = return

function apply_gate!(
g_matrix::Union{SMatrix{2,2,T}, Diagonal{T,SVector{2,T}}},
Expand Down
3 changes: 0 additions & 3 deletions src/gates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ qubit_count(g::Unitary) = convert(Int, log2(size(g.matrix, 1)))
Parametrizable(g::AngledGate) = Parametrized()
Parametrizable(g::Gate) = NonParametrized()
parameters(g::AngledGate) = collect(filter(a->a isa FreeParameter, angles(g)))
parameters(g::Gate) = FreeParameter[]
bind_value!(g::G, params::Dict{Symbol, <:Real}) where {G<:Gate} = bind_value!(Parametrizable(g), g, params)
bind_value!(::NonParametrized, g::G, params::Dict{Symbol, <:Real}) where {G<:Gate} = g
# nosemgrep
function bind_value!(::Parametrized, g::G, params::Dict{Symbol, <:Real}) where {G<:AngledGate}
new_angles = map(angles(g)) do angle
Expand Down
4 changes: 1 addition & 3 deletions src/noises.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,9 @@ end
Base.:(==)(c1::MultiQubitPauliChannel{N}, c2::MultiQubitPauliChannel{M}) where {N,M} = N == M && c1.probabilities == c2.probabilities

Parametrizable(g::Noise) = NonParametrized()
parameters(g::Noise) = parameters(Parametrizable(g), g)
parameters(g::Noise) = parameters(Parametrizable(g), g)
parameters(::Parametrized, g::N) where {N<:Noise} = filter(x->x isa FreeParameter, [getproperty(g, fn) for fn in fieldnames(N)])
parameters(::NonParametrized, g::Noise) = FreeParameter[]
bind_value!(n::N, params::Dict{Symbol, <:Real}) where {N<:Noise} = bind_value!(Parametrizable(n), n, params)
bind_value!(::NonParametrized, n::N, params::Dict{Symbol, <:Real}) where {N<:Noise} = n

# nosemgrep
function bind_value!(::Parametrized, g::N, params::Dict{Symbol, <:Real}) where {N<:Noise}
Expand Down
46 changes: 46 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ function Base.getindex(p::PauliEigenvalues{N}, i::Int)::Float64 where N
end
Base.getindex(p::PauliEigenvalues{N}, ix::Vector{Int}) where {N} = [p[i] for i in ix]

parameters(o::QuantumOperator) = FreeParameter[]

"""
Measure(index) <: QuantumOperator
Expand All @@ -66,3 +68,47 @@ Measure() = Measure(-1)
Parametrizable(m::Measure) = NonParametrized()
qubit_count(::Type{Measure}) = 1
qubit_count(m::Measure) = qubit_count(Measure)

"""
Reset(index) <: QuantumOperator
Represents an active reset operation on targeted qubit, stored in the classical register at `index`.
For now, this is a no-op.
"""
struct Reset <: QuantumOperator
index::Int
end
Reset() = Reset(-1)
Parametrizable(m::Reset) = NonParametrized()
qubit_count(::Type{Reset}) = 1
qubit_count(m::Reset) = qubit_count(Reset)

"""
Barrier(index) <: QuantumOperator
Represents a barrier operation on targeted qubit, stored in the classical register at `index`.
For now, this is a no-op.
"""
struct Barrier <: QuantumOperator
index::Int
end
Barrier() = Barrier(-1)
Parametrizable(m::Barrier) = NonParametrized()
qubit_count(::Type{Barrier}) = 1
qubit_count(m::Barrier) = qubit_count(Barrier)

"""
Delay(index, duration::Time) <: QuantumOperator
Represents a delay operation for `duration` on targeted qubit,
stored in the classical register at `index`.
For now, this is a no-op.
"""
struct Delay <: QuantumOperator
index::Int
duration::Dates.Period
end
Delay(duration::Dates.Period) = Delay(-1, duration)
Parametrizable(m::Delay) = NonParametrized()
qubit_count(::Type{Delay}) = 1
qubit_count(m::Delay) = qubit_count(Delay)
2 changes: 2 additions & 0 deletions src/schemas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ end
Instruction(o::O, target) where {O<:Operator} = Instruction{O}(o, QubitSet(target...))
Base.:(==)(ix1::Instruction{O}, ix2::Instruction{O}) where {O<:Operator} = (ix1.operator == ix2.operator && ix1.target == ix2.target)
bind_value!(ix::Instruction{O}, param_values::Dict{Symbol, <:Real}) where {O<:Operator} = Instruction{O}(bind_value!(ix.operator, param_values), ix.target)
bind_value!(o::O, params::Dict{Symbol, <:Real}) where {O<:Operator} = bind_value!(Parametrizable(o), o, params)
bind_value!(::NonParametrized, o::O, params::Dict{Symbol, <:Real}) where {O<:Operator} = o
remap(@nospecialize(ix::Instruction{O}), mapping::Dict{<:Integer, <:Integer}) where {O} = Instruction{O}(copy(ix.operator), [mapping[q] for q in ix.target])
Loading

0 comments on commit 4780fdf

Please sign in to comment.