Skip to content

Commit

Permalink
embind: Use optional return type for vector and maps.
Browse files Browse the repository at this point in the history
This helps create better TypeScript definitions for what
is actually returned.
  • Loading branch information
brendandahl committed Jan 31, 2024
1 parent c4d76b8 commit c9b106a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
59 changes: 52 additions & 7 deletions system/include/emscripten/bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,21 @@ class class_ {
}
};

#if __cplusplus >= 201703L
template<typename T>
void register_optional() {
// Optional types are automatically registered for some internal types so
// only run the register method once so we don't conflict with a user's
// bindings if they also register the optional type.
thread_local bool hasRun;
if (hasRun) {
return;
}
hasRun = true;
internal::_embind_register_optional(internal::TypeID<std::optional<T>>::get(), internal::TypeID<T>::get());
}
#endif

////////////////////////////////////////////////////////////////////////////////
// VECTORS
////////////////////////////////////////////////////////////////////////////////
Expand All @@ -1883,6 +1898,20 @@ namespace internal {

template<typename VectorType>
struct VectorAccess {
// This nearly duplicated code is used for generating more specific TypeScript
// types when using more modern C++ versions.
#if __cplusplus >= 201703L
static std::optional<typename VectorType::value_type> get(
const VectorType& v,
typename VectorType::size_type index
) {
if (index < v.size()) {
return v[index];
} else {
return {};
}
}
#else
static val get(
const VectorType& v,
typename VectorType::size_type index
Expand All @@ -1893,6 +1922,7 @@ struct VectorAccess {
return val::undefined();
}
}
#endif

static bool set(
VectorType& v,
Expand All @@ -1909,6 +1939,9 @@ struct VectorAccess {
template<typename T>
class_<std::vector<T>> register_vector(const char* name) {
typedef std::vector<T> VecType;
#if __cplusplus >= 201703L
register_optional<T>();
#endif

void (VecType::*push_back)(const T&) = &VecType::push_back;
void (VecType::*resize)(const size_t, const T&) = &VecType::resize;
Expand All @@ -1923,13 +1956,6 @@ class_<std::vector<T>> register_vector(const char* name) {
;
}

#if __cplusplus >= 201703L
template<typename T>
void register_optional() {
internal::_embind_register_optional(internal::TypeID<std::optional<T>>::get(), internal::TypeID<T>::get());
}
#endif

////////////////////////////////////////////////////////////////////////////////
// MAPS
////////////////////////////////////////////////////////////////////////////////
Expand All @@ -1938,6 +1964,21 @@ namespace internal {

template<typename MapType>
struct MapAccess {
// This nearly duplicated code is used for generating more specific TypeScript
// types when using more modern C++ versions.
#if __cplusplus >= 201703L
static std::optional<typename MapType::mapped_type> get(
const MapType& m,
const typename MapType::key_type& k
) {
auto i = m.find(k);
if (i == m.end()) {
return {};
} else {
return i->second;
}
}
#else
static val get(
const MapType& m,
const typename MapType::key_type& k
Expand All @@ -1949,6 +1990,7 @@ struct MapAccess {
return val(i->second);
}
}
#endif

static void set(
MapType& m,
Expand All @@ -1975,6 +2017,9 @@ struct MapAccess {
template<typename K, typename V>
class_<std::map<K, V>> register_map(const char* name) {
typedef std::map<K,V> MapType;
#if __cplusplus >= 201703L
register_optional<V>();
#endif

size_t (MapType::*size)() const = &MapType::size;
return class_<MapType>(name)
Expand Down
2 changes: 2 additions & 0 deletions test/other/embind_tsgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ EMSCRIPTEN_BINDINGS(Test) {

register_vector<int>("IntVec");

register_map<int, int>("MapIntInt");

class_<Foo>("Foo").function("process", &Foo::process);

function("global_fn", &global_fn);
Expand Down
13 changes: 11 additions & 2 deletions test/other/embind_tsgen.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,16 @@ export interface IntVec {
push_back(_0: number): void;
resize(_0: number, _1: number): void;
size(): number;
get(_0: number): number | undefined;
set(_0: number, _1: number): boolean;
get(_0: number): any;
delete(): void;
}

export interface MapIntInt {
keys(): IntVec;
get(_0: number): number | undefined;
set(_0: number, _1: number): void;
size(): number;
delete(): void;
}

Expand Down Expand Up @@ -79,6 +87,7 @@ export interface MainModule {
EmptyEnum: {};
enum_returning_fn(): Bar;
IntVec: {new(): IntVec};
MapIntInt: {new(): MapIntInt};
Foo: {};
ClassWithConstructor: {new(_0: number, _1: ValArr): ClassWithConstructor};
ClassWithTwoConstructors: {new(): ClassWithTwoConstructors; new(_0: number): ClassWithTwoConstructors};
Expand All @@ -87,8 +96,8 @@ export interface MainModule {
DerivedClass: {};
a_bool: boolean;
an_int: number;
global_fn(_0: number, _1: number): number;
optional_test(_0: Foo | undefined): number | undefined;
global_fn(_0: number, _1: number): number;
smart_ptr_function(_0: ClassWithSmartPtrConstructor): number;
smart_ptr_function_with_params(foo: ClassWithSmartPtrConstructor): number;
function_with_callback_param(_0: (message: string) => void): number;
Expand Down

0 comments on commit c9b106a

Please sign in to comment.