Skip to content

Commit

Permalink
Fix the arithmetic for wait_until_done(); add start_time kwarg in s…
Browse files Browse the repository at this point in the history
…econds, instead of ns. (#104)

Correctly account for the time sent from the API (ms since epoch), and
add a new parameter, start_time, which is in *seconds* not nanoseconds.

Deprecate the old parameter, to be removed in a future release.

Also adds an complex mocked test for wait_until_done, which tests that it really is waiting for the expected time, based on the created_on field from the API, which comes back in milliseconds.

--------------

The actual src/ changes are pretty straightforward:
- change from ns to seconds: the API returns the unix timestamp in ms and julia's `time()` function returns timestamp in seconds (to ~microsecond precision).

But then i made the PR more complicated by:
- keeping the old interface to make the change backwards compatible, and
- adding a complicated mocked unit test to make sure we're testing the behavior here and it's working as we expected

... sorry for the extra complexity.

## Commits

* Fix the arithmetic for wait_until_done(); add API in seconds.

Correctly account for the time sent from the API (ms since epoch), and
add a new parameter, start_time, which is in *seconds* not nanoseconds.

Deprecate the old parameter, to be removed in a future release.

* Add test for wait_until_done that it counts correctly
  • Loading branch information
NHDaly authored Jan 19, 2023
1 parent 59d6e20 commit 3958c4d
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 26 deletions.
66 changes: 45 additions & 21 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ const PATH_ASYNC_TRANSACTIONS = "/transactions"
const PATH_USERS = "/users"
const ARROW_CONTENT_TYPE = "application/vnd.apache.arrow.stream"

const TXN_POLLING_OVERHEAD = 0.10

struct HTTPError <: Exception
status_code::Int
status_text::String
Expand Down Expand Up @@ -65,29 +67,49 @@ finished. A transaction has finished once it has reached one of the terminal sta
`COMPLETED` or `ABORTED`. The polling uses a low-overhead exponential backoff in order to
ensure low-latency results without overloading network traffic.
"""
function wait_until_done(ctx::Context, rsp::TransactionResponse; start_time_ns = nothing)
wait_until_done(ctx, rsp.transaction; start_time_ns)
function wait_until_done(ctx::Context, rsp::TransactionResponse;
start_time_ns = nothing, # deprecated
start_time = nothing,
)
wait_until_done(ctx, rsp.transaction; start_time_ns, start_time)
end
function wait_until_done(ctx::Context, txn::JSON3.Object; start_time_ns = nothing)
function wait_until_done(ctx::Context, txn::JSON3.Object;
start_time_ns = nothing, # deprecated
start_time = nothing,
)
if start_time_ns !== nothing
start_time = start_time_ns / 1e9,
@warn "wait_until_done(): start_time_ns= is deprecated; please pass start_time= as a unix timestamp instead."
end

# If the user is calling this manually, read the start time from the transaction object.
if start_time_ns === nothing &&
if start_time === nothing &&
# NOTE: the fast-path txn may not include the created_on key.
haskey(txn, :created_on)
start_time_ns = _transaction_start_time_ns(txn)
start_time = _transaction_start_time(txn)
end
wait_until_done(ctx, transaction_id(txn); start_time_ns)
wait_until_done(ctx, transaction_id(txn); start_time)
end
function _transaction_start_time_ns(txn::JSON3.Object)
return txn[:created_on] ÷ 1_000_000_000
function _transaction_start_time(txn::JSON3.Object)
# The API returns *milliseconds* since the epoch
return txn[:created_on] / 1e3
end
function wait_until_done(ctx::Context, id::AbstractString; start_time_ns = nothing)
function wait_until_done(ctx::Context, id::AbstractString;
start_time_ns = nothing, # deprecated
start_time = nothing,
)
if start_time_ns !== nothing
start_time = start_time_ns / 1e9,
@warn "wait_until_done(): start_time_ns= is deprecated; please pass start_time= as a unix timestamp instead."
end

# If the user is calling this manually, read the start time from the transaction object.
if start_time_ns === nothing
if start_time === nothing
txn = get_transaction(ctx, id)
start_time_ns = _transaction_start_time_ns(txn)
start_time = _transaction_start_time(txn)
end
try
_poll_with_specified_overhead(; overhead_rate = 0.10, start_time_ns) do
_poll_with_specified_overhead(; overhead_rate = TXN_POLLING_OVERHEAD, start_time) do
txn = get_transaction(ctx, id)
return transaction_is_done(txn)
end
Expand Down Expand Up @@ -125,14 +147,14 @@ end
function _poll_with_specified_overhead(
f;
overhead_rate, # Add xx% overhead through polling.
start_time_ns = time_ns(), # Optional start time, otherwise defaults to now()
start_time = time(), # Optional start time, otherwise defaults to now()
n = typemax(Int), # Maximum number of polls
max_delay = 120, # 2 min
timeout_secs = Inf, # no timeout by default
throw_on_timeout = false,
)
@debug "start time: $start_time"
@assert overhead_rate >= 0.0
timeout_ns = timeout_secs * 1e9
local iter
for i in 1:n
iter = i
Expand All @@ -142,17 +164,19 @@ function _poll_with_specified_overhead(
if done
return nothing
end
current_delay = time_ns() - start_time_ns
if current_delay > timeout_ns
t = @mock(time())
@debug "time: $t"
current_delay_s = t - start_time
if current_delay_s > timeout_secs
break
end
duration = (current_delay * overhead_rate) / 1e9
duration = current_delay_s * overhead_rate
duration = min(duration, max_delay) # clamp the duration as specified.
sleep(duration)
@mock sleep(duration)
end

# We have exhausted the iterator.
current_delay_secs = (time_ns() - start_time_ns) * 1e9
current_delay_secs = time() - start_time
throw_on_timeout && error("Timed out after $iter iterations, $current_delay_secs seconds in `_poll_with_specified_overhead`.")

return nothing
Expand Down Expand Up @@ -526,14 +550,14 @@ Dict{String, Any} with 4 entries:
function exec(ctx::Context, database::AbstractString, engine::AbstractString, source; inputs = nothing, readonly = false, kw...)
# Record the initial start time so that we include the time to create the transaction
# in our exponential backoff in `wait_until_done()`.
start_time_ns = time_ns()
start_time = time()
# Create an Async transaction:
transactionResponse = exec_async(ctx, database, engine, source; inputs=inputs, readonly=readonly, kw...)
if transactionResponse.results !== nothing
return transactionResponse
end
# Poll until the transaction is done, and return the results.
return wait_until_done(ctx, transactionResponse; start_time_ns = start_time_ns)
return wait_until_done(ctx, transactionResponse; start_time = start_time)
end

function exec_async(ctx::Context, database::AbstractString, engine::AbstractString, source; inputs = nothing, readonly = false, kw...)
Expand Down
4 changes: 3 additions & 1 deletion test/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import ProtoBuf

Mocking.activate()

include("wait_until_done.jl")

# -----------------------------------
# v2 transactions

Expand Down Expand Up @@ -269,7 +271,7 @@ end
)

apply(metadata_404_patch) do
RAI.wait_until_done(ctx, "<txn-id>", start_time_ns=0)
RAI.wait_until_done(ctx, "<txn-id>", start_time=0)
end
end

Expand Down
8 changes: 4 additions & 4 deletions test/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function with_engine(f, ctx; existing_engine=nothing)
engine_name = rnd_test_name()
if isnothing(existing_engine)
custom_headers = get(ENV, "CUSTOM_HEADERS", nothing)
start_time_ns = time_ns()
start_time = time()
if isnothing(custom_headers)
create_engine(ctx, engine_name)
else
Expand All @@ -66,7 +66,7 @@ function with_engine(f, ctx; existing_engine=nothing)
headers = JSON3.read(custom_headers, Dict{String, String})
create_engine(ctx, engine_name; nothing, headers)
end
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time_ns) do
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time) do
state = get_engine(ctx, engine_name)[:state]
state == "PROVISION_FAILED" && throw("Failed to provision engine $engine_name")
state == "PROVISIONED"
Expand All @@ -80,8 +80,8 @@ function with_engine(f, ctx; existing_engine=nothing)
# Engines cannot be deleted if they are still provisioning. We have to at least wait
# until they are ready.
if isnothing(existing_engine)
start_time_ns = time_ns() - 2e9 # assume we started 2 seconds ago
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time_ns) do
start_time = time() - 2 # assume we started 2 seconds ago
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time) do
state = get_engine(ctx, engine_name)[:state]
state == "PROVISION_FAILED" && throw("Failed to provision engine $engine_name")
state == "PROVISIONED"
Expand Down
72 changes: 72 additions & 0 deletions test/wait_until_done.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using ExceptionUnwrapping: unwrap_exception_to_root

# This test is _pretty complicated_ since it's trying to test something that depends on
# timing: testing that wait_until_done() polls for the expected amount of time in between
# calls to get_transaction.
# Testing anything to do with timing is always complicated. We tackle it here by mocking
# both sleep() and time(), and injecting fake times, and then making sure that the
# function is computing the correct duration to sleep, based on those times.
@testset "wait_until_done polls correctly" begin
now_ms = round(Int, time() * 1e3)
txn_str = """{
"id": "a3e3bc91-0a98-50ba-733c-0987e160eb7d",
"results_format_version": "2.0.1",
"state": "RUNNING",
"created_on": $(now_ms)
}"""
txn = JSON3.read(txn_str)

ctx = Context("region", "scheme", "host", "2342", nothing, "audience")

start = now_ms / 1e3
# Simulate OVERHEAD of 0.1 + round-trip-time of 0.5
times = [
start + 2, # First call takes 2 seconds then returns async
start + 2.2 + 0.5, # So we slept 0.2 seconds, then get_txn takes 0.5 secs
start + 2.97 + 0.5 # Now we sleep 2.7 * 1.1 ≈ 2.97, then again 0.5 RTT.
]
i = 1
time_patch = @patch function Base.time()
v = times[i]
i += 1
return v
end
# Here, we test that each call to sleep is the correct calculation of current "time"
# minus start time * the overhead.
sleep_patch = @patch function Base.sleep(duration)
@info "Mock sleep for $duration"
@test duration (times[i-1] - start) * RAI.TXN_POLLING_OVERHEAD
end

# This is returned on each get_txn() request.
unfinished_response = HTTP.Response(
200,
["Content-Type" => "application/json"],
body = """{"transaction": $(txn_str)}"""
)

# Stop the test after 3 polls.
ABORT = :ABORT_TEST

request_patch = @patch function RAI.request(ctx::Context, args...; kw...)
if i <= 3
return unfinished_response
else
# Finish the test
throw(ABORT)
end
end

# Call the function with the patches. Assert that it ends with our ABORT exception.
apply([time_patch, sleep_patch, request_patch]) do
try
wait_until_done(ctx, txn)
catch e
@assert unwrap_exception_to_root(e) == ABORT
end
end

# Test that we made it through all the expected polls, so that we know the above
# `@test`s all triggered.
@test i == 4
end

0 comments on commit 3958c4d

Please sign in to comment.