Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GroupNorm to support Opset21 #2928

Merged
merged 49 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c64ad53
Group norm for opset 21
MegoHam21 Aug 12, 2024
db003ba
Merge branch 'onnx:main' into hamptonm/feature/groupnorm
hamptonm1 Aug 29, 2024
fe10937
Testing phase
MegoHam21 Aug 29, 2024
8979ccb
Merge branch 'onnx:main' into hamptonm/feature/groupnorm
hamptonm1 Sep 3, 2024
fed2948
Fix GroupNorm to support Opset21
MegoHam21 Sep 3, 2024
b650554
Merge branch 'hamptonm/feature/groupnorm' of https://github.com/hampt…
MegoHam21 Sep 3, 2024
35692e8
Merge branch 'main' into hamptonm/feature/groupnorm
hamptonm1 Sep 3, 2024
7209660
Fix format
MegoHam21 Sep 3, 2024
f2e87a4
Merge branch 'hamptonm/feature/groupnorm' of https://github.com/hampt…
MegoHam21 Sep 3, 2024
52721fc
Still fixing format here
MegoHam21 Sep 3, 2024
32d0132
Trying this out
MegoHam21 Sep 4, 2024
fb4b39c
Logic branches
MegoHam21 Sep 4, 2024
c6a0ebb
Fix linter
MegoHam21 Sep 4, 2024
b36f32d
Try again
MegoHam21 Sep 4, 2024
a1a15e3
Update Dockerfile.llvm-project
hamptonm1 Sep 5, 2024
3b0bf5f
Update requirements.txt
hamptonm1 Sep 5, 2024
770b1c4
Merge branch 'main' into hamptonm/feature/groupnorm
hamptonm1 Sep 5, 2024
3f33e0b
Update requirements.txt
hamptonm1 Sep 5, 2024
6b7dc63
Testing it
MegoHam21 Sep 5, 2024
fd20266
Merge branch 'hamptonm/feature/groupnorm' of https://github.com/hampt…
MegoHam21 Sep 5, 2024
8d50a48
Merge branch 'main' into hamptonm/feature/groupnorm
hamptonm1 Sep 5, 2024
76de507
Fix format
MegoHam21 Sep 5, 2024
31e542c
Merge branch 'hamptonm/feature/groupnorm' of https://github.com/hampt…
MegoHam21 Sep 5, 2024
07b4e51
Testing still
MegoHam21 Sep 6, 2024
a3e8359
Still figuring this out
MegoHam21 Sep 7, 2024
6b3f4dd
Try it out
MegoHam21 Sep 10, 2024
da9c0ab
Format fixed
MegoHam21 Sep 10, 2024
41ee2b9
Merge branch 'main' into hamptonm/feature/groupnorm
hamptonm1 Sep 10, 2024
b9fa3b0
Grr
MegoHam21 Sep 10, 2024
654b488
Fixing format
MegoHam21 Sep 10, 2024
30337e5
Keep going
MegoHam21 Sep 10, 2024
08ff821
Add verifier and cleanup code
MegoHam21 Sep 11, 2024
f7d53f7
Fix failure and add lit test
MegoHam21 Sep 11, 2024
2a26b17
Merge branch 'main' into hamptonm/feature/groupnorm
hamptonm1 Sep 11, 2024
424d807
Fixing up
MegoHam21 Sep 11, 2024
16d4c4e
Format again
MegoHam21 Sep 11, 2024
84a10c8
Format verifier
MegoHam21 Sep 11, 2024
d0b51e2
Figure out why it is not working
MegoHam21 Sep 11, 2024
ca8da95
My fault
MegoHam21 Sep 11, 2024
d8b26f9
Update Decompose.cpp
hamptonm1 Sep 11, 2024
6d35971
Update Decompose.cpp again
hamptonm1 Sep 11, 2024
327e6ef
Update Decompose.cpp part 3
hamptonm1 Sep 11, 2024
ba471a1
Still at it
hamptonm1 Sep 12, 2024
6c2de2d
Test it!
hamptonm1 Sep 12, 2024
76e8422
I think I got it to work
MegoHam21 Sep 12, 2024
de19219
Hmmm
MegoHam21 Sep 12, 2024
e459d52
Address comments
MegoHam21 Sep 13, 2024
0f1b161
Fix format
MegoHam21 Sep 13, 2024
41eb8dd
Merge branch 'main' into hamptonm/feature/groupnorm
hamptonm1 Sep 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions docs/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -3589,6 +3589,63 @@ where the mean and variance are computed per instance per group of channels, and
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

The overall computation has two stages: the first stage normalizes the elements to
have zero mean and unit variance for each instance in each group, and the second
stage scales and shifts the results of the first stage. The floating-point precision
used in the first stage is determined by the `stash_type` attribute. For example,
if `stash_type` is 1, the operator casts all input variables to 32-bit float,
performs the computation, and finally casts the normalized results back to the
original type of `X`. The second stage does not depend on `stash_type`.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>epsilon</code></td><td>::mlir::FloatAttr</td><td>32-bit float attribute</td></tr>
<tr><td><code>num_groups</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
<tr><td><code>stash_type</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
| `scale` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
| `bias` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values

#### Results:

| Result | Description |
| :----: | ----------- |
| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values

### `onnx.GroupNormalizationV18` (ONNXGroupNormalizationV18Op)

_ONNX GroupNormalization operation_

A GroupNormalization function. Carries out group normalization as described in
the paper https://arxiv.org/abs/1803.08494

This operator transforms input according to
```
y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
```
where the mean and variance are computed per instance per group of channels, and
`scale` and `bias` should be specified for each group of channels. The number of
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.
Expand Down
4 changes: 2 additions & 2 deletions docs/SupportedONNXOps-NNPA.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# Supported ONNX Operation for Target *NNPA*.

Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.
Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.

* Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md).
* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21.


NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA.
Expand Down
4 changes: 2 additions & 2 deletions docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# Supported ONNX Operation for Target *cpu*.

Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.
Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.

* Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md).
* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21.


| Op |Supported Opsets (inclusive) |Limitations |Notes |
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
lit~=15.0
# numpy 1.24 deprecates np.object, np.bool, np.float, np.complex, np.str,
# and np.int which are used heavily in onnx-mlir.
numpy~=1.22.2, <=1.23.5
numpy==2.0.1
onnx==1.16.2
protobuf==4.21.12
pytest~=7.2
pytest-xdist~=3.0
pytest==8.3.2
pytest-xdist==3.6.1
4 changes: 3 additions & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ op_dialect_version_map_["Gradient"] = {1};
op_dialect_version_map_["Greater"] = {13};
op_dialect_version_map_["GreaterOrEqual"] = {16};
op_dialect_version_map_["GridSample"] = {16};
op_dialect_version_map_["GroupNormalization"] = {18};
op_dialect_version_map_["GroupNormalization"] = {21, 18};
op_dialect_version_map_["HammingWindow"] = {17};
op_dialect_version_map_["HannWindow"] = {17};
op_dialect_version_map_["HardSigmoid"] = {6};
Expand Down Expand Up @@ -358,6 +358,8 @@ import_handler_map_["GridSample"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGridSampleOp>;
import_handler_map_["GroupNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGroupNormalizationOp>;
import_handler_map_["GroupNormalizationV18"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGroupNormalizationV18Op>;
import_handler_map_["HammingWindow"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXHammingWindowOp>;
import_handler_map_["HannWindow"] =
Expand Down
15 changes: 14 additions & 1 deletion src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===----- DialectBuilder.cpp - Helper functions for ONNX dialects -------===//
//
// Copyright 2019-2023 The IBM Research Authors.
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -164,6 +164,19 @@ Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
toTensor(bias), axisAttr, epsilon, stashTypeAttr);
return layerNormOp.getY();
}
// In the case of GroupNormalization when stashType can be specified
Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
Value bias, int64_t axis, FloatAttr epsilon, IntegerAttr stashType) const {
IntegerAttr axisAttr = getSignedInt64Attr(axis);
Value noneVal = none();
Type noneType = noneVal.getType();
ONNXLayerNormalizationOp layerNormOp =
createOpAndInferShapes<ONNXLayerNormalizationOp>(
/*Y type*/ toTensor(outputType), /*mean type*/ noneType,
/*std dev Type*/ noneType, toTensor(input), toTensor(scale),
toTensor(bias), axisAttr, epsilon, stashType);
return layerNormOp.getY();
}

Value OnnxBuilder::qlinearMatMul(Type outputType, Value a, Value aScale,
Value aZeroPoint, Value b, Value bScale, Value bZeroPoint, Value yScale,
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
mlir::FloatAttr epsilon) const;
// In the case of GroupNormalization when stashType can be specified
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
mlir::FloatAttr epsilon, mlir::IntegerAttr stashType) const;

// ONNXQLinearMatMulOp
mlir::Value qlinearMatMul(mlir::Type outputType, mlir::Value a,
Expand Down
59 changes: 58 additions & 1 deletion src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3122,6 +3122,62 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization",
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

The overall computation has two stages: the first stage normalizes the elements to
have zero mean and unit variance for each instance in each group, and the second
stage scales and shifts the results of the first stage. The floating-point precision
used in the first stage is determined by the `stash_type` attribute. For example,
if `stash_type` is 1, the operator casts all input variables to 32-bit float,
performs the computation, and finally casts the normalized results back to the
original type of `X`. The second stage does not depend on `stash_type`.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.
}];
let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X,
AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$scale,
AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$bias,
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
SI64Attr:$num_groups,
DefaultValuedAttr<SI64Attr, "1">:$stash_type);
let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 3;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {30};
}
}];
let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
}

def ONNXGroupNormalizationV18Op:ONNX_Op<"GroupNormalizationV18",
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let summary = "ONNX GroupNormalization operation";
let description = [{
A GroupNormalization function. Carries out group normalization as described in
the paper https://arxiv.org/abs/1803.08494

This operator transforms input according to
```
y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
```
where the mean and variance are computed per instance per group of channels, and
`scale` and `bias` should be specified for each group of channels. The number of
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.
Expand All @@ -3146,11 +3202,12 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization",
let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope);
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationV18OpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
let hasVerifier = 1;
}

def ONNXHammingWindowOp:ONNX_Op<"HammingWindow",
Expand Down
15 changes: 15 additions & 0 deletions src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ LogicalResult ONNXInstanceNormalizationOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// GroupNormalizationV18
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexandreEichenberger Is this what you were thinking of? I added a print because if I add emitWarning it seems like all opset 18 tests fail and I figure we can still enable support for the meantime. I am fine with removing support in the near future but was not sure if any model still uses GroupNorm Opset 18.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, in general, we can also test other properties, but since this op is not going to be used, its fine.

//===----------------------------------------------------------------------===//
LogicalResult ONNXGroupNormalizationV18Op::verify() {
ONNXGroupNormalizationV18OpAdaptor(*this);
llvm::outs()
<< "Warning: The previous understanding of Opset 18 for "
"GroupNormalization "
"is incorrect. As shown in the following issue: "
"https://github.com/onnx/onnx/issues/5466.Rather, use Opset 21 for "
"GroupNormalization instead."
<< "/n";
return success();
}

// TODO: should there be a shape inference for this one?

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXUnsupportedOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV13Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV11Op)
Expand Down
Loading
Loading