Skip to content

Commit

Permalink
[mlir] Fix infinite recursion in alias initializer
Browse files Browse the repository at this point in the history
The alias initializer keeps a list of child indices around. When an alias is then marked as non-deferrable, all children are also marked non-deferrable.

This is currently done naively which leads to an infinite recursion if using mutable types or attributes containing a cycle.

This patch fixes this by adding an early return if the alias is already marked non-deferrable. Since this function is the only way to mark an alias as non-deferrable, it is guaranteed that if it is marked non-deferrable, all its children are as well, and it is not required to walk all the children.
This incidentally makes the non-deferrable marking also `O(n)` instead of `O(n^2)` (although not performance sensitive obviously).

Differential Revision: https://reviews.llvm.org/D158932
  • Loading branch information
zero9178 authored and tru committed Aug 31, 2023
1 parent cc7e24c commit b66219d
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 3 deletions.
6 changes: 6 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(

void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
auto it = std::next(aliases.begin(), aliasIndex);

// If already marked non-deferrable stop the recursion.
// All children should already be marked non-deferrable as well.
if (!it->second.canBeDeferred)
return;

it->second.canBeDeferred = false;

// Propagate the non-deferrable flag to any child aliases.
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
Expand All @@ -12,6 +14,16 @@ func.func @roundtrip() {
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>

// CHECK: () -> ![[$NAME]]
// CHECK: () -> ![[$NAME]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>

// CHECK: () -> ![[$NAME2]]
// CHECK: () -> ![[$NAME2]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
return
}

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
os << recAliasType.getName();
return AliasResult::FinalAlias;
}
return AliasResult::NoAlias;
}

Expand Down
22 changes: 22 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,4 +369,26 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}

def TestI32 : Test_Type<"TestI32"> {
let mnemonic = "i32";
}

def TestRecursiveAlias
: Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
let mnemonic = "test_rec_alias";
let storageClass = "TestRecursiveTypeStorage";
let storageNamespace = "test";
let genStorageClass = 0;

let parameters = (ins "llvm::StringRef":$name);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
Type getBody() const;

void setBody(Type type);
}];
}

#endif // TEST_TYPEDEFS
51 changes: 51 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
SetVector<Type> stack;
printTestType(type, printer, stack);
}

Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }

void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }

StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }

Type TestRecursiveAliasType::parse(AsmParser &parser) {
thread_local static SetVector<Type> stack;

StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveAliasType::get(parser.getContext(), name);

// If this type already has been parsed above in the stack, expect just the
// name.
if (stack.contains(rec)) {
if (failed(parser.parseGreater()))
return Type();
return rec;
}

// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype;
if (parser.parseType(subtype))
return nullptr;
stack.pop_back();
if (!subtype || failed(parser.parseGreater()))
return Type();

rec.setBody(subtype);

return rec;
}

void TestRecursiveAliasType::print(AsmPrinter &printer) const {
thread_local static SetVector<Type> stack;

printer << "<" << getName();
if (!stack.contains(*this)) {
printer << ", ";
stack.insert(*this);
printer << getBody();
stack.pop_back();
}
printer << ">";
}
6 changes: 3 additions & 3 deletions mlir/test/lib/Dialect/Test/TestTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {

#include "TestTypeInterfaces.h.inc"

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"

namespace test {

/// Storage for simple named recursive types, where the type is identified by
Expand Down Expand Up @@ -150,4 +147,7 @@ class TestRecursiveType

} // namespace test

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"

#endif // MLIR_TESTTYPES_H

0 comments on commit b66219d

Please sign in to comment.