Skip to content

Commit

Permalink
fix: Catch up to BDK for IonQ error mitigation, image URIs, and new a…
Browse files Browse the repository at this point in the history
…dditionalMetadata (#66)

* Catch up to IonQ error mitigation and new Jobs images tags

* Fix integ tests

* More tests and implementation for queue depth

* Add integ tests and methods for queue info

* Add support and tests for error mitigation

* Add some more logging for docker commands and fix call to pull_image
  • Loading branch information
kshyatt-aws authored Oct 30, 2023
1 parent f42a5ec commit de79455
Show file tree
Hide file tree
Showing 20 changed files with 420 additions and 73 deletions.
3 changes: 3 additions & 0 deletions src/Braket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export apply_gate_noise!, apply
export logs, log_metric, metrics
export depth, qubit_count, qubits, ir, IRType, OpenQASMSerializationProperties
export OpenQasmProgram
export QueueDepthInfo, QueueType, Normal, Priority, queue_depth, queue_position

export AdjointGradient, Expectation, Sample, Variance, Amplitude, Probability, StateVector, DensityMatrix, Result

Expand Down Expand Up @@ -120,12 +121,14 @@ Base.show(io::IO, fp::FreeParameter) = print(io, string(fp.name))
include("compiler_directive.jl")
include("gates.jl")
include("noises.jl")
include("error_mitigation.jl")
include("results.jl")
include("schemas.jl")
include("moments.jl")
include("circuit.jl")
include("noise_model.jl")
include("ahs.jl")
include("queue_information.jl")
include("device.jl")
include("gate_applicators.jl")
include("noise_applicators.jl")
Expand Down
9 changes: 9 additions & 0 deletions src/aws_jobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ function log_metric(metric_name::String, value::Union{Float64, Int}; timestamp=t
return
end

function queue_position(j::AwsQuantumJob)
md = metadata(j)
response = md["queueInfo"]
queue_position = get(response, "position", "None") == "None" ? "" : get(response, "position", "")
message = get(response, "message", "")
return HybridJobQueueInfo(queue_position, message)
end

function log_stream(ch::Channel, log_group::String, stream_name::String, start_time::Int=0, skip::Int=0)
next_token = nothing
event_count = 1
Expand Down Expand Up @@ -558,3 +566,4 @@ function AwsQuantumJob(
wait_until_complete && logs(job, wait=true)
return job
end
AwsQuantumJob(device::BraketDevice, source_module::String; kwargs...) = AwsQuantumJob(convert(String, device), source_module; kwargs...)
86 changes: 86 additions & 0 deletions src/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,68 @@ const _GET_DEVICES_ORDER_BY_KEYS = Set(("arn", "name", "type", "provider_name",
@enum AwsDeviceType SIMULATOR QPU
const AwsDeviceTypeDict = Dict("SIMULATOR"=>SIMULATOR, "QPU"=>QPU)

abstract type BraketDevice end
for provider in (:AmazonDevice, :_XanaduDevice, :_DWaveDevice, :OQCDevice, :QuEraDevice, :IonQDevice, :RigettiDevice)
@eval begin
abstract type $provider <: BraketDevice end
end
end

for (d, d_arn) in zip((:SV1, :DM1, :TN1), ("sv1", "dm1", "tn1"))
@eval begin
struct $d <: AmazonDevice end
Base.convert(::Type{String}, d::$d) = "arn:aws:braket:::device/quantum-simulator/amazon/" * $d_arn
end
end

for (d, d_arn) in zip((:_Advantage1, :_Advantage3, :_Advantage4, :_Advantage6, :_DW2000Q6),
("arn:aws:braket:::device/qpu/d-wave/Advantage_system1",
"arn:aws:braket:::device/qpu/d-wave/Advantage_system2",
"arn:aws:braket:::device/qpu/d-wave/Advantage_system3",
"arn:aws:braket:us-west-2::device/qpu/d-wave/Advantage_system6",
"arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6",
))
@eval begin
struct $d <: _DWaveDevice end
Base.convert(::Type{String}, d::$d) = $d_arn
end
end

struct _Borealis <: _XanaduDevice end
Base.convert(::String, d::_Borealis) = "arn:aws:braket:us-east-1::device/qpu/xanadu/Borealis"

for (d, d_arn) in zip((:Harmony, :Aria1, :Aria2),
("arn:aws:braket:us-east-1::device/qpu/ionq/Harmony",
"arn:aws:braket:us-east-1::device/qpu/ionq/Aria-1",
"arn:aws:braket:us-east-1::device/qpu/ionq/Aria-2",
))
@eval begin
struct $d <: IonQDevice end
Base.convert(::Type{String}, d::$d) = $d_arn
end
end

struct Aquila <: QuEraDevice end
Base.convert(::Type{String}, d::Aquila) = "arn:aws:braket:us-east-1::device/qpu/quera/Aquila"

struct Lucy <: OQCDevice end
Base.convert(::Type{String}, d::Lucy) = "arn:aws:braket:eu-west-2::device/qpu/oqc/Lucy"

for (d, d_arn) in zip((:_Aspen8, :_Aspen9, :_Aspen10, :_Aspen11, :_AspenM1, :_AspenM2, :AspenM3),
("arn:aws:braket:::device/qpu/rigetti/Aspen-8",
"arn:aws:braket:::device/qpu/rigetti/Aspen-9",
"arn:aws:braket:::device/qpu/rigetti/Aspen-10",
"arn:aws:braket:::device/qpu/rigetti/Aspen-11",
"arn:aws:braket:us-west-1::device/qpu/rigetti/Aspen-M-1",
"arn:aws:braket:us-west-1::device/qpu/rigetti/Aspen-M-2",
"arn:aws:braket:us-west-1::device/qpu/rigetti/Aspen-M-3",
))
@eval begin
struct $d <: RigettiDevice end
Base.convert(::Type{String}, d::$d) = $d_arn
end
end

"""
AwsDevice <: Device
Expand Down Expand Up @@ -62,6 +124,29 @@ function _construct_topology_graph(d::AwsDevice)
end
end

"""
queue_depth(d::AwsDevice)
"""
function queue_depth(d::AwsDevice)
dev_name = d._arn
metadata = parse(BRAKET.get_device(HTTP.escapeuri(dev_name), aws_config=d._config))
queue_metadata = get(metadata, "deviceQueueInfo", nothing)
queue_info = Dict{String, Any}()
for response in queue_metadata
queue_name = get(response, "queue", "")
queue_priority = get(response, "queuePriority", "")
queue_size = get(response, "queueSize", "")
if queue_name == "QUANTUM_TASKS_QUEUE"
priority_enum = QueueType(queue_priority)
!haskey(queue_info, "quantum_tasks") && (queue_info["quantum_tasks"] = Dict{QueueType, String}())
queue_info["quantum_tasks"][priority_enum] = queue_size
else
queue_info["jobs"] = queue_size
end
end
return QueueDepthInfo(get(queue_info, "quantum_tasks", Dict{QueueType, String}()), get(queue_info, "jobs", ""))
end

"""
refresh_metadata!(d::AwsDevice)
Expand Down Expand Up @@ -118,6 +203,7 @@ function AwsDevice(device_arn::String; config::AWSConfig=global_aws_config())
return d
end
end
AwsDevice(d::BraketDevice; kwargs...) = AwsDevice(convert(String, d); kwargs...)

"""
isavailable(d::AwsDevice) -> Bool
Expand Down
5 changes: 5 additions & 0 deletions src/error_mitigation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
abstract type ErrorMitigation end

struct DeBias <: ErrorMitigation end

ir(db::DeBias) = [Debias(StructTypes.defaults(Debias)[:type])]
8 changes: 4 additions & 4 deletions src/jobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Abstract type representing a Braket Job.
abstract type Job end

get_job(j::Job) = get_job(arn(j))
get_job(arn::String) = startswith(arn, "local") ? LocalQuantumJob(arn) : BRAKET.get_job(HTTP.escapeuri(arn))
get_job(arn::String) = startswith(arn, "local") ? LocalQuantumJob(arn) : BRAKET.get_job(HTTP.escapeuri(arn) * "?additionalAttributeNames=QueueInfo")

config_fname(f::Framework) = joinpath(@__DIR__, "image_uri_config", lowercase(string(f))*".json")
config_for_framework(f::Framework) = JSON3.read(read(config_fname(f), String), Dict)
Expand All @@ -44,11 +44,11 @@ function retrieve_image(f::Framework, config::AWSConfig)
version_config = conf["versions"][framework_version]
registry = registry_for_region(version_config, aws_region)
tag = if f == BASE
string(version_config["repository"]) * ":" * string(framework_version) * "-cpu-py37-ubuntu18.04"
string(version_config["repository"]) * ":" * "latest"
elseif f == PL_TENSORFLOW
string(version_config["repository"]) * ":" * string(framework_version) * "-gpu-py37-cu110-ubuntu18.04"
string(version_config["repository"]) * ":" * "latest"
elseif f == PL_PYTORCH
string(version_config["repository"]) * ":" * string(framework_version) * "-gpu-py38-cu111-ubuntu20.04"
string(version_config["repository"]) * ":" * "latest"
end
return string(registry) * ".dkr.ecr.$aws_region.amazonaws.com/$tag"
end
46 changes: 33 additions & 13 deletions src/local_jobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ mutable struct LocalJobContainer
function LocalJobContainer(image_uri::String, create_job_args; config::AWSConfig=global_aws_config(), container_name::String="", container_code_path::String="/opt/ml/code", force_update::Bool=false)
c = new(image_uri, container_name, container_code_path, Dict{String, String}(), "", config)
c = start_container!(c, force_update)
finalizer(c) do c
# check that the container is still running
c_list = read(`docker container ls -q`, String)
stop_flag = occursin(first(c.container_name, 10), c_list)
stop_flag && read(Cmd(["docker", "stop", c.container_name]), String)
return
end
return setup_container!(c, create_job_args)
end
end
Expand Down Expand Up @@ -146,12 +139,17 @@ function run_local_job!(c::LocalJobContainer)
end

function login_to_ecr(account_id::String, ecr_uri::String, config::AWSConfig)
@debug "Attempting to log in to ECR..."
@debug "Getting authorization token"
authorization_data_result = EcR.get_authorization_token(Dict("registryIds"=>[account_id]), aws_config=config)
isnothing(authorization_data_result) && throw(ErrorException("unable to get permissions to access ECR in order to log in to docker. Please pull down the container before proceeding."))
raw_token = base64decode(authorization_data_result["authorizationData"][1]["authorizationToken"])
token = String(raw_token)
token = replace(token, "AWS:"=>"")
@debug "Performing docker login"
proc_out, proc_err, code = capture_docker_cmd(`docker login -u AWS -p $token $ecr_uri`)
@debug "docker login complete"
code != 0 && throw(ErrorException("Unable to docker login to ECR with error $proc_err"))
return
end

Expand All @@ -161,8 +159,9 @@ function pull_image(image_uri::String, config::AWSConfig)
ecr_uri = String(m[1])
account_id = String(m[2])
login_to_ecr(account_id, ecr_uri, config)
@warn "Pulling docker image. This may take a while."
@warn "Pulling docker image $image_uri. This may take a while."
proc_out, proc_err, code = capture_docker_cmd(`docker pull $image_uri`)
code != 0 && error(proc_err)
return
end

Expand All @@ -182,15 +181,19 @@ function start_container!(c::LocalJobContainer, force_update::Bool)
get_image_name(image_uri) = capture_docker_cmd(`docker images -q $image_uri`)[1]
image_name = get_image_name(image_uri)
if isempty(image_name) || isnothing(image_name)
pull_image(image_uri, c.config)
image_name = get_image_name(image_uri)
(isempty(image_name) || isnothing(image_name)) && throw(ErrorException("The URI $(c.image_uri) is not available locally and cannot be pulled from Amazon ECR. Please pull down the container before proceeding."))
try
pull_image(image_uri, c.config)
image_name = get_image_name(image_uri)
(isempty(image_name) || isnothing(image_name)) && throw(ErrorException("The URI $(c.image_uri) is not available locally and cannot be pulled from Amazon ECR. Please pull down the container before proceeding."))
catch ex
throw(ErrorException("The URI $(c.image_uri) is not available locally and cannot be pulled from Amazon ECR due to $ex. Please pull down the container before proceeding."))
end
elseif force_update
try
pull_image(image_uri)
pull_image(image_uri, c.config)
image_name = get_image_name(image_uri)
catch e
@warn "Unable to update $(c.image_uri)"
@warn "Unable to update $(c.image_uri) with error $e"
end
end
container_name, err, code = capture_docker_cmd(`docker run -d --rm $image_name tail -f /dev/null`)
Expand All @@ -199,6 +202,21 @@ function start_container!(c::LocalJobContainer, force_update::Bool)
return c
end

function stop_container!(c::LocalJobContainer)
# check that the container is still running
cmd = `docker container ls -q`
c_list, c_err, code = capture_docker_cmd(cmd)
if code == 0 && occursin(first(c.container_name, 10), c_list)
stop_out, stop_err, stop_code = capture_docker_cmd(Cmd(["docker", "stop", c.container_name]))
if stop_code != 0
error("unable to stop docker container $(c.contianer_name)! Error: $stop_err")
end
else
error("unable to read docker container list! Error: $c_err")
end
return
end

function copy_from_container!(c::LocalJobContainer, src::String, dst::String)
c_name = c.container_name
cmd = `docker cp $c_name:$src $dst`
Expand Down Expand Up @@ -335,9 +353,11 @@ function LocalQuantumJob(
copy_from_container!(local_job_container, checkpoint_path, joinpath(job_name, "checkpoints"))
end
run_log = local_job_container.run_log
stop_container!(local_job_container)
end
return LocalQuantumJob("local:job/$job_name", run_log=run_log)
end
LocalQuantumJob(device::BraketDevice, source_module::String; kwargs...) = LocalQuantumJob(convert(String, device), source_module; kwargs...)

"""
arn(j::LocalQuantumJob)
Expand Down
22 changes: 22 additions & 0 deletions src/queue_information.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@enum QueueType Normal Priority
QueueTypeDict = Dict("Normal"=>Normal, "Priority"=>Priority)
QueueType(s::String) = QueueTypeDict[s]

struct QueueDepthInfo
quantum_tasks::Dict{QueueType, String}
jobs::String
end
Base.:(==)(ix::QueueDepthInfo, iy::QueueDepthInfo) = (ix.jobs == iy.jobs && ix.quantum_tasks == iy.quantum_tasks)

mutable struct QuantumTaskQueueInfo
queue_type::QueueType
queue_position::String
message::String
end
QuantumTaskQueueInfo(queue_type::QueueType) = QuantumTaskQueueInfo(queue_type, "", "")

mutable struct HybridJobQueueInfo
queue_position::String
message::String
end
HybridJobQueueInfo() = HybridJobQueueInfo("", "")
Loading

0 comments on commit de79455

Please sign in to comment.