Skip to content

Commit

Permalink
Merge branch 'main' into yuwenzho/int4
Browse files Browse the repository at this point in the history
  • Loading branch information
yuwenzho committed Nov 22, 2023
2 parents 21fc0c4 + 3bc9efc commit 8f6ea60
Show file tree
Hide file tree
Showing 45 changed files with 1,301 additions and 419 deletions.
18 changes: 18 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ data sparsity based performance optimizations.
unset ORTMODULE_CACHE_DIR # Disable
```

#### ORTMODULE_USE_EFFICIENT_ATTENTION

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.

```bash
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down Expand Up @@ -397,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training
export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results
```

#### ORTMODULE_USE_FLASH_ATTENTION

- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.

```bash
export ORTMODULE_USE_FLASH_ATTENTION=1
```

#### ORTMODULE_TRITON_DEBUG

- **Feature Area**: *ORTMODULE/TritonOp*
Expand Down
49 changes: 11 additions & 38 deletions include/onnxruntime/core/framework/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,17 @@
// Licensed under the MIT License.

#pragma once
#include <iosfwd>
#include <vector>

#include <algorithm>
#include <string>
#include <cstring>
#include "core/common/gsl.h"
#include "onnxruntime_config.h"

#ifndef DISABLE_ABSEIL
// Need to include abseil inlined_vector.h header directly here
// as hash tables cause CUDA 10.2 compilers to fail. inlined_vector.h is fine.
#ifdef _MSC_VER
#pragma warning(push)
// C4127: conditional expression is constant
#pragma warning(disable : 4127)
// C4324: structure was padded due to alignment specifier
// Usage of alignas causes some internal padding in places.
#pragma warning(disable : 4324)
#endif

#include <absl/container/inlined_vector.h>

#ifdef _MSC_VER
#pragma warning(pop)
#endif
#endif // DISABLE_ABSEIL
#include <iosfwd>
#include <string>
#include <vector>

#include "core/common/gsl.h"
#include "core/common/inlined_containers_fwd.h"
#include "core/common/span_utils.h"
#include "onnxruntime_config.h"

namespace onnxruntime {
#ifdef __GNUC__
Expand All @@ -41,18 +24,10 @@ namespace onnxruntime {

constexpr size_t kTensorShapeSmallBufferElementsSize = 5;

#ifndef DISABLE_ABSEIL
// Use this type to build a shape and then create TensorShape.
using TensorShapeVector = absl::InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize>;
#else
class TensorShapeVector : public std::vector<int64_t> {
using Base = std::vector<int64_t>;

public:
using Base::Base;
};

#endif // DISABLE_ABSEIL
// We opt to re-use a common instantiation instead of a typedef with kTensorShapeSmallBufferElementsSize
// To reduce on binary size.
using TensorShapeVector = InlinedVector<int64_t>;

inline TensorShapeVector ToShapeVector(const gsl::span<const int64_t>& span) {
TensorShapeVector out;
Expand Down Expand Up @@ -194,9 +169,7 @@ class TensorShape {

friend struct ProviderHostImpl; // So that the shared provider interface can access Allocate
};
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

// operator<< to nicely output to a stream
std::ostream& operator<<(std::ostream& out, const TensorShape& shape);

Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ enum OrtDmlPerformancePreference {
};

enum OrtDmlDeviceFilter : uint32_t {
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
Any = 0xffffffff,
Gpu = 1 << 0,
Npu = 1 << 1,
#else
Gpu = 1 << 0,
#endif
};

inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; }
Expand Down
1 change: 0 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['BiasSplitGelu', [biasSplitGelu]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
['Ceil', [unaryOps.ceil]],
['ClipV10', [unaryOps.clipV10]],
['Clip', [unaryOps.clip]],
['Concat', [concat, parseConcatAttributes]],
['Conv', [conv, parseConvAttributes]],
Expand Down
19 changes: 8 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey {
readonly max: number;
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
Expand All @@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo
attributes.cacheKey),
{inputs: [0]});
};
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

export const clip = (context: ComputeContext): void => {
const attributes = generateClipAttributesFromInputs(context.inputs);
clipV10(context, attributes);
};

export const ceil = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));
Expand Down
30 changes: 18 additions & 12 deletions onnxruntime/contrib_ops/cuda/math/gemm_float8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL() \
ONNX_OPERATOR_KERNEL_EX( \
GemmFloat8, \
kMSDomain, \
1, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("TA", BuildKernelDefConstraints<Float8E4M3FN, Float8E5M2, MLFloat16, BFloat16, float>()) \
.TypeConstraint("TB", BuildKernelDefConstraints<Float8E4M3FN, Float8E5M2, MLFloat16, BFloat16, float>()) \
.TypeConstraint("TR", BuildKernelDefConstraints<Float8E4M3FN, Float8E5M2, MLFloat16, BFloat16, float>()) \
.TypeConstraint("TS", BuildKernelDefConstraints<float>()), \
#if !defined(DISABLE_FLOAT8_TYPES)
#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints<Float8E4M3FN, Float8E5M2, MLFloat16, BFloat16, float>()
#else
#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints<MLFloat16, BFloat16, float>()
#endif

#define REGISTER_KERNEL() \
ONNX_OPERATOR_KERNEL_EX( \
GemmFloat8, \
kMSDomain, \
1, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) \
.TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) \
.TypeConstraint("TR", GEMM_FLOAT8_CONSTRAINTS) \
.TypeConstraint("TS", BuildKernelDefConstraints<float>()), \
GemmFloat8);

REGISTER_KERNEL()
Expand All @@ -38,7 +44,7 @@ GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 1);
beta_ = info.GetAttrOrDefault<float>("beta", 0);

#if (CUDA_VERSION <= 12000)
#if (CUDA_VERSION < 12000)
ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0.");
#endif

Expand Down
27 changes: 18 additions & 9 deletions onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
return 2;
#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080))
#if !defined(DISABLE_FLOAT8_TYPES)
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
return 1;
Expand Down Expand Up @@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {
}

auto first_type = input_A->GetElementType();
#if !defined(DISABLE_FLOAT8_TYPES)
bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
if (!is_float8)
#endif
return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
input_C, scale_A, scale_B, scale_Y);
#if !defined(DISABLE_FLOAT8_TYPES)
return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
input_C, scale_A, scale_B, scale_Y);
#endif
}

Status GemmFloat8::ComputeRowMajor(
Expand Down Expand Up @@ -197,10 +201,15 @@ Status GemmFloat8::ComputeGemm(
switch (d_cuda_type) {
case CUDA_R_16F:
switch (a_cuda_type) {
#if !defined(DISABLE_FLOAT8_TYPES)
#if CUDA_VERSION < 11080
#error CUDA_R_8F_E4M3 (float 8 types) is defined with CUDA>=11.8. Set flag DISABLE_FLOAT8_TYPES.
#endif
case CUDA_R_8F_E4M3:
case CUDA_R_8F_E5M2:
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break;
#endif
default:
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
break;
Expand Down Expand Up @@ -267,7 +276,7 @@ Status GemmFloat8::ComputeGemm(
sizeof(p_scale_b)));

// float 8
#if CUDA_VERSION >= 11080
#if !defined(DISABLE_FLOAT8_TYPES)
if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN ||
dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) {
// For FP8 output, cuBLAS requires C_type to be same as bias_type
Expand All @@ -280,15 +289,14 @@ Status GemmFloat8::ComputeGemm(
CUBLAS_RETURN_IF_ERROR(
cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
}
} else {
CUBLAS_RETURN_IF_ERROR(
cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
}
#else
// An output is still needed but it is not initialized.
CUBLAS_RETURN_IF_ERROR(
cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
#endif
} else {
CUBLAS_RETURN_IF_ERROR(
cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
}

if (row_major_compute) {
cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW;
Expand Down Expand Up @@ -345,7 +353,7 @@ Status GemmFloat8::ComputeGemm(
". Check NVIDIA documentation to see what combination is valid: ",
"https://docs.nvidia.com/cuda/cublas/"
"index.html?highlight=cublasLtMatmulAlgoGetHeuristic#"
"cublasltmatmulalgogetheuristic.");
"cublasltmatmulalgogetheuristic. CUDA>=11.8 is required to use float 8 types.");

void* workspace = nullptr;
if (workspaceSize > 0) {
Expand Down Expand Up @@ -381,7 +389,8 @@ Status GemmFloat8::ComputeGemm(
", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x",
shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb,
", ldd=", ldd, ", workspaceSize=", workspaceSize,
", rowMajorCompute=", (row_major_compute ? 1 : 0), ".");
", rowMajorCompute=", (row_major_compute ? 1 : 0),
". CUDA>=11.8 is required to use float 8 types.");

if (workspaceSize > 0) {
CUDA_RETURN_IF_ERROR(cudaFree(workspace));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,19 @@ std::vector<NodeAndMoveInfo> WhereMoves() {
MoveAll(q, ArgType::kOutput)};
return moves;
}
QDQReplaceWithNew SplitReplacer() {
QDQReplaceWithNew SplitReplacer(bool has_split_as_input) {
NTO::NodeLocation dq{NTO::NodeType::kInput, 0};
NTO::NodeLocation target{NTO::NodeType::kTarget, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};
std::vector<NodeAndMoveInfo> moves{
MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput),
MoveAll(q, ArgType::kOutput)};
std::vector<NodeAndMoveInfo> moves{MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput)};

if (has_split_as_input) {
// Move the optional split input to the new node.
moves.push_back(MoveAndAppend(target, ArgType::kInput, 1, ArgType::kInput, true));
}

moves.push_back(MoveAll(q, ArgType::kOutput));

return QDQReplaceWithNew(kOnnxDomain, "Split", std::move(moves));
}

Expand Down Expand Up @@ -247,7 +254,12 @@ MatMulReplaceWithQLinear::MatMulReplaceWithQLinear()
}

Status SplitReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
return SplitReplacer().Run(graph, selected_nodes);
const auto& target_node = selected_nodes.Target();
const auto& input_defs = target_node.InputDefs();

// The 'split' attribute became an optional input at opset 13.
bool has_split_as_input = target_node.SinceVersion() >= 13 && input_defs.size() == 2;
return SplitReplacer(has_split_as_input).Run(graph, selected_nodes);
}

Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
const std::string action_name{"dropSplitQDQ"};
std::unique_ptr<Action> action = std::make_unique<QDQ::SplitReplaceWithQuant>();
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::OutputVariadicSelector>();
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::SplitSelector>(true /*req_equal_quant_params*/);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Split", {}}},
std::move(selector),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,39 @@ void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder
builder.num_input_defs = 1; // set to 1 as the first input is variadic
}

void OutputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) {
return false;
}

auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
};

const Node& dq_node = *dq_nodes.front();
int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();

// All Q outputs should have same data type and (optionally) equal quantization parameters as the input.
for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) {
const Node& q_node = *q_nodes[q_idx];

if (dt_input != q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) {
return false;
}

if (req_equal_quant_params_ &&
!IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) {
return false;
}
}

return true;
}

void SplitSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
builder.num_output_defs = 1; // set to 1 as the first output is variadic
}

Expand Down
Loading

0 comments on commit 8f6ea60

Please sign in to comment.