Skip to content

Commit

Permalink
bug fix for recursion in VarRewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
leissa committed Mar 16, 2024
1 parent a15a0a3 commit 40c2da1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 67 deletions.
1 change: 1 addition & 0 deletions include/thorin/phase/phase.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Pipeline : public Phase {
/// @name phases
///@{
const auto& phases() const { return phases_; }

/// Add a Phase.
/// You don't need to pass the World to @p args - it will be passed automatically.
/// If @p P is a Pass, this method will wrap this in a PassPhase.
Expand Down
33 changes: 13 additions & 20 deletions include/thorin/rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,38 +48,31 @@ class ScopeRewriter : public Rewriter {
const Scope& scope_;
};

/// Stops rewriting when leaving the Scope.
class VarRewriter : public Rewriter {
public:
VarRewriter(Ref var, Ref subst)
VarRewriter(const Var* var, Ref arg)
: Rewriter(var->world()) {
if (var) {
if (auto w = var->isa<Var>()) {
vars_.emplace(w);
map(w, subst);
} else {
assert(var == subst);
}
}
vars_.emplace(var);
map(var, arg);
}

Ref rewrite_imm(Ref imm) override {
if (imm->local_vars().empty() && imm->local_muts().empty()) return imm; // safe to skip
return Rewriter::rewrite_imm(imm);
}

Ref rewrite_mut(Def* mut) override {
if (mut->sym().str() == "pow_else") outln("hey");
if (descend(mut)) {
if (auto var = mut->has_var()) {
vars_.emplace(var);
return Rewriter::rewrite_mut(mut);
}
if (auto var = mut->has_var()) vars_.emplace(var);
return Rewriter::rewrite_mut(mut);
}
return map(mut, mut);
}

bool descend(Def* mut) const {
for (auto op : mut->extended_ops()) {
auto fvs = op->free_vars();
for (auto var : vars_)
if (fvs.contains(var)) return true;
}
auto fvs = mut->free_vars();
for (auto var : vars_)
if (fvs.contains(var)) return true;
return false;
}

Expand Down
71 changes: 33 additions & 38 deletions lit/core/pow.thorin
Original file line number Diff line number Diff line change
@@ -1,48 +1,44 @@
// RUN: rm -f %t.ll
// RUN: %thorin %s --output-ll %t.ll -o - | FileCheck %s
//
// if b<=0:
// 1
// else
// a*pow(a,b-1)
//
// pow(a,b,ret):
// ((pow_else,pow_then)#cmp) ()
// then():
// ret 1
// else():
// pow(a,b-1,cont)
// cont(v):
// ret (a*v)

.plugin core;

/// if b<=0:
/// 1
/// else
/// a*pow(a,b-1)
///
/// pow(a,b,ret):
/// ((pow_else,pow_then)#cmp) ()
/// then():
/// ret 1
/// else():
/// pow(a,b-1,cont)
/// cont(v):
/// ret (a*v)
///
.con pow ((a b: %core.I32), ret: .Cn %core.I32) = {
.con pow_then [] = ret 1:%core.I32;

.con pow_cont [v:%core.I32] = {
.let m = %core.wrap.mul 0 (a,v);
ret m
};
.con pow_else [] = {
.let b_1 = %core.wrap.sub 0 (b,1:%core.I32);
pow ((a,b_1),pow_cont)
};
.let cmp = %core.icmp.e (b,0:%core.I32);
((pow_else, pow_then)#cmp) ()
};

.con .extern main [mem : %mem.M, argc : %core.I32, argv : %mem.Ptr (%mem.Ptr (%core.I8, 0), 0), return : .Cn [%mem.M, %core.I32]] = {
.con ret_cont r::[%core.I32] = return (mem, r);
.fun pow(a b: %core.I32): %core.I32 =
.con pow_cont(v: %core.I32) =
return (%core.wrap.mul 0 (a, v));

.let c = (42:%core.I32, 2:%core.I32);
pow (c,ret_cont)
};
.con pow_then() =
return 1:%core.I32;

.con pow_else() =
.let b_1 = %core.wrap.sub 0 (b, 1:%core.I32);
pow ((a, b_1), pow_cont);

.let cmp = %core.icmp.e (b, 0:%core.I32);
(pow_else, pow_then)#cmp ();

// CHECK-DAG: .con pow_{{[0-9_]+}} _{{[0-9_]+}}::[b_{{[0-9_]+}}: .Idx 4294967296, ret_{{[0-9_]+}}: .Cn .Idx 4294967296]{{(@.*)?}}= {
// CHECK-DAG: .con ret_{{[0-9_]+}} _{{[0-9_]+}}: .Idx 4294967296{{(@.*)?}}= {
// CHECK-DAG: ret_{{[0-9_]+}} _{{[0-9_]+}}
.fun .extern main(mem: %mem.M, argc: %core.I32, argv: %mem.Ptr0 (%mem.Ptr0 %core.I8)): [%mem.M, %core.I32] =
.con ret_cont(res: %core.I32) = return (mem, res);

.let c = (42:%core.I32, 2:%core.I32);
pow (c,ret_cont);
// CHECK-DAG: .con pow_{{[0-9_]+}} _{{[0-9_]+}}::[b_{{[0-9_]+}}: .Idx 4294967296, return_{{[0-9_]+}}: .Cn .Idx 4294967296]{{(@.*)?}}= {
// CHECK-DAG: .con return_{{[0-9_]+}} _{{[0-9_]+}}: .Idx 4294967296{{(@.*)?}}= {
// CHECK-DAG: return_{{[0-9_]+}} _{{[0-9_]+}}

// CHECK-DAG: .con pow_then_{{[0-9_]+}} []{{(@.*)?}}= {
// CHECK-DAG: _{{[0-9_]+}} 1:(.Idx 4294967296)
Expand All @@ -57,4 +53,3 @@

// CHECK-DAG: .let _{{[0-9_]+}}: .Idx 2 = %core.icmp.xyglE 4294967296 (0:(.Idx 4294967296), b_{{[0-9_]+}});
// CHECK-DAG: (pow_else_{{[0-9_]+}}, pow_then_{{[0-9_]+}})#_{{[0-9_]+}} ()

16 changes: 7 additions & 9 deletions src/thorin/rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,13 @@ DefVec rewrite(Def* mut, Ref arg, const Scope& scope) {
}

DefVec rewrite(Def* mut, Ref arg) {
#if 1
Scope scope(mut);
return rewrite(mut, arg, scope);
#else
VarRewriter rw(mut->var(), arg);
DefVec result(mut->num_ops());
for (size_t i = 0, e = result.size(); i != e; ++i) result[i] = rw.rewrite(mut->op(i));
return result;
#endif
if (auto var = mut->has_var()) {
auto rw = VarRewriter(var, arg);
DefVec result(mut->num_ops());
for (size_t i = 0, e = result.size(); i != e; ++i) result[i] = rw.rewrite(mut->op(i));
return result;
}
return DefVec(mut->ops().begin(), mut->ops().end());
}

} // namespace thorin

0 comments on commit 40c2da1

Please sign in to comment.