diff --git a/src/enclave/enclave_init.c b/src/enclave/enclave_init.c index 50101b89b..580d8dc40 100644 --- a/src/enclave/enclave_init.c +++ b/src/enclave/enclave_init.c @@ -93,9 +93,6 @@ static int app_main_thread(void* args) { SGXLKL_VERBOSE("enter\n"); - lthread_set_funcname(lthread_self(), "app-main"); - lthread_set_app_main(); - /* Set locale for userspace components using it */ SGXLKL_VERBOSE("Setting locale\n"); pthread_t self = __pthread_self(); @@ -131,7 +128,17 @@ static int kernel_main_thread(void* args) lthread_set_funcname(lthread_self(), "sgx-lkl-init"); struct lthread* lt; - if (lthread_create(<, NULL, app_main_thread, NULL) != 0) + struct lthread_attr lt_attr = {0}; + + // Denote this thread to be the main application thread + lt_attr.state = BIT(LT_ST_APP_MAIN); + oe_strncpy_s( + lt_attr.funcname, + sizeof(lt_attr.funcname), + "app_name_thread", + sizeof("app_name_thread")); + + if (lthread_create(<, <_attr, app_main_thread, NULL) != 0) { sgxlkl_fail("Failed to create lthread for app_main_thread\n"); } diff --git a/src/enclave/enclave_signal.c b/src/enclave/enclave_signal.c index a571b56e9..b29f032c6 100644 --- a/src/enclave/enclave_signal.c +++ b/src/enclave/enclave_signal.c @@ -184,7 +184,7 @@ static uint64_t sgxlkl_enclave_signal_handler( "ret=%i)\n", trap_info.description, lt ? lt->tid : -1, - lt ? lt->funcname : "(?)", + lt ? lt->attr.funcname : "(?)", exception_record->code, (void*)exception_record->address, opcode, @@ -216,7 +216,7 @@ static uint64_t sgxlkl_enclave_signal_handler( "ret=%i)\n", trap_info.description, lt ? lt->tid : -1, - lt ? lt->funcname : "(?)", + lt ? lt->attr.funcname : "(?)", exception_record->code, (void*)exception_record->address, opcode, diff --git a/src/include/enclave/lthread.h b/src/include/enclave/lthread.h index 3e76b1a90..79c77667f 100644 --- a/src/include/enclave/lthread.h +++ b/src/include/enclave/lthread.h @@ -110,6 +110,7 @@ struct lthread_attr _Atomic(int) state; /* current lthread state */ void* stack; /* ptr to lthread_stack */ int thread_type; /* type of thread: usermode or lkl kernel */ + char funcname[64]; /* optional func name */ }; typedef void (*sig_handler)(int sig, siginfo_t* si, void* unused); @@ -138,7 +139,6 @@ struct lthread void* arg; /* func args passed to func */ struct lthread_attr attr; /* various attributes */ int tid; /* lthread id */ - char funcname[64]; /* optional func name */ struct lthread* lt_join; /* lthread we want to join on */ void** lt_exit_ptr; /* exit ptr for lthread_join */ uint32_t ops; /* num of ops since yield */ @@ -249,12 +249,10 @@ extern "C" /** * Make the current scheduler also terminate and exit the enclave after the - * calling lthread has returned. + * calling lthread next yields. */ void lthread_terminate_this_scheduler(void); - void lthread_set_app_main(void); - /** * Run the main scheduler loop. * diff --git a/src/sched/lthread.c b/src/sched/lthread.c index 2551fd054..99aff4802 100644 --- a/src/sched/lthread.c +++ b/src/sched/lthread.c @@ -209,13 +209,6 @@ void lthread_terminate_this_scheduler(void) lt->attr.state |= BIT(LT_ST_TERMINATE); } -void lthread_set_app_main(void) -{ - struct lthread* lt = lthread_self(); - SGXLKL_ASSERT(lt); - lt->attr.state |= BIT(LT_ST_APP_MAIN); -} - int lthread_run(void) { const struct lthread_sched* const sched = lthread_get_sched(); @@ -361,6 +354,8 @@ void _lthread_yield(struct lthread* lt) void _lthread_free(struct lthread* lt) { + // Only run the destructors if this is not the main application thread, + // otherwise it would get deallocated twice if (lthread_self() != NULL && !(lt->attr.state & BIT(LT_ST_APP_MAIN))) { lthread_rundestructors(lt); @@ -391,9 +386,9 @@ static void init_tp(struct lthread *lt, unsigned char *mem, size_t sz) { mem += sz - sizeof(struct lthread_tcb_base); mem -= (uintptr_t)mem & (TLS_ALIGN - 1); - lt->tp = mem; - struct lthread_tcb_base *tcb = (struct lthread_tcb_base *)mem; - tcb->self = mem; + lt->tp = (uintptr_t*)mem; + struct lthread_tcb_base* tcb = (struct lthread_tcb_base*)mem; + tcb->self = mem; } static void set_fsbase(void* tp){ @@ -522,16 +517,10 @@ static void _lthread_init(struct lthread* lt) lt->ctx.esp = (void*)((uintptr_t)stack - (4 * sizeof(void*))); lt->ctx.ebp = (void*)((uintptr_t)stack - (3 * sizeof(void*))); lt->ctx.eip = (void*)_exec; - /* this is equivalent to unlock */ - a_barrier(); - if (lt->attr.state & BIT(LT_ST_DETACH)) - { - lt->attr.state = BIT(LT_ST_READY) | BIT(LT_ST_DETACH); - } - else - { - lt->attr.state = BIT(LT_ST_READY); - } + + lt->attr.state &= CLEARBIT(LT_ST_NEW); + + _lthread_unlock(lt); } int _lthread_sched_init(size_t stack_size) @@ -582,7 +571,7 @@ int lthread_create_primitive( static unsigned long long n = 0; oe_snprintf( - lt->funcname, + lt->attr.funcname, 64, "cloned host task %llu", __atomic_fetch_add(&n, 1, __ATOMIC_SEQ_CST)); @@ -659,16 +648,25 @@ int lthread_create( lt->attr.state = BIT(LT_ST_NEW) | (attrp ? attrp->state : 0); lt->attr.thread_type = LKL_KERNEL_THREAD; + lt->attr.funcname[0] = '\0'; lt->tid = a_fetch_add(&spawned_lthreads, 1); lt->fun = fun; lt->arg = arg; LIST_INIT(<->tls); - // Inherit name from parent - if (lthread_self() && lthread_self()->funcname) + // Did we get a thread name? + if (attrp && attrp->funcname) { - lthread_set_funcname(lt, lthread_self()->funcname); + lthread_set_funcname(lt, attrp->funcname); + } + else + { + // Inherit the thread name from the parent + if (lthread_self() && lthread_self()->attr.funcname) + { + lthread_set_funcname(lt, lthread_self()->attr.funcname); + } } if (new_lt) @@ -815,8 +813,12 @@ void lthread_detach2(struct lthread* lt) void lthread_set_funcname(struct lthread* lt, const char* f) { - oe_strncpy_s(lt->funcname, 64, f, 64); - lt->funcname[64 - 1] = 0; + oe_strncpy_s( + lt->attr.funcname, + sizeof(lt->attr.funcname), + f, + sizeof(lt->attr.funcname)); + lt->attr.funcname[sizeof(lt->attr.funcname) - 1] = '\0'; } uint64_t lthread_id(void) @@ -1009,6 +1011,7 @@ static void lthread_state_to_string( STRINGIFY_LT_STATE(DETACH) STRINGIFY_LT_STATE(PINNED) STRINGIFY_LT_STATE(TERMINATE) + STRINGIFY_LT_STATE(APP_MAIN) lt_state_str[offset - 1] = '\0'; } @@ -1036,7 +1039,7 @@ void lthread_dump_all_threads(bool is_lthread) if (lt) { int tid = lt->tid; - char* funcname = lt->funcname; + char* funcname = lt->attr.funcname; lthread_state_to_string(lt, lt_state_str, 1024);