Skip to content

Commit

Permalink
Fix location emission
Browse files Browse the repository at this point in the history
  • Loading branch information
driazati committed Oct 21, 2022
1 parent 0cf4cb8 commit dec0d8a
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 52 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py

# Printed TIR code on disk
*.tir

# GDB history file
.gdb_history
74 changes: 74 additions & 0 deletions gallery/tutorial/debug_tir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
.. _tutorial-topi:
Debugging TIR
=============
"""

# sphinx_gallery_start_ignore
from tvm import testing

testing.utils.install_request_hook(depth=3)
# sphinx_gallery_end_ignore

import tvm
import tvm.testing
import numpy as np
from tvm.script import tir as T

# Installing dependencies
#
# .. code-block:: bash
#
# pip install -q tensorflow
# apt-get -qq install curl


@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
# A block is an abstraction for computation.
with T.block("B"):
# Define a spatial block iterator and bind it to value i.
vi = T.axis.spatial(8, i)
assert 1 == 0, "Some numbers"
B[vi] = A[vi] + 1.0


print("Actually starting ------")
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}):
runtime_module = tvm.build(MyModule, target="llvm")

# print(runtime_module.get_source())
print(type(runtime_module))

a = tvm.nd.array(np.arange(8).astype("float32"))
b = tvm.nd.array(np.zeros((8,)).astype("float32"))
print("EXECUTING ------")
runtime_module(a, b)
print(a)
print(b)
49 changes: 45 additions & 4 deletions src/printer/tir_text_printer_debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,66 @@

#include "tir_text_printer_debug.h"

#include <optional>
#include <string>

#include "text_printer.h"

namespace tvm {
namespace tir {

std::string span_text(const Span& span) {
std::optional<std::string> span_text(const Span& span) {
if (!span.defined()) {
return "missing";
return std::nullopt;
}
std::string source("file");

std::string source("main.tir");
// TODO(driazati): This segfaults even with a guard around source_name, so the
// filename always defaults to main.tir (llvm ignores this filename anyways)
// if (span->source_name.defined()) {
// source = span->source_name->name;
// }
return source + ":" + std::to_string(span->line) + ":" + std::to_string(span->column);
}

template <typename ObjectPtr>
void add_all_relevant_lines(const std::vector<std::tuple<const ObjectPtr*, size_t>>& data,
size_t current_line, Doc* output) {
ICHECK(output) << "output must be a valid Doc";
for (const auto& item : data) {
if (std::get<1>(item) != current_line - 1) {
// Item is not relevant for this line, skip it
continue;
}

// Print out the item's span info if present
auto text = span_text(std::get<0>(item)->span);
if (text.has_value()) {
*output << *text;
} else {
*output << "missing";
}
*output << ", ";
}
}

Doc TIRTextPrinterDebug::NewLine() {
current_line_ += 1;

return TIRTextPrinter::NewLine();
if (!show_spans_) {
return TIRTextPrinter::NewLine();
}

Doc output;

output << " [";

add_all_relevant_lines(exprs_by_line_, current_line_, &output);
add_all_relevant_lines(stmts_by_line_, current_line_, &output);

output << "]" << TIRTextPrinter::NewLine();

return output;
}

#define X(TypeName) \
Expand Down
6 changes: 5 additions & 1 deletion src/printer/tir_text_printer_debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace tir {

class TIRTextPrinterDebug : public TIRTextPrinter {
public:
TIRTextPrinterDebug() : TIRTextPrinter(false, &meta_), current_line_(1) {}
explicit TIRTextPrinterDebug(bool show_spans)
: TIRTextPrinter(false, &meta_), current_line_(1), show_spans_(show_spans) {}

std::vector<std::tuple<const PrimExprNode*, size_t>> GetExprsByLine() const {
return exprs_by_line_;
Expand All @@ -61,6 +62,9 @@ class TIRTextPrinterDebug : public TIRTextPrinter {
// Line that the printer is currently printing
size_t current_line_;

// Whether to include spans relevant to each line before a newline or not
bool show_spans_;

// Record of all stmts and exprs and their corresponding line
std::vector<std::tuple<const StmtNode*, size_t>> stmts_by_line_;
std::vector<std::tuple<const PrimExprNode*, size_t>> exprs_by_line_;
Expand Down
2 changes: 0 additions & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,6 @@ llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lo
}

llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) {
EmitDebugLocation(op);
ICHECK_EQ(op->args.size(), 6U);
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value, true);
Expand Down Expand Up @@ -1388,7 +1387,6 @@ void CodeGenCPU::AddStartupFunction() {
}

llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
EmitDebugLocation(op);
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
return CreateCallPacked(op, true /* use_string_lookup */);
} else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) {
Expand Down
33 changes: 2 additions & 31 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,6 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
EmitDebugLocation(op);
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
ICHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
Expand Down Expand Up @@ -1226,7 +1225,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::bitwise_not())) {
return builder_->CreateNot(MakeValue(op->args[0]));
} else if (op->op.same_as(builtin::bitwise_xor())) {
EmitDebugLocation(op);
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->op.same_as(builtin::shift_left())) {
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
Expand Down Expand Up @@ -1353,29 +1351,20 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::V
}

// Visitors
llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) {
EmitDebugLocation(op);
return GetVarValue(op);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }

llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
EmitDebugLocation(op);
return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
EmitDebugLocation(op);
return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
EmitDebugLocation(op);
return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
EmitDebugLocation(op);
return GetConstString(op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }

#define DEFINE_CODEGEN_BINARY_OP(Op) \
llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
Expand All @@ -1397,7 +1386,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
EmitDebugLocation(op); \
return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
}

Expand All @@ -1417,7 +1405,6 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
EmitDebugLocation(op); \
return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
}

Expand All @@ -1427,7 +1414,6 @@ DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);

llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
Expand All @@ -1441,7 +1427,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
Expand All @@ -1455,21 +1440,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
Expand All @@ -1480,7 +1462,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
Expand All @@ -1491,28 +1472,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
EmitDebugLocation(op);
return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
EmitDebugLocation(op);
return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
EmitDebugLocation(op);
return builder_->CreateNot(MakeValue(op->a));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
EmitDebugLocation(op);
return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
MakeValue(op->false_value));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
EmitDebugLocation(op);
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second->value, op->value))
Expand Down Expand Up @@ -1630,7 +1606,6 @@ void CodeGenLLVM::BufferAccessHelper(
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
EmitDebugLocation(op);
DataType value_dtype = op->dtype;

std::vector<llvm::Value*> loads;
Expand Down Expand Up @@ -1668,7 +1643,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
EmitDebugLocation(op);
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);
if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
Expand All @@ -1695,7 +1669,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
EmitDebugLocation(op);
llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
for (int i = 0; i < op->lanes; ++i) {
vec = builder_->CreateInsertElement(
Expand All @@ -1705,7 +1678,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
EmitDebugLocation(op);
std::vector<llvm::Value*> vecs(op->vectors.size());
int total_lanes = 0;
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
Expand All @@ -1730,7 +1702,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
EmitDebugLocation(op);
return CreateBroadcast(MakeValue(op->value), op->lanes);
}

Expand Down
Loading

0 comments on commit dec0d8a

Please sign in to comment.