Skip to content

Commit

Permalink
Revert "[mlir][mesh] adding shard-size control (#98145)"
Browse files Browse the repository at this point in the history
This reverts commit fca6983.

Also reverts the fixup: "[mlir] Fix -Wunused-variable in MeshOps.cpp (NFC)"

This reverts commit fc73736.
  • Loading branch information
rengolin committed Aug 7, 2024
1 parent d07f106 commit 3968942
Show file tree
Hide file tree
Showing 28 changed files with 695 additions and 1,641 deletions.
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)

set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)

set(LLVM_TARGET_DEFINITIONS MeshOps.td)
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
Expand Down
112 changes: 90 additions & 22 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
Expand All @@ -32,13 +31,11 @@ def Mesh_Dialect : Dialect {
];

let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
let hasConstantMaterializer = 1;
}

def Mesh_MeshAxis : I<16>;
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;

//===----------------------------------------------------------------------===//
// Mesh Enums.
Expand All @@ -62,33 +59,104 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
}

def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
let assemblyFormat = "$value";
}

class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
let mnemonic = typeMnemonic;
}

def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
let summary = "sharding definition";
let assemblyFormat = "";
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//

def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
let mnemonic = "axisarray";
let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
let assemblyFormat = "`[` $axes `]`";
def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let mnemonic = "shard";

let parameters = (ins
AttrParameter<"::mlir::FlatSymbolRefAttr",
"The mesh on which tensors are sharded.">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
);

let summary = "Attribute that extends tensor type to distributed tensor type.";

let description = [{
The MeshSharding attribute is used in a `mesh.shard` operation.
It specifies how a tensor is sharded and distributed across the process
mesh.

1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
mesh where the distributed tensor is placed. The symbol must resolve to a
`mesh.mesh` operation.

2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
along the x and y axes of the device mesh.

3. `partial_axes`: if not empty, this signifies that the tensor is partial
one along the specified mesh axes. An all-reduce should be applied to obtain
the complete tensor, with reduction type being specified by `partial_type`.

4. `partial_type`: indicates the reduction type of the possible all-reduce
op. It has 4 possible values:
`generic`: is not an allowed value inside a shard attribute.

Example:

```
mesh.mesh @mesh0(shape = 2x2x4)

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
#mesh.shard<@mesh0, [[]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0
#mesh.shard<@mesh0, [[0]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
#mesh.shard<@mesh0, [[0]], partial = max[1]>

// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
```
}];
let assemblyFormat = [{
`<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
$partial_axes^ `]`)? `>`
}];

let builders = [
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::ReductionKind": $partial_type), [{
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
});
return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
partial_type);
}]>,
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
}]>
];

let extraClassDeclaration = [{
size_t size() const { return getAxes().size(); }
auto begin() const { return getAxes().begin(); }
auto end() const { return getAxes().end(); }
bool operator==(::mlir::Attribute rhs) const;
bool operator!=(::mlir::Attribute rhs) const;
bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
}];

let genVerifyDecl = 1;
}

#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
79 changes: 11 additions & 68 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ namespace mesh {

using MeshAxis = int16_t;
using MeshAxesAttr = DenseI16ArrayAttr;
using ShardShapeAttr = DenseI64ArrayAttr;
using HaloSizePairAttr = DenseI64ArrayAttr;

} // namespace mesh
} // namespace mlir
Expand All @@ -35,59 +33,6 @@ using HaloSizePairAttr = DenseI64ArrayAttr;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"

namespace mlir {
namespace mesh {

class MeshSharding {
private:
::mlir::FlatSymbolRefAttr mesh;
SmallVector<MeshAxesAttr> split_axes;
SmallVector<MeshAxis> partial_axes;
ReductionKind partial_type;
SmallVector<int64_t> static_halo_sizes;
SmallVector<int64_t> static_sharded_dims_sizes;
SmallVector<Value> dynamic_halo_sizes;
SmallVector<Value> dynamic_sharded_dims_sizes;

public:
MeshSharding() = default;
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
ArrayRef<MeshAxis> partial_axes_ = {},
ReductionKind partial_type_ = ReductionKind::Sum,
ArrayRef<int64_t> static_halo_sizes_ = {},
ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
ArrayRef<int64_t> getStaticShardedDimsSizes() const {
return static_sharded_dims_sizes;
}
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
ArrayRef<Value> getDynamicShardedDimsSizes() const {
return dynamic_sharded_dims_sizes;
}
operator bool() const { return (!mesh) == false; }
bool operator==(Value rhs) const;
bool operator!=(Value rhs) const;
bool operator==(const MeshSharding &rhs) const;
bool operator!=(const MeshSharding &rhs) const;
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
};

} // namespace mesh
} // namespace mlir

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"

Expand All @@ -105,9 +50,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
}

// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshSharding sharding) {
return sharding.getPartialAxes().empty() &&
llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
inline bool isFullReplication(MeshShardingAttr attr) {
return attr.getPartialAxes().empty() &&
llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
return axes.asArrayRef().empty();
});
}
Expand Down Expand Up @@ -135,10 +80,8 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
template <>
inline mesh::MeshOp
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
return getMesh(
op.getOperation(),
cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
symbolTableCollection);
return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
symbolTableCollection);
}

// Get the number of processes that participate in each group
Expand Down Expand Up @@ -188,22 +131,22 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
// result in a shape for each shard of ?x2x?.
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
MeshSharding sharding);
MeshShardingAttr sharding);

// If ranked tensor type return its sharded counterpart.
//
// If not ranked tensor type return `type`.
// `sharding` in that case must be null.
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);

// Insert shard op if there is not one that already has the same sharding.
// May insert resharding if required.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
OpOperand &operand,
OpBuilder &builder);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
OpResult result, OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
OpOperand &operand,
OpBuilder &builder);

Expand Down
Loading

0 comments on commit 3968942

Please sign in to comment.