From f5ec984666017f0afe124195889178fb62fba6d7 Mon Sep 17 00:00:00 2001 From: WingCode Date: Wed, 28 Aug 2024 19:01:43 +0200 Subject: [PATCH] feature: support serde pickle v4 (#88) --- Project.toml | 2 ++ src/jobs.jl | 7 +++++-- test/Project.toml | 1 + test/jobs.jl | 19 +++++++++++++++---- 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 7e29d2b3..a7d05380 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -58,6 +59,7 @@ Markdown = "=0.7.5" Mocking = "=0.8.1" NamedTupleTools = "=0.14.3" OrderedCollections = "=1.6.3" +Pickle = "0.3.5" Pkg = "1.6" Random = "1.6" SparseArrays = "1.6" diff --git a/src/jobs.jl b/src/jobs.jl index e8bca6eb..cc78d82d 100644 --- a/src/jobs.jl +++ b/src/jobs.jl @@ -1,3 +1,6 @@ +using Base64 +using Pickle + const JOB_DEFAULT_RESULTS_POLL_TIMEOUT = 864000 const JOB_DEFAULT_RESULTS_POLL_INTERVAL = 5 const JOB_TERMINAL_STATES = ["COMPLETED", "FAILED", "CANCELLED"] @@ -67,13 +70,13 @@ function get_hyperparameters() end function serialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat) - data_format == pickled_v4 && throw(ArgumentError("pickling data not yet supported!")) + data_format == pickled_v4 && return Dict(k => base64encode(Pickle.stores(v)) for (k, v) in data_dictionary) return data_dictionary end function deserialize_values(data_dictionary::Dict{String, Any}, data_format::PersistedJobDataFormat) data_format == plaintext && return data_dictionary - throw(ArgumentError("unpickling results not yet supported!")) + return Dict(k => Pickle.loads(base64decode(v)) for (k, v) in data_dictionary) end deserialize_values(data_dictionary::Dict{String, Any}, data_format::String) = deserialize_values(data_dictionary, PersistedJobDataFormatDict[data_format]) diff --git a/test/Project.toml b/test/Project.toml index 4f33b7a6..74fd56d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Mocking = "78c3b35d-d492-501b-9361-3d52fe80e533" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/jobs.jl b/test/jobs.jl index 964ed4a4..e432700e 100644 --- a/test/jobs.jl +++ b/test/jobs.jl @@ -1,4 +1,4 @@ -using Braket, Test, Mocking, Random, Dates, Tar, JSON3 +using Braket, Pickle, Base64, Test, Mocking, Random, Dates, Tar, JSON3 Mocking.activate() Base.parse(d::Dict) = d @@ -6,14 +6,25 @@ Base.parse(d::Dict) = d dev_arn = "arn:aws:braket:::device/quantum-simulator/amazon/sv1" @testset "Jobs" begin - @testset "deserialization errors" begin - @test_throws ArgumentError Braket.deserialize_values(Dict{String, Any}(), Braket.pickled_v4) + @testset "serialization pickled_v4" begin + data_dictionary = Dict{String, Any}("key1" => "value1", "key2" => "value2") + data_format = Braket.pickled_v4 + result = Braket.serialize_values(data_dictionary, data_format) + @test result == Dict{String, Any}("key1" => base64encode(Pickle.stores("value1")), "key2" => base64encode(Pickle.stores("value2"))) + end + @testset "deserialization pickled_v4" begin + data_dictionary = Dict{String, Any}("key1" => base64encode(Pickle.stores("value1")), "key2" => base64encode(Pickle.stores("value2"))) + data_format = Braket.pickled_v4 + result = Braket.deserialize_values(data_dictionary, data_format) + @test result == Dict{String, Any}("key1" => "value1", "key2" => "value2") + end + @testset "deserialization" begin + @test Dict{Any,Any}() == Braket.deserialize_values(Dict{String,Any}(), Braket.pickled_v4) mktempdir() do d job = Braket.AwsQuantumJob("arn:fake") @test Braket._read_and_deserialize_results(job, d) == [] pjd = Braket.PersistedJobData(Braket.header_dict[Braket.PersistedJobData], Dict{String, Any}(), Braket.pickled_v4) write(joinpath(d, Braket.RESULTS_FILENAME), JSON3.write(pjd)) - @test_throws ArgumentError Braket._read_and_deserialize_results(job, d) end end @testset "logs" begin