Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the arithmetic for wait_until_done(); add start_time kwarg in seconds, instead of ns. #104

Merged
merged 2 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ const PATH_ASYNC_TRANSACTIONS = "/transactions"
const PATH_USERS = "/users"
const ARROW_CONTENT_TYPE = "application/vnd.apache.arrow.stream"

const TXN_POLLING_OVERHEAD = 0.10

struct HTTPError <: Exception
status_code::Int
status_text::String
Expand Down Expand Up @@ -65,29 +67,49 @@ finished. A transaction has finished once it has reached one of the terminal sta
`COMPLETED` or `ABORTED`. The polling uses a low-overhead exponential backoff in order to
ensure low-latency results without overloading network traffic.
"""
function wait_until_done(ctx::Context, rsp::TransactionResponse; start_time_ns = nothing)
wait_until_done(ctx, rsp.transaction; start_time_ns)
function wait_until_done(ctx::Context, rsp::TransactionResponse;
start_time_ns = nothing, # deprecated
start_time = nothing,
)
wait_until_done(ctx, rsp.transaction; start_time_ns, start_time)
end
function wait_until_done(ctx::Context, txn::JSON3.Object; start_time_ns = nothing)
function wait_until_done(ctx::Context, txn::JSON3.Object;
start_time_ns = nothing, # deprecated
start_time = nothing,
)
if start_time_ns !== nothing
start_time = start_time_ns / 1e9,
@warn "wait_until_done(): start_time_ns= is deprecated; please pass start_time= as a unix timestamp instead."
end

# If the user is calling this manually, read the start time from the transaction object.
if start_time_ns === nothing &&
if start_time === nothing &&
# NOTE: the fast-path txn may not include the created_on key.
haskey(txn, :created_on)
start_time_ns = _transaction_start_time_ns(txn)
start_time = _transaction_start_time(txn)
end
wait_until_done(ctx, transaction_id(txn); start_time_ns)
wait_until_done(ctx, transaction_id(txn); start_time)
end
function _transaction_start_time_ns(txn::JSON3.Object)
return txn[:created_on] ÷ 1_000_000_000
function _transaction_start_time(txn::JSON3.Object)
# The API returns *milliseconds* since the epoch
return txn[:created_on] / 1e3
end
function wait_until_done(ctx::Context, id::AbstractString; start_time_ns = nothing)
function wait_until_done(ctx::Context, id::AbstractString;
start_time_ns = nothing, # deprecated
start_time = nothing,
)
if start_time_ns !== nothing
start_time = start_time_ns / 1e9,
@warn "wait_until_done(): start_time_ns= is deprecated; please pass start_time= as a unix timestamp instead."
end

# If the user is calling this manually, read the start time from the transaction object.
if start_time_ns === nothing
if start_time === nothing
txn = get_transaction(ctx, id)
start_time_ns = _transaction_start_time_ns(txn)
start_time = _transaction_start_time(txn)
end
try
_poll_with_specified_overhead(; overhead_rate = 0.10, start_time_ns) do
_poll_with_specified_overhead(; overhead_rate = TXN_POLLING_OVERHEAD, start_time) do
txn = get_transaction(ctx, id)
return transaction_is_done(txn)
end
Expand Down Expand Up @@ -125,14 +147,14 @@ end
function _poll_with_specified_overhead(
f;
overhead_rate, # Add xx% overhead through polling.
start_time_ns = time_ns(), # Optional start time, otherwise defaults to now()
start_time = time(), # Optional start time, otherwise defaults to now()
n = typemax(Int), # Maximum number of polls
max_delay = 120, # 2 min
timeout_secs = Inf, # no timeout by default
throw_on_timeout = false,
)
@debug "start time: $start_time"
@assert overhead_rate >= 0.0
timeout_ns = timeout_secs * 1e9
local iter
for i in 1:n
iter = i
Expand All @@ -142,17 +164,19 @@ function _poll_with_specified_overhead(
if done
return nothing
end
current_delay = time_ns() - start_time_ns
if current_delay > timeout_ns
t = @mock(time())
@debug "time: $t"
current_delay_s = t - start_time
if current_delay_s > timeout_secs
break
end
duration = (current_delay * overhead_rate) / 1e9
duration = current_delay_s * overhead_rate
duration = min(duration, max_delay) # clamp the duration as specified.
sleep(duration)
@mock sleep(duration)
end

# We have exhausted the iterator.
current_delay_secs = (time_ns() - start_time_ns) * 1e9
current_delay_secs = time() - start_time
throw_on_timeout && error("Timed out after $iter iterations, $current_delay_secs seconds in `_poll_with_specified_overhead`.")

return nothing
Expand Down Expand Up @@ -526,14 +550,14 @@ Dict{String, Any} with 4 entries:
function exec(ctx::Context, database::AbstractString, engine::AbstractString, source; inputs = nothing, readonly = false, kw...)
# Record the initial start time so that we include the time to create the transaction
# in our exponential backoff in `wait_until_done()`.
start_time_ns = time_ns()
start_time = time()
# Create an Async transaction:
transactionResponse = exec_async(ctx, database, engine, source; inputs=inputs, readonly=readonly, kw...)
if transactionResponse.results !== nothing
return transactionResponse
end
# Poll until the transaction is done, and return the results.
return wait_until_done(ctx, transactionResponse; start_time_ns = start_time_ns)
return wait_until_done(ctx, transactionResponse; start_time = start_time)
end

function exec_async(ctx::Context, database::AbstractString, engine::AbstractString, source; inputs = nothing, readonly = false, kw...)
Expand Down
4 changes: 3 additions & 1 deletion test/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import ProtoBuf

Mocking.activate()

include("wait_until_done.jl")

# -----------------------------------
# v2 transactions

Expand Down Expand Up @@ -269,7 +271,7 @@ end
)

apply(metadata_404_patch) do
RAI.wait_until_done(ctx, "<txn-id>", start_time_ns=0)
RAI.wait_until_done(ctx, "<txn-id>", start_time=0)
end
end

Expand Down
8 changes: 4 additions & 4 deletions test/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function with_engine(f, ctx; existing_engine=nothing)
engine_name = rnd_test_name()
if isnothing(existing_engine)
custom_headers = get(ENV, "CUSTOM_HEADERS", nothing)
start_time_ns = time_ns()
start_time = time()
if isnothing(custom_headers)
create_engine(ctx, engine_name)
else
Expand All @@ -66,7 +66,7 @@ function with_engine(f, ctx; existing_engine=nothing)
headers = JSON3.read(custom_headers, Dict{String, String})
create_engine(ctx, engine_name; nothing, headers)
end
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time_ns) do
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time) do
state = get_engine(ctx, engine_name)[:state]
state == "PROVISION_FAILED" && throw("Failed to provision engine $engine_name")
state == "PROVISIONED"
Expand All @@ -80,8 +80,8 @@ function with_engine(f, ctx; existing_engine=nothing)
# Engines cannot be deleted if they are still provisioning. We have to at least wait
# until they are ready.
if isnothing(existing_engine)
start_time_ns = time_ns() - 2e9 # assume we started 2 seconds ago
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time_ns) do
start_time = time() - 2 # assume we started 2 seconds ago
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time) do
state = get_engine(ctx, engine_name)[:state]
state == "PROVISION_FAILED" && throw("Failed to provision engine $engine_name")
state == "PROVISIONED"
Expand Down
72 changes: 72 additions & 0 deletions test/wait_until_done.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using ExceptionUnwrapping: unwrap_exception_to_root

# This test is _pretty complicated_ since it's trying to test something that depends on
# timing: testing that wait_until_done() polls for the expected amount of time in between
# calls to get_transaction.
# Testing anything to do with timing is always complicated. We tackle it here by mocking
# both sleep() and time(), and injecting fake times, and then making sure that the
# function is computing the correct duration to sleep, based on those times.
@testset "wait_until_done polls correctly" begin
now_ms = round(Int, time() * 1e3)
txn_str = """{
"id": "a3e3bc91-0a98-50ba-733c-0987e160eb7d",
"results_format_version": "2.0.1",
"state": "RUNNING",
"created_on": $(now_ms)
}"""
txn = JSON3.read(txn_str)

ctx = Context("region", "scheme", "host", "2342", nothing, "audience")

start = now_ms / 1e3
# Simulate OVERHEAD of 0.1 + round-trip-time of 0.5
times = [
start + 2, # First call takes 2 seconds then returns async
start + 2.2 + 0.5, # So we slept 0.2 seconds, then get_txn takes 0.5 secs
start + 2.97 + 0.5 # Now we sleep 2.7 * 1.1 ≈ 2.97, then again 0.5 RTT.
]
i = 1
time_patch = @patch function Base.time()
v = times[i]
i += 1
return v
end
# Here, we test that each call to sleep is the correct calculation of current "time"
# minus start time * the overhead.
sleep_patch = @patch function Base.sleep(duration)
@info "Mock sleep for $duration"
@test duration ≈ (times[i-1] - start) * RAI.TXN_POLLING_OVERHEAD
end

# This is returned on each get_txn() request.
unfinished_response = HTTP.Response(
200,
["Content-Type" => "application/json"],
body = """{"transaction": $(txn_str)}"""
)

# Stop the test after 3 polls.
ABORT = :ABORT_TEST

request_patch = @patch function RAI.request(ctx::Context, args...; kw...)
if i <= 3
return unfinished_response
else
# Finish the test
throw(ABORT)
end
end

# Call the function with the patches. Assert that it ends with our ABORT exception.
apply([time_patch, sleep_patch, request_patch]) do
try
wait_until_done(ctx, txn)
catch e
@assert unwrap_exception_to_root(e) == ABORT
end
end

# Test that we made it through all the expected polls, so that we know the above
# `@test`s all triggered.
@test i == 4
end