Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

<functional>: Constrain functions used by std::bind #3577

Merged
merged 4 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions stl/inc/functional
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ struct _Unforced { // tag to distinguish bind() from bind<R>()
// helper to give INVOKE an explicit return type; avoids undesirable Expression SFINAE
template <class _Rx>
struct _Invoker_ret { // selected for all _Rx other than _Unforced
template <class _Fx, class... _Valtys>
template <class _Fx, class... _Valtys,
enable_if_t<_Select_invoke_traits<_Fx, _Valtys...>::template _Is_invocable_r<_Rx>::value, int> = 0>
static _CONSTEXPR20 _Rx _Call(_Fx&& _Func, _Valtys&&... _Vals) noexcept(_Select_invoke_traits<_Fx,
_Valtys...>::template _Is_nothrow_invocable_r<_Rx>::value) { // INVOKE, implicitly converted
if constexpr (is_void_v<_Rx>) {
Expand Down Expand Up @@ -760,7 +761,11 @@ private:
}

_Rx _Do_call(_Types&&... _Args) override { // call wrapped function
return _Invoker_ret<_Rx>::_Call(_Mypair._Myval2, _STD forward<_Types>(_Args)...);
if constexpr (is_void_v<_Rx>) {
(void) _STD invoke(_Mypair._Myval2, _STD forward<_Types>(_Args)...);
} else {
return _STD invoke(_Mypair._Myval2, _STD forward<_Types>(_Args)...);
}
}

const type_info& _Target_type() const noexcept override {
Expand Down Expand Up @@ -817,7 +822,11 @@ private:
}

_Rx _Do_call(_Types&&... _Args) override { // call wrapped function
return _Invoker_ret<_Rx>::_Call(_Callee, _STD forward<_Types>(_Args)...);
if constexpr (is_void_v<_Rx>) {
(void) _STD invoke(_Callee, _STD forward<_Types>(_Args)...);
} else {
return _STD invoke(_Callee, _STD forward<_Types>(_Args)...);
}
}

const type_info& _Target_type() const noexcept override {
Expand Down Expand Up @@ -1232,14 +1241,20 @@ template <class _Rx, class... _Types>

template <class _Vt, class _VtInvQuals, class _Rx, bool _Noex, class... _Types>
_NODISCARD _Rx __stdcall _Function_inv_small(const _Move_only_function_data& _Self, _Types&&... _Args) noexcept(_Noex) {
return _Invoker_ret<_Rx>::_Call(
static_cast<_VtInvQuals>(*_Self._Small_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...);
if constexpr (is_void_v<_Rx>) {
(void) _STD invoke(static_cast<_VtInvQuals>(*_Self._Small_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...);
} else {
return _STD invoke(static_cast<_VtInvQuals>(*_Self._Small_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...);
}
}

template <class _Vt, class _VtInvQuals, class _Rx, bool _Noex, class... _Types>
_NODISCARD _Rx __stdcall _Function_inv_large(const _Move_only_function_data& _Self, _Types&&... _Args) noexcept(_Noex) {
return _Invoker_ret<_Rx>::_Call(
static_cast<_VtInvQuals>(*_Self._Large_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...);
if constexpr (is_void_v<_Rx>) {
(void) _STD invoke(static_cast<_VtInvQuals>(*_Self._Large_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...);
} else {
return _STD invoke(static_cast<_VtInvQuals>(*_Self._Large_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...);
}
}

template <class _Vt>
Expand Down Expand Up @@ -1910,7 +1925,8 @@ struct _Select_fixer<_Cv_TiD, true, false, 0> { // reference_wrapper fixer

template <class _Cv_TiD>
struct _Select_fixer<_Cv_TiD, false, true, 0> { // nested bind fixer
template <class _Untuple, size_t... _Jx>
template <class _Untuple, size_t... _Jx,
enable_if_t<conjunction_v<bool_constant<(_Jx < tuple_size_v<_Untuple>)>...>, int> = 0>
static constexpr auto _Apply(_Cv_TiD& _Tid, _Untuple&& _Ut, index_sequence<_Jx...>) noexcept(
noexcept(_Tid(_STD get<_Jx>(_STD move(_Ut))...))) -> decltype(_Tid(_STD get<_Jx>(_STD move(_Ut))...)) {
// call a nested bind expression
Expand Down Expand Up @@ -1939,7 +1955,7 @@ template <class _Cv_TiD, int _Jx>
struct _Select_fixer<_Cv_TiD, false, false, _Jx> { // placeholder fixer
static_assert(_Jx > 0, "invalid is_placeholder value");

template <class _Untuple>
template <class _Untuple, enable_if_t<(_Jx <= tuple_size_v<_Untuple>), int> = 0>
static constexpr auto _Fix(_Cv_TiD&, _Untuple&& _Ut) noexcept
-> decltype(_STD get<_Jx - 1>(_STD move(_Ut))) { // choose the Jth unbound argument (1-based indexing)
return _STD get<_Jx - 1>(_STD move(_Ut));
Expand All @@ -1956,10 +1972,10 @@ constexpr auto _Fix_arg(_Cv_TiD& _Tid, _Untuple&& _Ut) noexcept(
template <class _Ret, size_t... _Ix, class _Cv_FD, class _Cv_tuple_TiD, class _Untuple>
_CONSTEXPR20 auto _Call_binder(_Invoker_ret<_Ret>, index_sequence<_Ix...>, _Cv_FD& _Obj, _Cv_tuple_TiD& _Tpl,
_Untuple&& _Ut) noexcept(noexcept(_Invoker_ret<_Ret>::_Call(_Obj,
_Fix_arg(_STD get<_Ix>(_Tpl), _STD move(_Ut))...)))
-> decltype(_Invoker_ret<_Ret>::_Call(_Obj, _Fix_arg(_STD get<_Ix>(_Tpl), _STD move(_Ut))...)) {
_STD _Fix_arg(_STD get<_Ix>(_Tpl), _STD move(_Ut))...)))
-> decltype(_Invoker_ret<_Ret>::_Call(_Obj, _STD _Fix_arg(_STD get<_Ix>(_Tpl), _STD move(_Ut))...)) {
// bind() and bind<R>() invocation
return _Invoker_ret<_Ret>::_Call(_Obj, _Fix_arg(_STD get<_Ix>(_Tpl), _STD move(_Ut))...);
return _Invoker_ret<_Ret>::_Call(_Obj, _STD _Fix_arg(_STD get<_Ix>(_Tpl), _STD move(_Ut))...);
}

template <class _Ret>
Expand Down
1 change: 1 addition & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ tests\GH_000690_overaligned_function
tests\GH_000890_pow_template
tests\GH_000935_complex_numerical_accuracy
tests\GH_000940_missing_valarray_copy
tests\GH_000952_bind_constraints
tests\GH_000990_any_link_without_exceptions
tests\GH_001001_random_rejection_rounding
tests\GH_001010_filesystem_error_encoding
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/GH_000952_bind_constraints/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\usual_matrix.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <functional>
#include <type_traits>

#define STATIC_ASSERT(...) static_assert(__VA_ARGS__, #__VA_ARGS__)

using namespace std;

void test() { // COMPILE-ONLY
{
auto lambda = [](int) {};
auto f = bind(lambda, placeholders::_1);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void()>>);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void*(int)>>);
STATIC_ASSERT(is_convertible_v<decltype(f), function<void(int)>>);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<char(int)>>);
}
{
auto lambda = [](int) { return 42; };
auto f = bind<void>(lambda, placeholders::_1);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void()>>);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void*(int)>>);
STATIC_ASSERT(is_convertible_v<decltype(f), function<void(int)>>);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<char(int)>>);
}
{
auto lambda = [](int) { return 42; };
auto f = bind<long>(lambda, placeholders::_1);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void()>>);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void*(int)>>);
STATIC_ASSERT(is_convertible_v<decltype(f), function<void(int)>>);
STATIC_ASSERT(is_convertible_v<decltype(f), function<char(int)>>);
}
{
auto lambda0 = [](int, int) { return true; };
auto lambda1 = [](const char*) { return 0; };
auto f = bind(lambda0, placeholders::_1, bind(lambda1, placeholders::_2));
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void()>>);
STATIC_ASSERT(!is_convertible_v<decltype(f), function<void*(int, const char*)>>);
STATIC_ASSERT(is_convertible_v<decltype(f), function<void(int, const char*)>>);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
STATIC_ASSERT(is_convertible_v<decltype(f), function<bool(int, const char*)>>);
}
}