From 674c359c783a3aef7ac0fedb29c940d0179f0f4f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 20 Mar 2024 13:52:00 -0400 Subject: [PATCH] String Tensor SplitToSequence fix (#19942) --- .../core/providers/cpu/sequence/sequence_ops.cc | 2 +- .../providers/cpu/sequence/sequence_ops_test.cc | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc index 8064bc0a58cb1..2913f4ac32b6e 100644 --- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc +++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc @@ -453,7 +453,7 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu int num_remaining_splits = 0; InlinedVector split_sizes; const bool is_string_type = input.IsDataTypeString(); - const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size(); + const size_t element_size = input.DataType()->Size(); // figure out split_scalar or split_sizes if (p_split_input) { diff --git a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc index 60e75811e4333..c2d64b8e5ee4a 100644 --- a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc +++ b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc @@ -442,6 +442,19 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) { test.Run(); } +TEST(SequenceOpsTest, SplitToSequence_StringSplit) { + OpTester test("SplitToSequence", 11); + test.AddInput("input", {3}, std::vector({"Test string", "Another string", "A third and much longer string"})); + int64_t axis = 0; + test.AddAttribute("axis", axis); + SeqTensors output; + output.AddTensor({1}, {"Test string"}); + output.AddTensor({1}, {"Another string"}); + output.AddTensor({1}, {"A third and much longer string"}); + test.AddSeqOutput("S2", output); + test.Run(); +} + TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) { OpTester test("SplitToSequence", 11); test.AddInput("input", {5, 2}, GetConsecutiveVector(1.f, 10));