Skip to content

Commit

Permalink
remove Ref allocation on task switch
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson committed Apr 27, 2020
1 parent 86ee57c commit a36eab6
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 16 deletions.
23 changes: 14 additions & 9 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ function yield()
end
end

@inline set_next_task(t::Task) = ccall(:jl_set_next_task, Cvoid, (Any,), t)

"""
yield(t::Task, arg = nothing)
Expand All @@ -624,7 +626,8 @@ immediately yields to `t` before calling the scheduler.
function yield(t::Task, @nospecialize(x=nothing))
t.result = x
enq_work(current_task())
return try_yieldto(ensure_rescheduled, Ref(t))
set_next_task(t)
return try_yieldto(ensure_rescheduled)
end

"""
Expand All @@ -637,14 +640,15 @@ or scheduling in any way. Its use is discouraged.
"""
function yieldto(t::Task, @nospecialize(x=nothing))
t.result = x
return try_yieldto(identity, Ref(t))
set_next_task(t)
return try_yieldto(identity)
end

function try_yieldto(undo, reftask::Ref{Task})
function try_yieldto(undo)
try
ccall(:jl_switchto, Cvoid, (Any,), reftask)
ccall(:jl_switch, Cvoid, ())
catch
undo(reftask[])
undo(ccall(:jl_get_next_task, Ref{Task}, ()))
rethrow()
end
ct = current_task()
Expand Down Expand Up @@ -696,18 +700,19 @@ function trypoptask(W::StickyWorkqueue)
return t
end

@noinline function poptaskref(W::StickyWorkqueue)
@noinline function poptask(W::StickyWorkqueue)
task = trypoptask(W)
if !(task isa Task)
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any), trypoptask, W)
end
return Ref(task)
set_next_task(task)
nothing
end

function wait()
W = Workqueues[Threads.threadid()]
reftask = poptaskref(W)
result = try_yieldto(ensure_rescheduled, reftask)
poptask(W)
result = try_yieldto(ensure_rescheduled)
process_events()
# return when we come out of the queue
return result
Expand Down
10 changes: 10 additions & 0 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,16 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
tbaa_decorate(tbaa_const, ctx.builder.CreateLoad(pct)),
retboxed, rt, unionall, static_rt);
}
else if (is_libjulia_func(jl_set_next_task)) {
assert(lrt == T_void);
assert(!isVa && !llvmcall && nccallargs == 1);
JL_GC_POP();
Value *ptls_pv = emit_bitcast(ctx, ctx.ptlsStates, T_ppjlvalue);
const int nt_offset = offsetof(jl_tls_states_t, next_task);
Value *pnt = ctx.builder.CreateGEP(ptls_pv, ConstantInt::get(T_size, nt_offset / sizeof(void*)));
ctx.builder.CreateStore(emit_pointer_from_objref(ctx, boxed(ctx, argv[0])), pnt);
return ghostValue(jl_nothing_type);
}
else if (is_libjulia_func(jl_sigatomic_begin)) {
assert(lrt == T_void);
assert(!isVa && !llvmcall && nccallargs == 0);
Expand Down
2 changes: 2 additions & 0 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2645,6 +2645,8 @@ static void jl_gc_queue_thread_local(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp
{
gc_mark_queue_obj(gc_cache, sp, ptls2->current_task);
gc_mark_queue_obj(gc_cache, sp, ptls2->root_task);
if (ptls2->next_task)
gc_mark_queue_obj(gc_cache, sp, ptls2->next_task);
if (ptls2->previous_exception)
gc_mark_queue_obj(gc_cache, sp, ptls2->previous_exception);
}
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,7 @@ JL_DLLEXPORT int jl_array_isassigned(jl_array_t *a, size_t i);

JL_DLLEXPORT uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_get_current_task(void);
JL_DLLEXPORT void jl_set_next_task(jl_task_t *task);

// -- synchronization utilities -- //

Expand Down
1 change: 1 addition & 0 deletions src/julia_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct _jl_tls_states_t {
uv_cond_t wake_signal;
volatile sig_atomic_t defer_signal;
struct _jl_task_t *current_task;
struct _jl_task_t *next_task;
#ifdef MIGRATE_TASKS
struct _jl_task_t *previous_task;
#endif
Expand Down
34 changes: 27 additions & 7 deletions src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ static void NOINLINE save_stack(jl_ptls_t ptls, jl_task_t *lastt, jl_task_t **pt
else {
buf = lastt->stkbuf;
}
*pt = lastt; // clear the gc-root for the target task before copying the stack for saving
*pt = NULL; // clear the gc-root for the target task before copying the stack for saving
lastt->copy_stack = nb;
lastt->sticky = 1;
memcpy_a16((uint64_t*)buf, (uint64_t*)frame_addr, nb);
Expand Down Expand Up @@ -248,10 +248,24 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel)
_julia_init(rel);
}

JL_DLLEXPORT void jl_set_next_task(jl_task_t *task)
{
jl_get_ptls_states()->next_task = task;
}

JL_DLLEXPORT jl_task_t *jl_get_next_task(void)
{
jl_ptls_t ptls = jl_get_ptls_states();
if (ptls->next_task)
return ptls->next_task;
return ptls->current_task;
}

void jl_release_task_stack(jl_ptls_t ptls, jl_task_t *task);

static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt)
static void ctx_switch(jl_ptls_t ptls)
{
jl_task_t **pt = &ptls->next_task;
jl_task_t *t = *pt;
assert(t != ptls->current_task);
jl_task_t *lastt = ptls->current_task;
Expand Down Expand Up @@ -283,7 +297,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt)
}

if (killed) {
*pt = lastt; // can't fail after here: clear the gc-root for the target task now
*pt = NULL; // can't fail after here: clear the gc-root for the target task now
lastt->gcstack = NULL;
if (!lastt->copy_stack && lastt->stkbuf) {
// early free of stkbuf back to the pool
Expand All @@ -302,7 +316,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt)
}
else
#endif
*pt = lastt; // can't fail after here: clear the gc-root for the target task now
*pt = NULL; // can't fail after here: clear the gc-root for the target task now
lastt->gcstack = ptls->pgcstack;
}

Expand Down Expand Up @@ -366,10 +380,10 @@ static jl_ptls_t NOINLINE refetch_ptls(void)
return jl_get_ptls_states();
}

JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
JL_DLLEXPORT void jl_switch(void)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_task_t *t = *pt;
jl_task_t *t = ptls->next_task;
jl_task_t *ct = ptls->current_task;
if (t == ct) {
return;
Expand Down Expand Up @@ -401,7 +415,7 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
jl_timing_block_stop(blk);
#endif

ctx_switch(ptls, pt);
ctx_switch(ptls);

#ifdef MIGRATE_TASKS
ptls = refetch_ptls();
Expand Down Expand Up @@ -432,6 +446,12 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
jl_sigint_safepoint(ptls);
}

JL_DLLEXPORT void jl_switchto(jl_task_t **pt)
{
jl_set_next_task(*pt);
jl_switch();
}

JL_DLLEXPORT JL_NORETURN void jl_no_exc_handler(jl_value_t *e)
{
jl_printf(JL_STDERR, "fatal: error thrown and no exception handler available.\n");
Expand Down
1 change: 1 addition & 0 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ void jl_init_threadtls(int16_t tid)
ptls->bt_data = bt_data;
ptls->sig_exception = NULL;
ptls->previous_exception = NULL;
ptls->next_task = NULL;
#ifdef _OS_WINDOWS_
ptls->needs_resetstkoflw = 0;
#endif
Expand Down

0 comments on commit a36eab6

Please sign in to comment.