From 3ac9e2900f2fcb684bebabf0e4abbfba3048da1f Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Tue, 12 Mar 2019 15:31:16 -0700 Subject: [PATCH] Added support for inserting batch dimension in inputs in TensorFlow. --- .../TensorFlowStaticExtensions.cs | 27 +++--- .../TensorFlowModel.cs | 20 +++-- .../TensorflowCatalog.cs | 2 +- .../TensorflowTransform.cs | 88 +++++++++++++------ .../Common/EntryPoints/core_manifest.json | 9 ++ .../TensorflowTests.cs | 4 +- 6 files changed, 104 insertions(+), 46 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs b/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs index 94d244fcd4..213fe996a5 100644 --- a/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs +++ b/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs @@ -14,14 +14,14 @@ private sealed class OutColumn : Vector { public PipelineColumn Input { get; } - public OutColumn(Vector input, string modelFile) - : base(new Reconciler(modelFile), input) + public OutColumn(Vector input, string modelFile, bool addBatchDimensionInput) + : base(new Reconciler(modelFile, addBatchDimensionInput), input) { Input = input; } - public OutColumn(Vector input, TensorFlowModel tensorFlowModel) - : base(new Reconciler(tensorFlowModel), input) + public OutColumn(Vector input, TensorFlowModel tensorFlowModel, bool addBatchDimensionInput) + : base(new Reconciler(tensorFlowModel, addBatchDimensionInput), input) { Input = input; } @@ -31,20 +31,23 @@ private sealed class Reconciler : EstimatorReconciler { private readonly string _modelFile; private readonly TensorFlowModel _tensorFlowModel; + private readonly bool _addBatchDimensionInput; - public Reconciler(string modelFile) + public Reconciler(string modelFile, bool addBatchDimensionInput) { Contracts.AssertNonEmpty(modelFile); _modelFile = modelFile; _tensorFlowModel = null; + _addBatchDimensionInput = addBatchDimensionInput; } - public Reconciler(TensorFlowModel tensorFlowModel) + public Reconciler(TensorFlowModel tensorFlowModel, bool addBatchDimensionInput) { Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel)); _modelFile = null; _tensorFlowModel = tensorFlowModel; + _addBatchDimensionInput = addBatchDimensionInput; } public override IEstimator Reconcile(IHostEnvironment env, @@ -57,9 +60,9 @@ public override IEstimator Reconcile(IHostEnvironment env, var outCol = (OutColumn)toOutput[0]; if (_modelFile == null) - return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _tensorFlowModel); + return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _tensorFlowModel, _addBatchDimensionInput); else - return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _modelFile); + return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _modelFile, _addBatchDimensionInput); } } @@ -70,22 +73,22 @@ public override IEstimator Reconcile(IHostEnvironment env, /// Load the TensorFlow model from and run it on the input column and extract one output column. /// The inputs and outputs are matched to TensorFlow graph nodes by name. /// - public static Vector ApplyTensorFlowGraph(this Vector input, string modelFile) + public static Vector ApplyTensorFlowGraph(this Vector input, string modelFile, bool addBatchDimensionInput = false) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckNonEmpty(modelFile, nameof(modelFile)); - return new OutColumn(input, modelFile); + return new OutColumn(input, modelFile, addBatchDimensionInput); } /// /// Run a TensorFlow model provided through on the input column and extract one output column. /// The inputs and outputs are matched to TensorFlow graph nodes by name. /// - public static Vector ApplyTensorFlowGraph(this Vector input, TensorFlowModel tensorFlowModel) + public static Vector ApplyTensorFlowGraph(this Vector input, TensorFlowModel tensorFlowModel, bool addBatchDimensionInput = false) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel)); - return new OutColumn(input, tensorFlowModel); + return new OutColumn(input, tensorFlowModel, addBatchDimensionInput); } } } diff --git a/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs b/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs index b72c2bc0a8..b21ddbf87d 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs @@ -55,6 +55,8 @@ public DataViewSchema GetInputSchema() /// /// The name of the model input. /// The name of the requested model output. + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. /// /// /// /// /// - public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName) - => new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, this); + public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName, bool addBatchDimensionInput = false) + => new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, this, addBatchDimensionInput); /// /// Scores a dataset using a pre-traiend TensorFlow model. /// /// The names of the model inputs. /// The names of the requested model outputs. + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. /// /// /// /// /// - public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames) - => new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, this); + public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false) + => new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, this, addBatchDimensionInput); /// /// Retrain the TensorFlow model on new data. @@ -97,6 +101,8 @@ public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, stri /// The name of the operation in the TensorFlow graph to compute performance metric during training (Optional). /// The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional). /// Learning rate to use during optimization (Optional). + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. /// /// The support for retraining is experimental. /// @@ -111,7 +117,8 @@ public TensorFlowEstimator RetrainTensorFlowModel( string lossOperation = null, string metricOperation = null, string learningRateOperation = null, - float learningRate = 0.01f) + float learningRate = 0.01f, + bool addBatchDimensionInput = false) { var options = new TensorFlowEstimator.Options() { @@ -127,7 +134,8 @@ public TensorFlowEstimator RetrainTensorFlowModel( LearningRateOperation = learningRateOperation, LearningRate = learningRate, BatchSize = batchSize, - ReTrain = true + ReTrain = true, + AddBatchDimensionInputs = addBatchDimensionInput }; return new TensorFlowEstimator(_env, options, this); } diff --git a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs index be922166d5..832e2f47b4 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs @@ -14,7 +14,7 @@ public static class TensorflowCatalog { /// /// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of - /// using . + /// using . /// /// The transform's catalog. /// Location of the TensorFlow model. diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 89d3b1d0b7..3022f53dec 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -38,7 +38,7 @@ public sealed class TensorFlowTransformer : RowToRowTransformerBase { private readonly string _savedModelPath; private readonly bool _isTemporarySavedModel; - + private readonly bool _addBatchDimensionInput; internal readonly TFSession Session; internal readonly DataViewType[] OutputTypes; internal readonly TFDataType[] TFOutputTypes; @@ -69,8 +69,9 @@ private static VersionInfo GetVersionInfo() return new VersionInfo( modelSignature: "TENSFLOW", //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel. - verReadableCur: 0x00010002, + //verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel. + verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs. + verReadableCur: 0x00010003, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(TensorFlowTransformer).Assembly.FullName); @@ -79,28 +80,32 @@ private static VersionInfo GetVersionInfo() /// /// Transform for scoring Tensorflow models. Input data column names/types must exactly match /// all model input names. Only the output columns specified will be generated. - /// If the model is already loaded please to avoid reloading of model. + /// If the model is already loaded please to avoid reloading of model. /// /// The environment to use. /// Model file path. /// The output columns to generate. Names must match model specifications. Data types are inferred from model. /// The name of the input data column. Must match model input name. If set to , the value of the will be used as source. - internal TensorFlowTransformer(IHostEnvironment env, string modelFile, string outputColumnName, string inputColumnName = null) - : this(env, TensorFlowUtils.GetSession(env, modelFile), new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false) + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. + internal TensorFlowTransformer(IHostEnvironment env, string modelFile, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false) + : this(env, TensorFlowUtils.GetSession(env, modelFile), new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false, addBatchDimensionInput) { } /// /// Transform for scoring Tensorflow models. Input data column names/types must exactly match /// all model input names. Only the output columns specified will be generated. - /// If the model is already loaded please to avoid reloading of model. + /// If the model is already loaded please to avoid reloading of model. /// /// The environment to use. /// Model file path. /// The name of the input data columns. Must match model's input names. /// The output columns to generate. Names must match model specifications. Data types are inferred from model. - internal TensorFlowTransformer(IHostEnvironment env, string modelFile, string[] outputColumnNames, string[] inputColumnNames) - : this(env, TensorFlowUtils.GetSession(env, modelFile), outputColumnNames, inputColumnNames, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false) + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. + internal TensorFlowTransformer(IHostEnvironment env, string modelFile, string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false) + : this(env, TensorFlowUtils.GetSession(env, modelFile), outputColumnNames, inputColumnNames, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false, addBatchDimensionInput) { } @@ -114,8 +119,10 @@ internal TensorFlowTransformer(IHostEnvironment env, string modelFile, string[] /// object created with . /// The output columns to generate. Names must match model specifications. Data types are inferred from model. /// The name of the input data columns. Must match model's input names. If set to , the value of the will be used as source. - internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null) - : this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false) + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. + internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false) + : this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput) { } @@ -129,8 +136,10 @@ internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo /// object created with . /// The name of the input data columns. Must match model's input names. /// The output columns to generate. Names must match model specifications. Data types are inferred from model. - internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string[] outputColumnNames, string[] inputColumnNames) - : this(env, tfModelInfo.Session, outputColumnNames, inputColumnNames, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false) + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. + internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false) + : this(env, tfModelInfo.Session, outputColumnNames, inputColumnNames, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput) { } @@ -143,6 +152,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte // *** Binary format *** // byte: indicator for frozen models + // byte: indicator for adding batch dimension in input // stream: tensorFlow model. // int: number of input columns // for each input column @@ -150,13 +160,13 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte // int: number of output columns // for each output column // int: id of output column name - GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen); + GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput); if (isFrozen) { byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); - return new TensorFlowTransformer(env, TensorFlowUtils.LoadTFSession(env, modelBytes), outputs, inputs, null, false); + return new TensorFlowTransformer(env, TensorFlowUtils.LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput); } var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(TensorFlowTransformer) + "_" + Guid.NewGuid())); @@ -185,7 +195,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte } }); - return new TensorFlowTransformer(env, TensorFlowUtils.GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true); + return new TensorFlowTransformer(env, TensorFlowUtils.GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, addBatchDimensionInput); } catch (Exception) { @@ -212,7 +222,7 @@ internal TensorFlowTransformer(IHostEnvironment env, TensorFlowEstimator.Options } internal TensorFlowTransformer(IHostEnvironment env, TensorFlowEstimator.Options options, TensorFlowModel tensorFlowModel, IDataView input) - : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, TensorFlowUtils.IsSavedModel(env, options.ModelLocation) ? options.ModelLocation : null, false) + : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, TensorFlowUtils.IsSavedModel(env, options.ModelLocation) ? options.ModelLocation : null, false, options.AddBatchDimensionInputs) { Contracts.CheckValue(env, nameof(env)); @@ -500,13 +510,18 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen) + private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput) { isFrozen = true; bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002; if (isNonFrozenModelSupported) isFrozen = ctx.Reader.ReadBoolByte(); + addBatchDimensionInput = false; + bool isAddingBatchDimensionSupported = ctx.Header.ModelVerReadable >= 0x00010003; + if (isAddingBatchDimensionSupported) + addBatchDimensionInput = ctx.Reader.ReadBoolByte(); + var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); inputs = new string[numInputs]; @@ -524,7 +539,7 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out outputs[j] = ctx.LoadNonEmptyString(); } - internal TensorFlowTransformer(IHostEnvironment env, TFSession session, string[] outputColumnNames, string[] inputColumnNames, string savedModelPath, bool isTemporarySavedModel) : + internal TensorFlowTransformer(IHostEnvironment env, TFSession session, string[] outputColumnNames, string[] inputColumnNames, string savedModelPath, bool isTemporarySavedModel, bool addBatchDimensionInput) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowTransformer))) { @@ -535,6 +550,7 @@ internal TensorFlowTransformer(IHostEnvironment env, TFSession session, string[] Session = session; _savedModelPath = savedModelPath; _isTemporarySavedModel = isTemporarySavedModel; + _addBatchDimensionInput = addBatchDimensionInput; Inputs = inputColumnNames; Outputs = outputColumnNames; @@ -610,6 +626,7 @@ private protected override void SaveModel(ModelSaveContext ctx) // *** Binary format *** // byte: indicator for frozen models + // byte: indicator for adding batch dimension in input // stream: tensorFlow model. // int: number of input columns // for each input column @@ -619,6 +636,7 @@ private protected override void SaveModel(ModelSaveContext ctx) // int: id of output column name var isFrozen = string.IsNullOrEmpty(_savedModelPath); ctx.Writer.WriteBoolByte(isFrozen); + ctx.Writer.WriteBoolByte(_addBatchDimensionInput); if (isFrozen) { var buffer = new TFBuffer(); @@ -764,6 +782,15 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : l[ishape] = originalShape[ishape] == -1 ? (int)d : originalShape[ishape]; _fullySpecifiedShapes[i] = new TFShape(l); } + + if (_parent._addBatchDimensionInput) + { + var l = new long[_fullySpecifiedShapes[i].NumDimensions + 1]; + l[0] = 1L; + for (int ishape = 1; ishape < l.Length; ishape++) + l[ishape] = _fullySpecifiedShapes[i][ishape-1]; + _fullySpecifiedShapes[i] = new TFShape(l); + } } } @@ -1098,6 +1125,16 @@ internal sealed class Options : TransformInputBase /// [Argument(ArgumentType.AtMostOnce, HelpText = "Retrain TensorFlow model.", SortOrder = 15)] public bool ReTrain = false; + + /// + /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. + /// + /// + /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. + /// In this case, there is no way to induce shape from the model's inputs or input data. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", SortOrder = 16)] + public bool AddBatchDimensionInputs = false; } private readonly IHost _host; @@ -1108,13 +1145,13 @@ internal sealed class Options : TransformInputBase private TensorFlowTransformer _transformer; [BestFriend] - internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, string modelLocation) - : this(env, outputColumnNames, inputColumnNames, TensorFlowUtils.LoadTensorFlowModel(env, modelLocation)) + internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, string modelLocation, bool addBatchDimensionInput) + : this(env, outputColumnNames, inputColumnNames, TensorFlowUtils.LoadTensorFlowModel(env, modelLocation), addBatchDimensionInput) { } - internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, TensorFlowModel tensorFlowModel) - : this(env, CreateArguments(tensorFlowModel, outputColumnNames, inputColumnNames), tensorFlowModel) + internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, TensorFlowModel tensorFlowModel, bool addBatchDimensionInput) + : this(env, CreateArguments(tensorFlowModel, outputColumnNames, inputColumnNames, addBatchDimensionInput), tensorFlowModel) { } @@ -1134,13 +1171,14 @@ internal TensorFlowEstimator(IHostEnvironment env, Options options, TensorFlowMo _outputTypes = outputTuple.outputTypes; } - private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName) + private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput) { var options = new Options(); options.ModelLocation = tensorFlowModel.ModelPath; options.InputColumns = inputColumnName; options.OutputColumns = outputColumnNames; options.ReTrain = false; + options.AddBatchDimensionInputs = addBatchDimensionInput; return options; } @@ -1183,7 +1221,7 @@ public TensorFlowTransformer Fit(IDataView input) { _transformer = _options.ReTrain ? new TensorFlowTransformer(_host, _options, _tensorFlowModel, input) : new TensorFlowTransformer(_host, _tensorFlowModel.Session, _options.OutputColumns, _options.InputColumns, - TensorFlowUtils.IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false); + TensorFlowUtils.IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false, _options.AddBatchDimensionInputs); } // Validate input schema. _transformer.GetOutputSchema(input.Schema); diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 5c8414eca2..30e8accc9a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -22268,6 +22268,15 @@ "SortOrder": 15.0, "IsNullable": false, "Default": false + }, + { + "Name": "AddBatchDimensionInputs", + "Type": "Bool", + "Desc": "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", + "Required": false, + "SortOrder": 16.0, + "IsNullable": false, + "Default": false } ], "Outputs": [ diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 8014e71816..ea6c9446d3 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -394,8 +394,8 @@ public void TensorFlowTransformInceptionTest() var data = reader.Load(new MultiFileSource(dataFile)); var images = mlContext.Transforms.LoadImages(imageFolder, ("ImageReal", "ImagePath")).Fit(data).Transform(data); var cropped = mlContext.Transforms.ResizeImages("ImageCropped", 224, 224, "ImageReal").Fit(images).Transform(images); - var pixels = mlContext.Transforms.ExtractPixels(inputName, "ImageCropped").Fit(cropped).Transform(cropped); - var tf = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(outputName, inputName).Fit(pixels).Transform(pixels); + var pixels = mlContext.Transforms.ExtractPixels(inputName, "ImageCropped", interleavePixelColors: true).Fit(cropped).Transform(cropped); + var tf = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(outputName, inputName, true).Fit(pixels).Transform(pixels); tf.Schema.TryGetColumnIndex(inputName, out int input); tf.Schema.TryGetColumnIndex(outputName, out int b);