Skip to content

Commit

Permalink
Reduce duplication between TaskPromise specializations
Browse files Browse the repository at this point in the history
Summary: Make the `TaskPromise<>` leaf classes as simple as they can be. This shortens the code by 30 lines, and makes it easier to follow.

Reviewed By: yfeldblum

Differential Revision: D61249849

fbshipit-source-id: c12ca1d96c2c99fd85d53df9a77146450d525ca6
  • Loading branch information
Alexey Spiridonov authored and facebook-github-bot committed Aug 17, 2024
1 parent a2aae6d commit 31a4de6
Showing 1 changed file with 57 additions and 87 deletions.
144 changes: 57 additions & 87 deletions folly/coro/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ class TaskPromiseBase {
}

private:
template <typename T>
template <typename>
friend class folly::coro::TaskWithExecutor;

template <typename T>
template <typename>
friend class folly::coro::Task;

friend coroutine_handle<ScopeExitTaskPromiseBase> tag_invoke(
Expand All @@ -196,45 +196,19 @@ class TaskPromiseBase {
} bypassExceptionThrowing_{BypassExceptionThrowing::INACTIVE};
};

template <typename T>
class TaskPromise final : public TaskPromiseBase,
public ExtendedCoroutinePromiseImpl<TaskPromise<T>> {
// Separate from `TaskPromiseBase` so the compiler has less to specialize.
template <typename Promise, typename T>
class TaskPromiseCrtpBase : public TaskPromiseBase,
public ExtendedCoroutinePromiseImpl<Promise> {
public:
static_assert(
!std::is_rvalue_reference_v<T>,
"Task<T&&> is not supported. "
"Consider using Task<T> or Task<std::unique_ptr<T>> instead.");
friend class TaskPromiseBase;

using StorageType = detail::lift_lvalue_reference_t<T>;

TaskPromise() noexcept = default;

Task<T> get_return_object() noexcept;

void unhandled_exception() noexcept {
result_.emplaceException(exception_wrapper{current_exception()});
}

template <typename U = T>
void return_value(U&& value) {
if constexpr (std::is_same_v<remove_cvref_t<U>, Try<StorageType>>) {
DCHECK(value.hasValue() || (value.hasException() && value.exception()));
result_ = static_cast<U&&>(value);
} else if constexpr (
std::is_same_v<remove_cvref_t<U>, Try<void>> &&
std::is_same_v<remove_cvref_t<T>, Unit>) {
// special-case to make task -> semifuture -> task preserve void type
DCHECK(value.hasValue() || (value.hasException() && value.exception()));
result_ = static_cast<Try<Unit>>(static_cast<U&&>(value));
} else {
static_assert(
std::is_convertible<U&&, StorageType>::value,
"cannot convert return value to type T");
result_.emplace(static_cast<U&&>(value));
}
}

Try<StorageType>& result() { return result_; }

auto yield_value(co_error ex) {
Expand All @@ -253,81 +227,80 @@ class TaskPromise final : public TaskPromiseBase,
return do_safe_point(*this);
}

protected:
TaskPromiseCrtpBase() noexcept = default;
~TaskPromiseCrtpBase() = default;

std::pair<ExtendedCoroutineHandle, AsyncStackFrame*> getErrorHandle(
exception_wrapper& ex) override {
auto& me = *static_cast<Promise*>(this);
if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
auto finalAwaiter = yield_value(co_error(std::move(ex)));
DCHECK(!finalAwaiter.await_ready());
return {
finalAwaiter.await_suspend(
coroutine_handle<TaskPromise>::from_promise(*this)),
coroutine_handle<Promise>::from_promise(me)),
// finalAwaiter.await_suspend pops a frame
getAsyncFrame().getParentFrame()};
}
return {coroutine_handle<TaskPromise>::from_promise(*this), nullptr};
return {coroutine_handle<Promise>::from_promise(me), nullptr};
}

private:
Try<StorageType> result_;
};

template <>
class TaskPromise<void> final
: public TaskPromiseBase,
public ExtendedCoroutinePromiseImpl<TaskPromise<void>> {
template <typename T>
class TaskPromise final : public TaskPromiseCrtpBase<TaskPromise<T>, T> {
public:
static_assert(
!std::is_rvalue_reference_v<T>,
"Task<T&&> is not supported. "
"Consider using Task<T> or Task<std::unique_ptr<T>> instead.");
friend class TaskPromiseBase;

using StorageType = void;
using StorageType =
typename TaskPromiseCrtpBase<TaskPromise<T>, T>::StorageType;

TaskPromise() noexcept = default;

Task<void> get_return_object() noexcept;

void unhandled_exception() noexcept {
result_.emplaceException(exception_wrapper{current_exception()});
template <typename U = T>
void return_value(U&& value) {
if constexpr (std::is_same_v<remove_cvref_t<U>, Try<StorageType>>) {
DCHECK(value.hasValue() || (value.hasException() && value.exception()));
this->result_ = static_cast<U&&>(value);
} else if constexpr (
std::is_same_v<remove_cvref_t<U>, Try<void>> &&
std::is_same_v<remove_cvref_t<T>, Unit>) {
// special-case to make task -> semifuture -> task preserve void type
DCHECK(value.hasValue() || (value.hasException() && value.exception()));
this->result_ = static_cast<Try<Unit>>(static_cast<U&&>(value));
} else {
static_assert(
std::is_convertible<U&&, StorageType>::value,
"cannot convert return value to type T");
this->result_.emplace(static_cast<U&&>(value));
}
}
};

void return_void() noexcept { result_.emplace(); }

Try<void>& result() { return result_; }
template <>
class TaskPromise<void> final
: public TaskPromiseCrtpBase<TaskPromise<void>, void> {
public:
friend class TaskPromiseBase;

auto yield_value(co_error ex) {
result_.emplaceException(std::move(ex.exception()));
return final_suspend();
}
using StorageType = void;

auto yield_value(co_result<void>&& result) {
result_ = std::move(result.result());
return final_suspend();
}
auto yield_value(co_result<Unit>&& result) {
result_ = std::move(result.result());
return final_suspend();
}
TaskPromise() noexcept = default;

using TaskPromiseBase::await_transform;
void return_void() noexcept { this->result_.emplace(); }

auto await_transform(co_safe_point_t) noexcept {
return do_safe_point(*this);
}
using TaskPromiseCrtpBase<TaskPromise<void>, void>::yield_value;

std::pair<ExtendedCoroutineHandle, AsyncStackFrame*> getErrorHandle(
exception_wrapper& ex) override {
if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
auto finalAwaiter = yield_value(co_error(std::move(ex)));
DCHECK(!finalAwaiter.await_ready());
return {
finalAwaiter.await_suspend(
coroutine_handle<TaskPromise>::from_promise(*this)),
// finalAwaiter.await_suspend pops a frame
getAsyncFrame().getParentFrame()};
}
return {coroutine_handle<TaskPromise>::from_promise(*this), nullptr};
auto yield_value(co_result<Unit>&& result) {
this->result_ = std::move(result.result());
return final_suspend();
}

private:
Try<void> result_;
};

} // namespace detail
Expand Down Expand Up @@ -786,7 +759,7 @@ class FOLLY_NODISCARD Task {

private:
friend class detail::TaskPromiseBase;
friend class detail::TaskPromise<T>;
friend class detail::TaskPromiseCrtpBase<detail::TaskPromise<T>, T>;
friend class TaskWithExecutor<T>;

class Awaiter {
Expand Down Expand Up @@ -890,14 +863,11 @@ Task<drop_unit_t<T>> makeResultTask(Try<T> t) {
co_yield co_result(std::move(t));
}

template <typename T>
Task<T> detail::TaskPromise<T>::get_return_object() noexcept {
return Task<T>{coroutine_handle<detail::TaskPromise<T>>::from_promise(*this)};
}

inline Task<void> detail::TaskPromise<void>::get_return_object() noexcept {
return Task<void>{
coroutine_handle<detail::TaskPromise<void>>::from_promise(*this)};
template <typename Promise, typename T>
inline Task<T>
detail::TaskPromiseCrtpBase<Promise, T>::get_return_object() noexcept {
return Task<T>{
coroutine_handle<Promise>::from_promise(*static_cast<Promise*>(this))};
}

} // namespace coro
Expand Down

0 comments on commit 31a4de6

Please sign in to comment.