diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index 7b4b554b7c..9ae8888791 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -334,7 +334,7 @@ public static CustomStopWordsRemovingEstimator RemoveStopWords(this TransformsCa => new CustomStopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, stopwords); /// - /// Create a , which maps the column specified in + /// Create a , which maps the column specified in /// to a vector of n-gram counts in a new column named . /// /// @@ -363,7 +363,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting); /// - /// Create a , which maps the multiple columns specified in + /// Create a , which maps the multiple columns specified in /// to a vector of n-gram counts in a new column named . /// /// diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 1eac17ccaa..3af7a1e471 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -415,10 +415,10 @@ public void SaveAsOnnx(OnnxContext ctx) string[] separators = column.SeparatorsArray.Select(c => c.ToString()).ToArray(); tokenizerNode.AddAttribute("separators", separators); - opType = "Squeeze"; - var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name); - var squeezeNode = ctx.CreateNode(opType, intermediateVar, squeezeOutput, ctx.GetNodeName(opType), ""); - squeezeNode.AddAttribute("axes", new long[] { 1 }); + opType = "Reshape"; + var shape = ctx.AddInitializer(new long[] { 1, -1 }, new long[] { 2 }, "Shape"); + var reshapeOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, 1), column.Name); + var reshapeNode = ctx.CreateNode(opType, new[] { intermediateVar, shape }, new[] { reshapeOutput }, ctx.GetNodeName(opType), ""); } } } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 69dbbe57e5..9f97f21ad4 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1323,9 +1323,12 @@ public void NgramOnnxConversionTest( weighting: weighting)), mlContext.Transforms.Text.ProduceWordBags("Tokens", "Text", - ngramLength: ngramLength, - useAllLengths: useAllLength, - weighting: weighting) + ngramLength: ngramLength, + useAllLengths: useAllLength, + weighting: weighting), + + mlContext.Transforms.Text.TokenizeIntoWords("Tokens0", "Text") + .Append(mlContext.Transforms.Text.ProduceWordBags("Tokens", "Tokens0")) }; for (int i = 0; i < pipelines.Length; i++) @@ -1346,7 +1349,7 @@ public void NgramOnnxConversionTest( var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxFilePath, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - var columnName = i == pipelines.Length - 1 ? "Tokens" : "NGrams"; + var columnName = i >= pipelines.Length - 2 ? "Tokens" : "NGrams"; CompareResults(columnName, columnName, transformedData, onnxResult, 3); VBuffer> mlNetSlots = default;