diff --git a/src/idl_gen_ts.cpp b/src/idl_gen_ts.cpp index 7bee3f80148..acd2a4febec 100644 --- a/src/idl_gen_ts.cpp +++ b/src/idl_gen_ts.cpp @@ -408,7 +408,7 @@ class TsGenerator : public BaseGenerator { switch (type.base_type) { case BASE_TYPE_BOOL: case BASE_TYPE_CHAR: return "Int8"; - case BASE_TYPE_UTYPE: + case BASE_TYPE_UTYPE: return GenType(GetUnionUnderlyingType(type)); case BASE_TYPE_UCHAR: return "Uint8"; case BASE_TYPE_SHORT: return "Int16"; case BASE_TYPE_USHORT: return "Uint16"; @@ -562,11 +562,26 @@ class TsGenerator : public BaseGenerator { } } + static Type GetUnionUnderlyingType(const Type &type) + { + if (type.enum_def != nullptr && + type.enum_def->underlying_type.base_type != type.base_type) { + return type.enum_def->underlying_type; + } else { + return Type(BASE_TYPE_UCHAR); + } + } + + static Type GetUnderlyingVectorType(const Type &vector_type) + { + return (vector_type.base_type == BASE_TYPE_UTYPE) ? GetUnionUnderlyingType(vector_type) : vector_type; + } + // Returns the method name for use with add/put calls. std::string GenWriteMethod(const Type &type) { // Forward to signed versions since unsigned versions don't exist switch (type.base_type) { - case BASE_TYPE_UTYPE: + case BASE_TYPE_UTYPE: return GenWriteMethod(GetUnionUnderlyingType(type)); case BASE_TYPE_UCHAR: return GenWriteMethod(Type(BASE_TYPE_CHAR)); case BASE_TYPE_USHORT: return GenWriteMethod(Type(BASE_TYPE_SHORT)); case BASE_TYPE_UINT: return GenWriteMethod(Type(BASE_TYPE_INT)); @@ -1763,7 +1778,8 @@ class TsGenerator : public BaseGenerator { auto vectortype = field.value.type.VectorType(); auto vectortypename = GenTypeName(imports, struct_def, vectortype, false); - auto inline_size = InlineSize(vectortype); + auto type = GetUnderlyingVectorType(vectortype); + auto inline_size = InlineSize(type); auto index = GenBBAccess() + ".__vector(this.bb_pos + offset) + index" + MaybeScale(inline_size); @@ -1994,8 +2010,9 @@ class TsGenerator : public BaseGenerator { if (IsVector(field.value.type)) { auto vector_type = field.value.type.VectorType(); - auto alignment = InlineAlignment(vector_type); - auto elem_size = InlineSize(vector_type); + auto type = GetUnderlyingVectorType(vector_type); + auto alignment = InlineAlignment(type); + auto elem_size = InlineSize(type); // Generate a method to create a vector from a JavaScript array if (!IsStruct(vector_type)) { diff --git a/src/idl_parser.cpp b/src/idl_parser.cpp index 4df4558ce95..2b401cf0c0b 100644 --- a/src/idl_parser.cpp +++ b/src/idl_parser.cpp @@ -2718,7 +2718,7 @@ bool Parser::Supports64BitOffsets() const { } bool Parser::SupportsUnionUnderlyingType() const { - return (opts.lang_to_generate & ~IDLOptions::kCpp) == 0; + return (opts.lang_to_generate & ~(IDLOptions::kCpp | IDLOptions::kTs)) == 0; } Namespace *Parser::UniqueNamespace(Namespace *ns) { diff --git a/tests/parser_test.cpp b/tests/parser_test.cpp index d63c0b180e4..33b1d6e1e6d 100644 --- a/tests/parser_test.cpp +++ b/tests/parser_test.cpp @@ -845,7 +845,7 @@ void ParseUnionTest() { // Test union underlying type const char *source = "table A {} table B {} union U : int {A, B} table C {test_union: U; test_vector_of_union: [U];}"; flatbuffers::Parser parser3; - parser3.opts.lang_to_generate = flatbuffers::IDLOptions::kCpp; + parser3.opts.lang_to_generate = flatbuffers::IDLOptions::kCpp | flatbuffers::IDLOptions::kTs; TEST_EQ(parser3.Parse(source), true); parser3.opts.lang_to_generate &= flatbuffers::IDLOptions::kJava; diff --git a/tests/ts/JavaScriptUnionUnderlyingTypeTest.js b/tests/ts/JavaScriptUnionUnderlyingTypeTest.js new file mode 100644 index 00000000000..6a324ca43ec --- /dev/null +++ b/tests/ts/JavaScriptUnionUnderlyingTypeTest.js @@ -0,0 +1,26 @@ +import assert from 'assert' +import * as flatbuffers from 'flatbuffers' +import {UnionUnderlyingType as Test} from './union_underlying_type_test.js' + +function main() { + let a = new Test.AT(); + a.a = 1; + let b = new Test.BT(); + b.b = "foo"; + let c = new Test.CT(); + c.c = true; + let d = new Test.DT(); + d.testUnionType = Test.ABC.A; + d.testUnion = a; + d.testVectorOfUnionType = [Test.ABC.A, Test.ABC.B, Test.ABC.C]; + d.testVectorOfUnion = [a, b, c]; + + let fbb = new flatbuffers.Builder(); + let offset = d.pack(fbb); + fbb.finish(offset); + + let unpacked = Test.D.getRootAsD(fbb.dataBuffer()).unpack(); + assert.equal(JSON.stringify(unpacked), JSON.stringify(d)); +} + +main() \ No newline at end of file diff --git a/tests/ts/TypeScriptTest.py b/tests/ts/TypeScriptTest.py index ae357ef09c4..c6c7cb7f993 100755 --- a/tests/ts/TypeScriptTest.py +++ b/tests/ts/TypeScriptTest.py @@ -117,6 +117,11 @@ def esbuild(input, output): ) esbuild("typescript_keywords.ts", "typescript_keywords_generated.cjs") +flatc( + options=["--ts", "--reflect-names", "--gen-name-strings", "--gen-mutable", "--gen-object-api", "--ts-entry-points", "--ts-flat-files"], + schema="../union_underlying_type_test.fbs" +) + print("Running TypeScript Compiler...") check_call(["tsc"]) print("Running TypeScript Compiler in old node resolution mode for no_import_ext...") @@ -129,6 +134,7 @@ def esbuild(input, output): check_call(NODE_CMD + ["JavaScriptUnionVectorTest"]) check_call(NODE_CMD + ["JavaScriptFlexBuffersTest"]) check_call(NODE_CMD + ["JavaScriptComplexArraysTest"]) +check_call(NODE_CMD + ["JavaScriptUnionUnderlyingTypeTest"]) print("Running old v1 TypeScript Tests...") check_call(NODE_CMD + ["JavaScriptTestv1.cjs", "./monster_test_generated.cjs"]) diff --git a/tests/ts/tsconfig.json b/tests/ts/tsconfig.json index d9ef7410c3a..eb08992b427 100644 --- a/tests/ts/tsconfig.json +++ b/tests/ts/tsconfig.json @@ -14,6 +14,7 @@ "optional_scalars/**/*.ts", "namespace_test/**/*.ts", "union_vector/**/*.ts", - "arrays_test_complex/**/*.ts" + "arrays_test_complex/**/*.ts", + "union_underlying_type_test.ts" ] }