diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 46b30483a0a39..9f2bdf4f25a57 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -4,10 +4,10 @@ Microsoft.ML.OnnxRuntime - - + netstandard1.1;netstandard2.0;xamarinios10;monoandroid11.0;net5.0;netcoreapp3.1 diff --git a/docs/ORTMobilePackageOperatorTypeSupport.md b/docs/ORTMobilePackageOperatorTypeSupport.md index 7e08e06890003..09de5d9d4cc39 100644 --- a/docs/ORTMobilePackageOperatorTypeSupport.md +++ b/docs/ORTMobilePackageOperatorTypeSupport.md @@ -17,104 +17,109 @@ NOTE: Operators used to manipulate dimensions and indices will support int32 and |Operator|Opsets| |--------|------| |**ai.onnx**|| -|ai.onnx:Abs|12, 13| -|ai.onnx:Add|12, 13| -|ai.onnx:And|12, 13| -|ai.onnx:ArgMax|12, 13| -|ai.onnx:ArgMin|12, 13| -|ai.onnx:AveragePool|12, 13| -|ai.onnx:Cast|12, 13| -|ai.onnx:Ceil|12, 13| -|ai.onnx:Clip|12, 13| -|ai.onnx:Concat|12, 13| -|ai.onnx:ConstantOfShape|12, 13| -|ai.onnx:Conv|12, 13| -|ai.onnx:ConvTranspose|12, 13| -|ai.onnx:Cos|12, 13| -|ai.onnx:CumSum|12, 13| -|ai.onnx:DepthToSpace|12, 13| -|ai.onnx:DequantizeLinear|12, 13| -|ai.onnx:Div|12, 13| -|ai.onnx:DynamicQuantizeLinear|12, 13| -|ai.onnx:Elu|12, 13| -|ai.onnx:Equal|12, 13| -|ai.onnx:Exp|12, 13| -|ai.onnx:Expand|12, 13| -|ai.onnx:Flatten|12, 13| -|ai.onnx:Floor|12, 13| -|ai.onnx:Gather|12, 13| -|ai.onnx:GatherND|12, 13| -|ai.onnx:Gemm|12, 13| -|ai.onnx:GlobalAveragePool|12, 13| -|ai.onnx:Greater|12, 13| -|ai.onnx:GreaterOrEqual|12, 13| -|ai.onnx:Identity|12, 13| -|ai.onnx:If|12, 13| -|ai.onnx:LRN|12, 13| -|ai.onnx:LeakyRelu|12, 13| -|ai.onnx:Less|12, 13| -|ai.onnx:LessOrEqual|12, 13| -|ai.onnx:Log|12, 13| -|ai.onnx:LogSoftmax|12, 13| -|ai.onnx:Loop|12, 13| -|ai.onnx:MatMul|12, 13| -|ai.onnx:MatMulInteger|12, 13| -|ai.onnx:Max|12, 13| -|ai.onnx:MaxPool|12, 13| -|ai.onnx:Mean|12, 13| -|ai.onnx:Min|12, 13| -|ai.onnx:Mul|12, 13| -|ai.onnx:Neg|12, 13| -|ai.onnx:NonMaxSuppression|12, 13| -|ai.onnx:NonZero|12, 13| -|ai.onnx:Not|12, 13| -|ai.onnx:Or|12, 13| -|ai.onnx:PRelu|12, 13| -|ai.onnx:Pad|12, 13| -|ai.onnx:Pow|12, 13| -|ai.onnx:QLinearConv|12, 13| -|ai.onnx:QLinearMatMul|12, 13| -|ai.onnx:QuantizeLinear|12, 13| -|ai.onnx:Range|12, 13| -|ai.onnx:Reciprocal|12, 13| -|ai.onnx:ReduceMax|12, 13| -|ai.onnx:ReduceMean|12, 13| -|ai.onnx:ReduceMin|12, 13| -|ai.onnx:ReduceProd|12, 13| -|ai.onnx:ReduceSum|12, 13| -|ai.onnx:Relu|12, 13| -|ai.onnx:Reshape|12, 13| -|ai.onnx:Resize|12, 13| -|ai.onnx:ReverseSequence|12, 13| -|ai.onnx:Round|12, 13| -|ai.onnx:Scan|12, 13| -|ai.onnx:ScatterND|12, 13| -|ai.onnx:Shape|12, 13| -|ai.onnx:Sigmoid|12, 13| -|ai.onnx:Sin|12, 13| -|ai.onnx:Size|12, 13| -|ai.onnx:Slice|12, 13| -|ai.onnx:Softmax|12, 13| -|ai.onnx:SpaceToDepth|12, 13| -|ai.onnx:Split|12, 13| -|ai.onnx:Sqrt|12, 13| -|ai.onnx:Squeeze|12, 13| -|ai.onnx:Sub|12, 13| -|ai.onnx:Sum|12, 13| -|ai.onnx:Tanh|12, 13| -|ai.onnx:ThresholdedRelu|12, 13| -|ai.onnx:Tile|12, 13| -|ai.onnx:TopK|12, 13| -|ai.onnx:Transpose|12, 13| -|ai.onnx:Unique|12, 13| -|ai.onnx:Unsqueeze|12, 13| -|ai.onnx:Where|12, 13| +|ai.onnx:Abs|12, 13, 14, 15| +|ai.onnx:Add|12, 13, 14, 15| +|ai.onnx:And|12, 13, 14, 15| +|ai.onnx:ArgMax|12, 13, 14, 15| +|ai.onnx:ArgMin|12, 13, 14, 15| +|ai.onnx:AveragePool|12, 13, 14, 15| +|ai.onnx:Cast|12, 13, 14, 15| +|ai.onnx:Ceil|12, 13, 14, 15| +|ai.onnx:Clip|12, 13, 14, 15| +|ai.onnx:Concat|12, 13, 14, 15| +|ai.onnx:ConstantOfShape|12, 13, 14, 15| +|ai.onnx:Conv|12, 13, 14, 15| +|ai.onnx:ConvTranspose|12, 13, 14, 15| +|ai.onnx:Cos|12, 13, 14, 15| +|ai.onnx:CumSum|12, 13, 14, 15| +|ai.onnx:DepthToSpace|12, 13, 14, 15| +|ai.onnx:DequantizeLinear|12, 13, 14, 15| +|ai.onnx:Div|12, 13, 14, 15| +|ai.onnx:DynamicQuantizeLinear|12, 13, 14, 15| +|ai.onnx:Elu|12, 13, 14, 15| +|ai.onnx:Equal|12, 13, 14, 15| +|ai.onnx:Erf|12, 13, 14, 15| +|ai.onnx:Exp|12, 13, 14, 15| +|ai.onnx:Expand|12, 13, 14, 15| +|ai.onnx:Flatten|12, 13, 14, 15| +|ai.onnx:Floor|12, 13, 14, 15| +|ai.onnx:Gather|12, 13, 14, 15| +|ai.onnx:GatherND|12, 13, 14, 15| +|ai.onnx:Gemm|12, 13, 14, 15| +|ai.onnx:GlobalAveragePool|12, 13, 14, 15| +|ai.onnx:Greater|12, 13, 14, 15| +|ai.onnx:GreaterOrEqual|12, 13, 14, 15| +|ai.onnx:HardSigmoid|12, 13, 14, 15| +|ai.onnx:Identity|12, 13, 14, 15| +|ai.onnx:If|12, 13, 14, 15| +|ai.onnx:InstanceNormalization|12, 13, 14, 15| +|ai.onnx:LRN|12, 13, 14, 15| +|ai.onnx:LayerNormalization|1| +|ai.onnx:LeakyRelu|12, 13, 14, 15| +|ai.onnx:Less|12, 13, 14, 15| +|ai.onnx:LessOrEqual|12, 13, 14, 15| +|ai.onnx:Log|12, 13, 14, 15| +|ai.onnx:LogSoftmax|12, 13, 14, 15| +|ai.onnx:Loop|12, 13, 14, 15| +|ai.onnx:MatMul|12, 13, 14, 15| +|ai.onnx:MatMulInteger|12, 13, 14, 15| +|ai.onnx:Max|12, 13, 14, 15| +|ai.onnx:MaxPool|12, 13, 14, 15| +|ai.onnx:Mean|12, 13, 14, 15| +|ai.onnx:Min|12, 13, 14, 15| +|ai.onnx:Mul|12, 13, 14, 15| +|ai.onnx:Neg|12, 13, 14, 15| +|ai.onnx:NonMaxSuppression|12, 13, 14, 15| +|ai.onnx:NonZero|12, 13, 14, 15| +|ai.onnx:Not|12, 13, 14, 15| +|ai.onnx:Or|12, 13, 14, 15| +|ai.onnx:PRelu|12, 13, 14, 15| +|ai.onnx:Pad|12, 13, 14, 15| +|ai.onnx:Pow|12, 13, 14, 15| +|ai.onnx:QLinearConv|12, 13, 14, 15| +|ai.onnx:QLinearMatMul|12, 13, 14, 15| +|ai.onnx:QuantizeLinear|12, 13, 14, 15| +|ai.onnx:Range|12, 13, 14, 15| +|ai.onnx:Reciprocal|12, 13, 14, 15| +|ai.onnx:ReduceMax|12, 13, 14, 15| +|ai.onnx:ReduceMean|12, 13, 14, 15| +|ai.onnx:ReduceMin|12, 13, 14, 15| +|ai.onnx:ReduceProd|12, 13, 14, 15| +|ai.onnx:ReduceSum|12, 13, 14, 15| +|ai.onnx:Relu|12, 13, 14, 15| +|ai.onnx:Reshape|12, 13, 14, 15| +|ai.onnx:Resize|12, 13, 14, 15| +|ai.onnx:ReverseSequence|12, 13, 14, 15| +|ai.onnx:Round|12, 13, 14, 15| +|ai.onnx:Scan|12, 13, 14, 15| +|ai.onnx:ScatterND|12, 13, 14, 15| +|ai.onnx:Shape|12, 13, 14, 15| +|ai.onnx:Sigmoid|12, 13, 14, 15| +|ai.onnx:Sin|12, 13, 14, 15| +|ai.onnx:Size|12, 13, 14, 15| +|ai.onnx:Slice|12, 13, 14, 15| +|ai.onnx:Softmax|12, 13, 14, 15| +|ai.onnx:SpaceToDepth|12, 13, 14, 15| +|ai.onnx:Split|12, 13, 14, 15| +|ai.onnx:Sqrt|12, 13, 14, 15| +|ai.onnx:Squeeze|12, 13, 14, 15| +|ai.onnx:Sub|12, 13, 14, 15| +|ai.onnx:Sum|12, 13, 14, 15| +|ai.onnx:Tanh|12, 13, 14, 15| +|ai.onnx:ThresholdedRelu|12, 13, 14, 15| +|ai.onnx:Tile|12, 13, 14, 15| +|ai.onnx:TopK|12, 13, 14, 15| +|ai.onnx:Transpose|12, 13, 14, 15| +|ai.onnx:Unique|12, 13, 14, 15| +|ai.onnx:Unsqueeze|12, 13, 14, 15| +|ai.onnx:Where|12, 13, 14, 15| ||| |**com.microsoft**|| |com.microsoft:DynamicQuantizeMatMul|1| |com.microsoft:FusedConv|1| |com.microsoft:FusedGemm|1| |com.microsoft:FusedMatMul|1| +|com.microsoft:Gelu|1| |com.microsoft:MatMulIntegerToFloat|1| |com.microsoft:NhwcMaxPool|1| |com.microsoft:QLinearAdd|1| diff --git a/docs/python/inference/examples/plot_common_errors.py b/docs/python/inference/examples/plot_common_errors.py index 0d98e17c45dff..b474574c0fdf6 100644 --- a/docs/python/inference/examples/plot_common_errors.py +++ b/docs/python/inference/examples/plot_common_errors.py @@ -21,7 +21,7 @@ from onnxruntime.datasets import get_example example2 = get_example("logreg_iris.onnx") -sess = rt.InferenceSession(example2) +sess = rt.InferenceSession(example2, providers=rt.get_available_providers()) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name diff --git a/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py b/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py index 0de0b30e28de0..af1351d0c87ff 100644 --- a/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py +++ b/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py @@ -72,7 +72,7 @@ import onnxruntime as rt from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument -sess = rt.InferenceSession("pipeline_vectorize.onnx") +sess = rt.InferenceSession("pipeline_vectorize.onnx", providers=rt.get_available_providers()) import numpy inp, out = sess.get_inputs()[0], sess.get_outputs()[0] diff --git a/docs/python/inference/examples/plot_load_and_predict.py b/docs/python/inference/examples/plot_load_and_predict.py index feb369feb2e27..9bfdc5795758d 100644 --- a/docs/python/inference/examples/plot_load_and_predict.py +++ b/docs/python/inference/examples/plot_load_and_predict.py @@ -21,7 +21,7 @@ # The model is available on github `onnx...test_sigmoid `_. example1 = get_example("sigmoid.onnx") -sess = rt.InferenceSession(example1) +sess = rt.InferenceSession(example1, providers=rt.get_available_providers()) ######################### # Let's see the input name and shape. diff --git a/docs/python/inference/examples/plot_metadata.py b/docs/python/inference/examples/plot_metadata.py index df5d15276c634..94c45e688f27f 100644 --- a/docs/python/inference/examples/plot_metadata.py +++ b/docs/python/inference/examples/plot_metadata.py @@ -31,8 +31,8 @@ ############################# # With *ONNX Runtime*: -from onnxruntime import InferenceSession -sess = InferenceSession(example) +import onnxruntime as rt +sess = rt.InferenceSession(example, providers=rt.get_available_providers()) meta = sess.get_modelmeta() print("custom_metadata_map={}".format(meta.custom_metadata_map)) diff --git a/docs/python/inference/examples/plot_profiling.py b/docs/python/inference/examples/plot_profiling.py index f0ea727ede1b2..402e7b3baee10 100644 --- a/docs/python/inference/examples/plot_profiling.py +++ b/docs/python/inference/examples/plot_profiling.py @@ -35,7 +35,7 @@ def change_ir_version(filename, ir_version=6): example1 = get_example("mul_1.onnx") onnx_model = change_ir_version(example1) onnx_model_str = onnx_model.SerializeToString() -sess = rt.InferenceSession(onnx_model_str) +sess = rt.InferenceSession(onnx_model_str, providers=rt.get_available_providers()) input_name = sess.get_inputs()[0].name x = numpy.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=numpy.float32) @@ -48,7 +48,7 @@ def change_ir_version(filename, ir_version=6): options = rt.SessionOptions() options.enable_profiling = True -sess_profile = rt.InferenceSession(onnx_model_str, options) +sess_profile = rt.InferenceSession(onnx_model_str, options, providers=rt.get_available_providers()) input_name = sess.get_inputs()[0].name x = numpy.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=numpy.float32) diff --git a/docs/python/inference/examples/plot_train_convert_predict.py b/docs/python/inference/examples/plot_train_convert_predict.py index 5b060c5f41ffe..4aa36b3dce25c 100644 --- a/docs/python/inference/examples/plot_train_convert_predict.py +++ b/docs/python/inference/examples/plot_train_convert_predict.py @@ -64,7 +64,7 @@ # its input and output. import onnxruntime as rt -sess = rt.InferenceSession("logreg_iris.onnx") +sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers()) print("input name='{}' and shape={}".format( sess.get_inputs()[0].name, sess.get_inputs()[0].shape)) @@ -180,7 +180,7 @@ def sess_predict_proba(x): ################################### # We compare. -sess = rt.InferenceSession("rf_iris.onnx") +sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers()) def sess_predict_proba_rf(x): return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] @@ -204,7 +204,7 @@ def sess_predict_proba_rf(x): onx = convert_sklearn(rf, initial_types=initial_type) with open("rf_iris_%d.onnx" % n_trees, "wb") as f: f.write(onx.SerializeToString()) - sess = rt.InferenceSession("rf_iris_%d.onnx" % n_trees) + sess = rt.InferenceSession("rf_iris_%d.onnx" % n_trees, providers=rt.get_available_providers()) def sess_predict_proba_loop(x): return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] tsk = speed("loop(X_test, rf.predict_proba, 100)", number=5, repeat=5) diff --git a/docs/python/inference/tutorial.rst b/docs/python/inference/tutorial.rst index d00a378cfeedc..fccca9cbd1451 100644 --- a/docs/python/inference/tutorial.rst +++ b/docs/python/inference/tutorial.rst @@ -82,7 +82,7 @@ for this machine learning model. import numpy import onnxruntime as rt - sess = rt.InferenceSession("logreg_iris.onnx") + sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers()) input_name = sess.get_inputs()[0].name pred_onx = sess.run(None, {input_name: X_test.astype(numpy.float32)})[0] print(pred_onx) @@ -97,7 +97,7 @@ by specifying its name into a list. import numpy import onnxruntime as rt - sess = rt.InferenceSession("logreg_iris.onnx") + sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers()) input_name = sess.get_inputs()[0].name label_name = sess.get_outputs()[0].name pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0] diff --git a/onnxruntime/core/providers/cuda/nn/instance_norm.cc b/onnxruntime/core/providers/cuda/nn/instance_norm.cc index c40c27cdf11da..da51ba8f90694 100644 --- a/onnxruntime/core/providers/cuda/nn/instance_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/instance_norm.cc @@ -5,6 +5,7 @@ #include "instance_norm_impl.h" #include "core/providers/cpu/nn/instance_norm_helper.h" #include "core/providers/cpu/nn/batch_norm_helper.h" +#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" namespace onnxruntime { namespace cuda { @@ -45,10 +46,10 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) co const TensorShape& x_shape = X->Shape(); Tensor* Y = p_op_kernel_context->Output(0, x_shape); - auto y_data = reinterpret_cast(Y->template MutableData()); - auto x_data = reinterpret_cast(X->template Data()); - auto scale_data = reinterpret_cast(scale->template Data()); - auto bias_data = reinterpret_cast(bias->template Data()); + auto* y_data = reinterpret_cast(Y->template MutableData()); + const auto* x_data = reinterpret_cast(X->template Data()); + const auto* scale_data = reinterpret_cast(scale->template Data()); + const auto* bias_data = reinterpret_cast(bias->template Data()); const auto& x_dims = x_shape.GetDims(); const int64_t N = x_dims[0]; @@ -160,5 +161,150 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) co return Status::OK(); } +template <> +Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const { + typedef typename ToCudaType::MappedType CudaT; + + const Tensor* X = p_op_kernel_context->Input(0); + const Tensor* scale = p_op_kernel_context->Input(1); + const Tensor* bias = p_op_kernel_context->Input(2); + + ORT_RETURN_IF_ERROR(InstanceNormHelper::ValidateInputs(X, scale, bias)); + + const TensorShape& x_shape = X->Shape(); + Tensor* Y = p_op_kernel_context->Output(0, x_shape); + + auto* y_data = reinterpret_cast(Y->template MutableData()); + const auto* x_data = reinterpret_cast(X->template Data()); + const auto* scale_data = reinterpret_cast(scale->template Data()); + const auto* bias_data = reinterpret_cast(bias->template Data()); + + const auto& x_dims = x_shape.GetDims(); + const int64_t N = x_dims[0]; + const int64_t C = x_dims[1]; + const auto one = Consts::One; + const auto zero = Consts::Zero; + + if (N == 1) { + // when N == 1, we can treat it as spatial batch normalization in training + // as the mean/variance would be computed from input + + CudnnTensor data_desc; + std::vector new_dims; + BatchNormHelper::NormalizeDims(x_shape, new_dims); + ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType())); + + CudnnTensor stats_desc; + ORT_RETURN_IF_ERROR(stats_desc.Set(data_desc, CUDNN_BATCHNORM_SPATIAL)); + + // For half input data type, alpha, beta, scale, bias need to be float type. + // alpha, beta will be of type float as the Consts struct specialization + // for MLFloat16 type take care of that. Only Convert the scale, bias to float) + + auto scale_data_fp32 = GetScratchBuffer(C); + Impl_Cast(Stream(), scale_data, scale_data_fp32.get(), C); + + auto bias_data_fp32 = GetScratchBuffer(C); + Impl_Cast(Stream(), bias_data, bias_data_fp32.get(), C); + + CUDNN_RETURN_IF_ERROR(cudnnBatchNormalizationForwardTraining( + CudnnHandle(), + CUDNN_BATCHNORM_SPATIAL, + &one, + &zero, + data_desc, + x_data, + data_desc, + y_data, + stats_desc, + scale_data_fp32.get(), + bias_data_fp32.get(), + 1.0f, + nullptr, + nullptr, + epsilon_, + nullptr, + nullptr)); + } else { + // we use cudnnBatchNormalizationForwardTraining to compute mean/variance + // so collapsing NC into channel + + auto input_count = x_shape.Size(); // N * C * H * W + auto stats_count = x_shape.SizeToDimension(2); // N * C + auto image_size = input_count / stats_count; + + CudnnTensor data_desc; + ORT_RETURN_IF_ERROR(data_desc.Set(std::array{1, stats_count, image_size, 1}, + CudnnTensor::GetDataType())); + + // stats_desc needs to be of 'float' type even for float16 input as the "stats" are of float type + CudnnTensor stats_desc; + ORT_RETURN_IF_ERROR(stats_desc.Set(std::array{1, stats_count, 1, 1}, + CudnnTensor::GetDataType())); + + // For half input data type, we need to allocate some "intermediate" + // float buffers for CuDNN to use. + const size_t stats_byte_count = stats_count * sizeof(float); + + // Mean & Variance are inputs & outputs and must be initialized to zero to work properly + auto mean = GetScratchBuffer(stats_count); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mean.get(), 0, stats_byte_count, Stream())); + auto variance = GetScratchBuffer(stats_count); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(variance.get(), 0, stats_byte_count, Stream())); + + // We must set the scale & bias inputs to zero as they are inputs to the calculation + auto unused_scale = GetScratchBuffer(stats_count); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_scale.get(), 0, stats_byte_count, Stream())); + auto unused_bias = GetScratchBuffer(stats_count); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_bias.get(), 0, stats_byte_count, Stream())); + + // first, compute mean and variance per-instance per-channel using cudnnBatchNorm training + CUDNN_RETURN_IF_ERROR(cudnnBatchNormalizationForwardTraining( + CudnnHandle(), + CUDNN_BATCHNORM_SPATIAL, + &one, + &zero, + data_desc, + x_data, + data_desc, + y_data, // use y temporarily, would be rewritten later + stats_desc, + unused_scale.get(), + unused_bias.get(), + 1.0f, + mean.get(), + variance.get(), + CUDNN_BN_MIN_EPSILON, + nullptr, + nullptr)); + + // Y = scale * (x - mean) / sqrt (variance + epsilon) + B + // X/Y is (N,C,H,W) + // scale/bias is (1,C,1,1) + // mean/stddev is (N,C,1,1) + // NOTE cudnnBatchNormalization computes unbiased variance sum((Xi - mean)^2) / (count - 1) + // and it needs to be corrected with (count - 1) / count + fast_divmod fdm_HW(gsl::narrow_cast(image_size)); + fast_divmod fdm_C(gsl::narrow_cast(C)); + + // The InstanceNormImpl kernel handles the mean/variance in float32, so no casting required here + InstanceNormImpl( + Stream(), + x_data, + scale_data, + bias_data, + mean.get(), + variance.get(), + (image_size - 1.0) / image_size, + static_cast(epsilon_), + fdm_HW, + fdm_C, + y_data, + input_count); + } + + return Status::OK(); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/instance_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/instance_norm_impl.cu index c0af3d0580d05..057c301dbde5f 100644 --- a/onnxruntime/core/providers/cuda/nn/instance_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/instance_norm_impl.cu @@ -7,18 +7,18 @@ namespace onnxruntime { namespace cuda { -template +template __global__ void _InstanceNormKernel( - const T* input_data, - const T* scale, - const T* bias, - const T* mean, - const T* variance, + const T1* __restrict__ input_data, + const T1* __restrict__ scale, + const T1* __restrict__ bias, + const T2* __restrict__ mean, + const T2* __restrict__ variance, const double variance_correction, const double epsilon, const fast_divmod fdm_HW, const fast_divmod fdm_C, - T* output_data, + T1* __restrict__ output_data, const CUDA_LONG N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); int nc = fdm_HW.div(id); @@ -26,34 +26,35 @@ __global__ void _InstanceNormKernel( fdm_C.divmod(nc, n, c); // Y = scale * (x - mean) / sqrt (std * std + epsilon) + B - output_data[id] = scale[c] * (input_data[id] - mean[nc]) / _Sqrt(variance[nc] * (T)variance_correction + (T)epsilon) + bias[c]; + output_data[id] = scale[c] * (input_data[id] - (T1)mean[nc]) / _Sqrt((T1)variance[nc] * (T1)variance_correction + (T1)epsilon) + bias[c]; } -template +template void InstanceNormImpl( cudaStream_t stream, - const T* input_data, - const T* scale, - const T* bias, - const T* mean, - const T* variance, + const T1* input_data, + const T1* scale, + const T1* bias, + const T2* mean, + const T2* variance, const double variance_correction, const double epsilon, const fast_divmod& fdm_HW, const fast_divmod& fdm_C, - T* output_data, + T1* output_data, size_t N) { int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _InstanceNormKernel<<>>( + _InstanceNormKernel<<>>( input_data, scale, bias, mean, variance, variance_correction, epsilon, fdm_HW, fdm_C, output_data, (CUDA_LONG)N); } -#define SPECIALIZED_IMPL(T) \ - template void InstanceNormImpl(cudaStream_t stream, const T* input_data, const T* scale, const T* bias, const T* mean, const T* stddev, const double variance_correction, const double epsilon, const fast_divmod& fdm_HW, const fast_divmod& fdm_C, T* output_data, size_t count); +#define SPECIALIZED_IMPL(T1, T2) \ + template void InstanceNormImpl(cudaStream_t stream, const T1* input_data, const T1* scale, const T1* bias, const T2* mean, const T2* stddev, const double variance_correction, const double epsilon, const fast_divmod& fdm_HW, const fast_divmod& fdm_C, T1* output_data, size_t count); -SPECIALIZED_IMPL(float) -SPECIALIZED_IMPL(double) -SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(float, float) +SPECIALIZED_IMPL(double, double) +// When the input data type is float16, the means and variances will flow in as float32 (special case) +SPECIALIZED_IMPL(half, float) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/instance_norm_impl.h b/onnxruntime/core/providers/cuda/nn/instance_norm_impl.h index cda9684416c61..35d754b297b54 100644 --- a/onnxruntime/core/providers/cuda/nn/instance_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/instance_norm_impl.h @@ -6,19 +6,19 @@ namespace onnxruntime { namespace cuda { -template +template void InstanceNormImpl( cudaStream_t stream, - const T* input_data, - const T* scale, - const T* bias, - const T* mean, - const T* variance, + const T1* input_data, + const T1* scale, + const T1* bias, + const T2* mean, + const T2* variance, const double variance_correction, const double epsilon, const fast_divmod& fdm_HW, const fast_divmod& fdm_C, - T* output_data, + T1* output_data, size_t count); } // namespace cuda diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index aa0ee6156fa65..38e3c5e30e4dd 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -47,8 +47,6 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, is_weight_int8 = weight_qType == QuantType.QInt8 self.is_weight_symmetric = is_weight_int8 if 'WeightSymmetric' not in self.extra_options else self.extra_options['WeightSymmetric'] self.is_activation_symmetric = False if 'ActivationSymmetric' not in self.extra_options else self.extra_options['ActivationSymmetric'] - self.op_types_support_per_channel_quantization = [] if 'OpTypesSupportPerChannelQuantization' not in extra_options \ - else extra_options['OpTypesSupportPerChannelQuantization'] self.input_qType = onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 self.weight_qType = onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 diff --git a/onnxruntime/python/tools/quantization/operators/matmul.py b/onnxruntime/python/tools/quantization/operators/matmul.py index 16ee6d12e5374..2d37eeb46e9ad 100644 --- a/onnxruntime/python/tools/quantization/operators/matmul.py +++ b/onnxruntime/python/tools/quantization/operators/matmul.py @@ -1,5 +1,7 @@ import onnx +import itertools from .base_operator import QuantOperatorBase +from .qdq_base_operator import QDQOperatorBase from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType from onnx import onnx_pb as onnx_proto ''' @@ -98,3 +100,24 @@ def quantize(self): self.quantizer.quantized_value_map[node.output[0]] = q_output self.quantizer.new_nodes += nodes + +class QDQMatMul(QDQOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def quantize(self): + node = self.node + assert (node.op_type == "MatMul") + + if self.disable_qdq_for_node_output: + nodes_to_iterate = node.input + else: + nodes_to_iterate = itertools.chain(node.input, node.output) + + for tensor_name in nodes_to_iterate: + # only support per-channel quantization on weight + if self.quantizer.is_per_channel() and find_by_name(tensor_name, self.quantizer.model.initializer()) : + channel_axis = self.quantizer.qdq_op_type_per_channel_support_to_axis.get(node.op_type, 1) + self.quantizer.quantize_tensor_per_channel(tensor_name, channel_axis) + else: + self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index ebe3b7c71a789..f8f5546b1512b 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -19,10 +19,4 @@ def quantize(self): nodes_to_iterate = itertools.chain(node.input, node.output) for tensor_name in nodes_to_iterate: - if self.quantizer.is_per_channel(): - if node.op_type in self.quantizer.op_types_support_per_channel_quantization : - self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis) - else: - self.quantizer.quantize_tensor(tensor_name) - else: - self.quantizer.quantize_tensor(tensor_name) + self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 423e8d5c8d938..f5797282dda06 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -43,10 +43,8 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, self.op_types_to_exclude_output_quantization = [] if 'OpTypesToExcludeOutputQuantizatioin' not in extra_options \ else extra_options['OpTypesToExcludeOutputQuantizatioin'] - # In some cases, for example QDQ BERT model for TensorRT, - # QDQ should always appear as a pair. - # For our quantization tool, we do quantization on Dequantizelinear's input - # to remove Quantizelinear as optimization for weight. + # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization. + # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair. # Therefore, we need to disable this optimization and add qdq pair to weight. self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \ else extra_options['AddQDQPairToWeight'] @@ -57,8 +55,8 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, if self.dedicated_qdq_pair: self.tensor_to_its_receiving_nodes = {} - # Channel axis when per_channel is True - self.qdq_channel_axis = 0 if 'QDQChannelAxis' not in extra_options else extra_options['QDQChannelAxis'] + # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. + self.qdq_op_type_per_channel_support_to_axis = {} if 'QDQOpTypePerChannelSupportToAxis' not in extra_options else extra_options['QDQOpTypePerChannelSupportToAxis'] def quantize_tensor(self, tensor_name): weight = find_by_name(tensor_name, self.model.initializer()) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index bc0a57a425507..955a74e52562b 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -200,6 +200,10 @@ def quantize_static(model_input, the output of ops with this specific op types. DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. If True, it will create identical and dedicated QDQ pair for each node. + QDQOpTypePerChannelSupportToAxis = dictionary : Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, + and it's effective only when per channel quantization is supported and per_channel is True. + If specific op type supports per channel quantization but not explicitly specified with channel axis, + default channel axis will be used. ''' mode = QuantizationMode.QLinearOps diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 3628bd2ec963d..e5da380978feb 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -1,7 +1,7 @@ from .quant_utils import QuantizationMode from .operators.base_operator import QuantOperatorBase from .operators.qdq_base_operator import QDQOperatorBase -from .operators.matmul import MatMulInteger, QLinearMatMul +from .operators.matmul import MatMulInteger, QLinearMatMul, QDQMatMul from .operators.attention import AttentionQuant from .operators.embed_layernorm import EmbedLayerNormalizationQuant from .operators.gather import GatherQuant @@ -66,6 +66,7 @@ "MaxPool": QDQMaxPool, "AveragePool" : QDQDirect8BitOp, "Concat": QDQConcat, + "MatMul": QDQMatMul, } diff --git a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc index 17afa96bd8506..45c8ed74f6a2e 100644 --- a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc @@ -41,11 +41,11 @@ TEST(InstanceNormalizationOpTest, InstanceNorm) { -0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F, -1.65898662F, 0.27618144F, 0.64840618F, 0.734399F}; test.AddOutput("Y", input_dims, expected_output); -#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug +#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); #else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); -#endif +#endif } TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { @@ -58,12 +58,10 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { vector input_dims = {1, 3, 4}; test.AddInput("input", input_dims, input); - // vector scale = {2.1F, 0.1F, 1.F}; vector scale = {1.0F, 1.0F, 1.F}; vector scale_dims = {3}; test.AddInput("scale", scale_dims, scale); - // vector B = {2.3F, 1.5F, 0.F}; vector B = {0.0F, 0.0F, 0.F}; vector B_dims = {3}; test.AddInput("B", B_dims, B); @@ -72,13 +70,150 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; test.AddOutput("Y", input_dims, expected_output); -#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug + +#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); +#else + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +#endif +} + +TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; + vector input_dims = {2, 3, 4}; + test.AddInput("input", input_dims, input); + + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, scale); + + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + test.AddInput("B", B_dims, B); + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, + + -0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; + + test.AddOutput("Y", input_dims, expected_output); +#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); #else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); -#endif +#endif } +// Only CUDA kernel has float 16 support +#ifdef USE_CUDA + +TEST(InstanceNormalizationOpTest, InstanceNormBatch1_fp16) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; + vector input_dims = {1, 3, 4}; + + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; + + constexpr size_t input_size = 1 * 3 * 4; + + vector input_fp16(input_size); + vector scale_fp16(3); + vector B_fp16(3); + vector expected_output_fp16(input_size); + + ConvertFloatToMLFloat16(input.data(), input_fp16.data(), input_size); + ConvertFloatToMLFloat16(scale.data(), scale_fp16.data(), 3); + ConvertFloatToMLFloat16(B.data(), B_fp16.data(), 3); + ConvertFloatToMLFloat16(expected_output.data(), expected_output_fp16.data(), input_size); + + test.AddInput("X", input_dims, input_fp16); + test.AddInput("scale", {3}, scale_fp16); + test.AddInput("B", {3}, B_fp16); + test.AddOutput("Y", input_dims, expected_output_fp16); + +#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); +#else + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +#endif +} + +TEST(InstanceNormalizationOpTest, InstanceNormBatch2_fp16) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; + vector input_dims = {2, 3, 4}; + + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, + + -0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; + + constexpr size_t input_size = 2 * 3 * 4; + + vector input_fp16(input_size); + vector scale_fp16(3); + vector B_fp16(3); + vector expected_output_fp16(input_size); + + ConvertFloatToMLFloat16(input.data(), input_fp16.data(), input_size); + ConvertFloatToMLFloat16(scale.data(), scale_fp16.data(), 3); + ConvertFloatToMLFloat16(B.data(), B_fp16.data(), 3); + ConvertFloatToMLFloat16(expected_output.data(), expected_output_fp16.data(), input_size); + + test.AddInput("X", input_dims, input_fp16); + test.AddInput("scale", {3}, scale_fp16); + test.AddInput("B", {3}, B_fp16); + test.AddOutput("Y", input_dims, expected_output_fp16); + +#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); +#else + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +#endif +} + +#endif TEST(InstanceNormalizationOpTest, InstanceNorm_2) { OpTester test("InstanceNormalization"); test.AddAttribute("epsilon", 0.3F); @@ -119,7 +254,7 @@ TEST(InstanceNormalizationOpTest, InstanceNorm_2) { 1.88028F, 2.353724F, -0.25549555F, 2.0837004F, 2.8466992F, 2.0773761F}; test.AddOutput("Y", input_dims, expected_output); -#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug +#if defined(OPENVINO_CONFIG_MYRIAD) //Disabling this test on MYRIADX temporarily due to a bug test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); #else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); diff --git a/onnxruntime/test/python/onnxruntime_test_python_keras.py b/onnxruntime/test/python/onnxruntime_test_python_keras.py index e2c4f2390d9da..02e7cdb8e7d71 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_keras.py +++ b/onnxruntime/test/python/onnxruntime_test_python_keras.py @@ -68,7 +68,7 @@ def testRunModelConv(self): # runtime content = converted_model.SerializeToString() - rt = onnxrt.InferenceSession(content) + rt = onnxrt.InferenceSession(content, providers=onnxrt.get_available_providers()) input = {rt.get_inputs()[0].name: x} actual_rt = rt.run(None, input) self.assertEqual(len(actual_rt), 1) diff --git a/orttraining/tools/scripts/layer_norm_transform.py b/orttraining/tools/scripts/layer_norm_transform.py index 15b2b4ae07352..6355118709ff8 100644 --- a/orttraining/tools/scripts/layer_norm_transform.py +++ b/orttraining/tools/scripts/layer_norm_transform.py @@ -163,11 +163,11 @@ def main(): input_mask = np.ones((batch, sq_length), dtype=np.int64) # Do forward using the original model. - sess = ort.InferenceSession(model_file_path) + sess = ort.InferenceSession(model_file_path, providers=ort.get_available_providers()) result = sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) # Do forward using the new model. - new_sess = ort.InferenceSession(new_model_file_path) + new_sess = ort.InferenceSession(new_model_file_path, providers=ort.get_available_providers()) new_result = new_sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) # Compare the outcomes from the two models. diff --git a/orttraining/tools/scripts/model_transform.py b/orttraining/tools/scripts/model_transform.py index 26424db66d3b3..de23df13a1963 100644 --- a/orttraining/tools/scripts/model_transform.py +++ b/orttraining/tools/scripts/model_transform.py @@ -298,11 +298,11 @@ def add_expand_shape(model): input_mask = np.ones((batch, sq_length), dtype=np.int64) # Do forward using the original model. -sess = ort.InferenceSession(input_model_name) +sess = ort.InferenceSession(input_model_name, providers=ort.get_available_providers()) result = sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) # Do forward using the new model. -new_sess = ort.InferenceSession(output_model_name) +new_sess = ort.InferenceSession(output_model_name, providers=ort.get_available_providers()) new_result = new_sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) # Compare the outcomes from the two models. diff --git a/orttraining/tools/scripts/nv_run_pretraining.py b/orttraining/tools/scripts/nv_run_pretraining.py index 1b3ec4a247e36..c7c03be161c08 100644 --- a/orttraining/tools/scripts/nv_run_pretraining.py +++ b/orttraining/tools/scripts/nv_run_pretraining.py @@ -528,7 +528,7 @@ def main(): is_model_exported = False import onnxruntime as ort - sess = ort.InferenceSession(onnx_path) + sess = ort.InferenceSession(onnx_path, providers=ort.get_available_providers()) result = sess.run(None, {'input1': input_ids.cpu().numpy(), 'input2': segment_ids.cpu().numpy(), 'input3': input_mask.cpu().numpy()}) print('---ORT result---')