From 8c7b8dd0ae74b36b7d42f77b0dd4096ebb7f4ab1 Mon Sep 17 00:00:00 2001 From: Sergei Izmailov Date: Wed, 13 Sep 2023 04:48:27 +0900 Subject: [PATCH] fix: Missing typed variants of `iterator` and `iterable` (#4832) --- include/pybind11/typing.h | 20 ++++++++++++++++++++ tests/test_pytypes.cpp | 2 ++ tests/test_pytypes.py | 14 ++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/include/pybind11/typing.h b/include/pybind11/typing.h index 74fd82eace..b7b1a4e54a 100644 --- a/include/pybind11/typing.h +++ b/include/pybind11/typing.h @@ -45,6 +45,16 @@ class Set : public set { using set::set; }; +template +class Iterable : public iterable { + using iterable::iterable; +}; + +template +class Iterator : public iterator { + using iterator::iterator; +}; + template class Callable; @@ -85,6 +95,16 @@ struct handle_type_name> { static constexpr auto name = const_name("set[") + make_caster::name + const_name("]"); }; +template +struct handle_type_name> { + static constexpr auto name = const_name("Iterable[") + make_caster::name + const_name("]"); +}; + +template +struct handle_type_name> { + static constexpr auto name = const_name("Iterator[") + make_caster::name + const_name("]"); +}; + template struct handle_type_name> { using retval_type = conditional_t::value, void_type, Return>; diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index b03f5624e1..0a587680f8 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -828,6 +828,8 @@ TEST_SUBMODULE(pytypes, m) { m.def("annotate_dict_str_int", [](const py::typing::Dict &) {}); m.def("annotate_list_int", [](const py::typing::List &) {}); m.def("annotate_set_str", [](const py::typing::Set &) {}); + m.def("annotate_iterable_str", [](const py::typing::Iterable &) {}); + m.def("annotate_iterator_int", [](const py::typing::Iterator &) {}); m.def("annotate_fn", [](const py::typing::Callable, py::str)> &) {}); } diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index e88a9328f7..2b2027316c 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -926,6 +926,20 @@ def test_set_annotations(doc): assert doc(m.annotate_set_str) == "annotate_set_str(arg0: set[str]) -> None" +def test_iterable_annotations(doc): + assert ( + doc(m.annotate_iterable_str) + == "annotate_iterable_str(arg0: Iterable[str]) -> None" + ) + + +def test_iterator_annotations(doc): + assert ( + doc(m.annotate_iterator_int) + == "annotate_iterator_int(arg0: Iterator[int]) -> None" + ) + + def test_fn_annotations(doc): assert ( doc(m.annotate_fn)