Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Diagnostics][Relay][InferType] Refactor InferType to work on whole module, and use new diagnostics. #6274

Merged
merged 40 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6361efc
Refactor the type checker to use diagnostics
jroesch Aug 9, 2020
876332b
Apply suggestions from code review
jroesch Oct 2, 2020
2e1627e
Apply suggestions from code review
jroesch Oct 2, 2020
59171bd
Clean up parser
jroesch Oct 2, 2020
781f550
CR feedback
jroesch Oct 2, 2020
c31538e
Apply Bobs suggestions
jroesch Oct 2, 2020
a36a591
Fix up Python interface for diagnostics
jroesch Oct 2, 2020
0540238
Fix test_ir_parser and formatting
jroesch Oct 2, 2020
c378954
Fix cpplint
jroesch Oct 2, 2020
bdbada8
Fix lint
jroesch Oct 2, 2020
e767be3
Fix format
jroesch Oct 2, 2020
df5fe25
More lint
jroesch Oct 2, 2020
5870cbc
Fix format
jroesch Oct 5, 2020
236b39e
Kill dead doc comment
jroesch Oct 5, 2020
84ae8ea
Fix documentation comment
jroesch Oct 5, 2020
9c4c5fd
Rebase fixups
jroesch Oct 5, 2020
2f94a64
Add docs for type.h
jroesch Oct 5, 2020
b81b7a0
Fix parser.cc
jroesch Oct 5, 2020
05b4c7e
Fix unittests
jroesch Oct 5, 2020
58d4dd4
Fix black
jroesch Oct 5, 2020
33847e6
Skip previously typechecked functions
jroesch Oct 6, 2020
e0c1de2
fix ACL
comaniac Oct 6, 2020
b080eac
Fix numerous issues
jroesch Oct 6, 2020
743bbff
Add repr method
jroesch Oct 7, 2020
0281041
Fix issue with Pytest, I am ready to cry
jroesch Oct 7, 2020
987e92a
Fix the rest of tests
jroesch Oct 7, 2020
f829028
Kill dead code
jroesch Oct 7, 2020
2209b08
Fix dignostic tests
jroesch Oct 7, 2020
222b7c1
Fix more tests
jroesch Oct 7, 2020
433deeb
fix more tests (#11)
zhiics Oct 7, 2020
e81eadb
Fix diagnostic.py deinit bug
jroesch Oct 7, 2020
f81e85e
Fix deinit issue
jroesch Oct 8, 2020
ca065d8
Format
jroesch Oct 8, 2020
46f9658
Tweak disabling of override
jroesch Oct 8, 2020
885e404
Format
jroesch Oct 8, 2020
f6c0524
Fix BYOC
jroesch Oct 8, 2020
b32cf3a
Fix TensorArray stuff
jroesch Oct 9, 2020
d8c85b5
Fix PyTorch
jroesch Oct 9, 2020
9e22dbd
Format
jroesch Oct 9, 2020
c2f368b
Format
jroesch Oct 9, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ if(MSVC)
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
add_definitions(-DNOMINMAX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ jvminstall:
mvn install -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(PKG_CFLAGS)" -Dldflags="$(PKG_LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/$(OUTPUTDIR)" $(JVM_TEST_ARGS))
format:
./tests/lint/git-clang-format.sh -i origin/master
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably use tests/lint/git-black.sh and tests/lint/git-clang-format.sh

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think running Black without the caching is actually more aggressive/safe, but YMMV, happy to change it.

black .
cd rust; which cargo && cargo fmt --all; cd ..


# clean rule
clean:
Expand Down
Empty file modified docker/install/ubuntu_install_arm_compute_lib.sh
100644 → 100755
Empty file.
1 change: 0 additions & 1 deletion docker/install/ubuntu_install_ethosn_driver_stack.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,3 @@ git checkout "$repo_revision"

cd "driver"
scons install_prefix="$install_path" install

262 changes: 262 additions & 0 deletions include/tvm/ir/diagnostic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file diagnostic.h
* \brief A new diagnostic interface for TVM error reporting.
*
* A prototype of the new diagnostic reporting interface for TVM.
*
* Eventually we hope to promote this file to the top-level and
* replace the existing errors.h.
*/

#ifndef TVM_IR_DIAGNOSTIC_H_
#define TVM_IR_DIAGNOSTIC_H_

#include <tvm/ir/module.h>
#include <tvm/ir/span.h>
#include <tvm/parser/source_map.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/support/logging.h>

#include <fstream>
#include <string>
#include <utility>
#include <vector>

namespace tvm {

using tvm::parser::SourceMap;
using tvm::runtime::TypedPackedFunc;

extern const char* kTVM_INTERNAL_ERROR_MESSAGE;

#define ICHECK_INDENT " "

#define ICHECK_BINARY_OP(name, op, x, y) \
if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< kTVM_INTERNAL_ERROR_MESSAGE << std::endl \
<< ICHECK_INDENT << "Check failed: " << #x " " #op " " #y << *(_check_err.str) << ": "

#define ICHECK(x) \
if (!(x)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< kTVM_INTERNAL_ERROR_MESSAGE << ICHECK_INDENT << "Check failed: " #x << " == false: "

#define ICHECK_LT(x, y) ICHECK_BINARY_OP(_LT, <, x, y)
#define ICHECK_GT(x, y) ICHECK_BINARY_OP(_GT, >, x, y)
#define ICHECK_LE(x, y) ICHECK_BINARY_OP(_LE, <=, x, y)
#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y)
#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y)
#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y)
#define ICHECK_NOTNULL(x) \
((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< kTVM_INTERNAL_ERROR_MESSAGE << __INDENT << "Check not null: " #x \
<< ' ', \
(x) : (x)) // NOLINT(*)

/*! \brief The diagnostic level, controls the printing of the message. */
enum class DiagnosticLevel : int {
kBug = 10,
kError = 20,
kWarning = 30,
kNote = 40,
kHelp = 50,
};

class DiagnosticBuilder;

/*! \brief A compiler diagnostic. */
class Diagnostic;

/*! \brief A compiler diagnostic message. */
class DiagnosticNode : public Object {
public:
/*! \brief The level. */
DiagnosticLevel level;
/*! \brief The span at which to report an error. */
Span span;
/*! \brief The diagnostic message. */
String message;

// override attr visitor
void VisitAttrs(AttrVisitor* v) {
v->Visit("level", &level);
v->Visit("span", &span);
v->Visit("message", &message);
}

bool SEqualReduce(const DiagnosticNode* other, SEqualReducer equal) const {
return equal(this->level, other->level) && equal(this->span, other->span) &&
equal(this->message, other->message);
}

static constexpr const char* _type_key = "Diagnostic";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object);
};

class Diagnostic : public ObjectRef {
public:
TVM_DLL Diagnostic(DiagnosticLevel level, Span span, const std::string& message);

static DiagnosticBuilder Bug(Span span);
static DiagnosticBuilder Error(Span span);
static DiagnosticBuilder Warning(Span span);
static DiagnosticBuilder Note(Span span);
static DiagnosticBuilder Help(Span span);

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode);
};

/*!
* \brief A wrapper around std::stringstream to build a diagnostic.
*/
class DiagnosticBuilder {
public:
/*! \brief The level. */
DiagnosticLevel level;

/*! \brief The source name. */
SourceName source_name;

/*! \brief The span of the diagnostic. */
Span span;

template <typename T>
DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*)
stream_ << val;
return *this;
}

DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {}

DiagnosticBuilder(const DiagnosticBuilder& builder)
: level(builder.level), source_name(builder.source_name), span(builder.span) {}

DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {}

operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); }

private:
std::stringstream stream_;
friend class Diagnostic;
};

/*!
* \brief A diagnostic context for recording errors against a source file.
*/
class DiagnosticContext;

/*! \brief Display diagnostics in a given display format.
*
* A diagnostic renderer is responsible for converting the
* raw diagnostics into consumable output.
*
* For example the terminal renderer will render a sequence
* of compiler diagnostics to std::out and std::err in
* a human readable form.
*/
class DiagnosticRendererNode : public Object {
public:
TypedPackedFunc<void(DiagnosticContext ctx)> renderer;

// override attr visitor
void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "DiagnosticRenderer";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object);
};

class DiagnosticRenderer : public ObjectRef {
public:
TVM_DLL DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)> render);
TVM_DLL DiagnosticRenderer()
: DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)>()) {}

void Render(const DiagnosticContext& ctx);

DiagnosticRendererNode* operator->() {
CHECK(get() != nullptr);
return static_cast<DiagnosticRendererNode*>(get_mutable());
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode);
};

class DiagnosticContextNode : public Object {
public:
/*! \brief The Module to report against. */
IRModule module;

/*! \brief The set of diagnostics to report. */
Array<Diagnostic> diagnostics;

/*! \brief The renderer set for the context. */
DiagnosticRenderer renderer;

void VisitAttrs(AttrVisitor* v) {
v->Visit("module", &module);
v->Visit("diagnostics", &diagnostics);
}

bool SEqualReduce(const DiagnosticContextNode* other, SEqualReducer equal) const {
return equal(module, other->module) && equal(diagnostics, other->diagnostics);
}

static constexpr const char* _type_key = "DiagnosticContext";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object);
};

class DiagnosticContext : public ObjectRef {
public:
TVM_DLL DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer);
TVM_DLL static DiagnosticContext Default(const IRModule& source_map);

/*! \brief Emit a diagnostic.
* \param diagnostic The diagnostic to emit.
*/
void Emit(const Diagnostic& diagnostic);

/*! \brief Emit a diagnostic and then immediately attempt to render all errors.
*
* \param diagnostic The diagnostic to emit.
*
* Note: this will raise an exception if you would like to instead continue execution
* use the Emit method instead.
*/
void EmitFatal(const Diagnostic& diagnostic);

/*! \brief Render the errors and raise a DiagnosticError exception. */
void Render();

DiagnosticContextNode* operator->() {
CHECK(get() != nullptr);
return static_cast<DiagnosticContextNode*>(get_mutable());
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticContext, ObjectRef, DiagnosticContextNode);
};

DiagnosticRenderer TerminalRenderer(std::ostream& ostream);

} // namespace tvm
#endif // TVM_IR_DIAGNOSTIC_H_
12 changes: 9 additions & 3 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/ir/function.h>
#include <tvm/ir/type.h>
#include <tvm/node/container.h>
#include <tvm/parser/source_map.h>

#include <string>
#include <unordered_map>
Expand All @@ -53,14 +54,17 @@ class IRModuleNode : public Object {
Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
Map<GlobalTypeVar, TypeData> type_definitions;
/*! \brief The source map for the module. */
parser::SourceMap source_map;

IRModuleNode() {}
IRModuleNode() : source_map() {}

void VisitAttrs(AttrVisitor* v) {
v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("global_type_var_map_", &global_type_var_map_);
v->Visit("source_map", &source_map);
}

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
Expand Down Expand Up @@ -280,12 +284,14 @@ class IRModule : public ObjectRef {
* \param functions Functions in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
* \param map The module source map.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {});
std::unordered_set<String> import_set = {}, parser::SourceMap map = {});

/*! \brief default constructor */
IRModule() : IRModule(Map<GlobalVar, BaseFunc>()) {}
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
/*!
* \brief constructor
* \param n The object pointer.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Span : public ObjectRef {
TVM_DLL Span(SourceName source_name, int line, int end_line, int column, int end_column);

/*! \brief Merge two spans into one which captures the combined regions. */
TVM_DLL Span Merge(const Span& other);
TVM_DLL Span Merge(const Span& other) const;

TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
Expand Down
14 changes: 6 additions & 8 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#ifndef TVM_IR_TRANSFORM_H_
#define TVM_IR_TRANSFORM_H_

#include <tvm/ir/diagnostic.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <tvm/node/container.h>
Expand Down Expand Up @@ -84,23 +85,19 @@ using TraceFunc =
*/
class PassContextNode : public Object {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;

/*! \brief The default optimization level. */
int opt_level{2};

/*! \brief The list of required passes. */
Array<String> required_pass;
/*! \brief The list of disabled passes. */
Array<String> disabled_pass;
/*! \brief Trace function to be invoked before and after each pass. */
TraceFunc trace_func;

/*! \brief The diagnostic context. */
mutable Optional<DiagnosticContext> diag_ctx;
/*! \brief Pass specific configurations. */
Map<String, ObjectRef> config;
/*! \brief Trace function to be invoked before and after each pass. */
TraceFunc trace_func;

PassContextNode() = default;

Expand Down Expand Up @@ -139,6 +136,7 @@ class PassContextNode : public Object {
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
v->Visit("config", &config);
v->Visit("diag_ctx", &diag_ctx);
}

static constexpr const char* _type_key = "transform.PassContext";
Expand Down
Loading