diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
index 7b4b554b7c..9ae8888791 100644
--- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
@@ -334,7 +334,7 @@ public static CustomStopWordsRemovingEstimator RemoveStopWords(this TransformsCa
=> new CustomStopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, stopwords);
///
- /// Create a , which maps the column specified in
+ /// Create a , which maps the column specified in
/// to a vector of n-gram counts in a new column named .
///
///
@@ -363,7 +363,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting);
///
- /// Create a , which maps the multiple columns specified in
+ /// Create a , which maps the multiple columns specified in
/// to a vector of n-gram counts in a new column named .
///
///
diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
index 1eac17ccaa..3af7a1e471 100644
--- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
@@ -415,10 +415,10 @@ public void SaveAsOnnx(OnnxContext ctx)
string[] separators = column.SeparatorsArray.Select(c => c.ToString()).ToArray();
tokenizerNode.AddAttribute("separators", separators);
- opType = "Squeeze";
- var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name);
- var squeezeNode = ctx.CreateNode(opType, intermediateVar, squeezeOutput, ctx.GetNodeName(opType), "");
- squeezeNode.AddAttribute("axes", new long[] { 1 });
+ opType = "Reshape";
+ var shape = ctx.AddInitializer(new long[] { 1, -1 }, new long[] { 2 }, "Shape");
+ var reshapeOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, 1), column.Name);
+ var reshapeNode = ctx.CreateNode(opType, new[] { intermediateVar, shape }, new[] { reshapeOutput }, ctx.GetNodeName(opType), "");
}
}
}
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index 69dbbe57e5..9f97f21ad4 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -1323,9 +1323,12 @@ public void NgramOnnxConversionTest(
weighting: weighting)),
mlContext.Transforms.Text.ProduceWordBags("Tokens", "Text",
- ngramLength: ngramLength,
- useAllLengths: useAllLength,
- weighting: weighting)
+ ngramLength: ngramLength,
+ useAllLengths: useAllLength,
+ weighting: weighting),
+
+ mlContext.Transforms.Text.TokenizeIntoWords("Tokens0", "Text")
+ .Append(mlContext.Transforms.Text.ProduceWordBags("Tokens", "Tokens0"))
};
for (int i = 0; i < pipelines.Length; i++)
@@ -1346,7 +1349,7 @@ public void NgramOnnxConversionTest(
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxFilePath, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
- var columnName = i == pipelines.Length - 1 ? "Tokens" : "NGrams";
+ var columnName = i >= pipelines.Length - 2 ? "Tokens" : "NGrams";
CompareResults(columnName, columnName, transformedData, onnxResult, 3);
VBuffer> mlNetSlots = default;