Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 23, 2022
1 parent f727118 commit dc5b593
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 32 deletions.
57 changes: 31 additions & 26 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,7 @@ def _shape_of(expr):
raise TaichiCompilationError(f"Cannot get shape of type {dt}")
return ret


@taichi_scope
def _reduce(e, func):
s = _shape_of(e)
Expand All @@ -1515,6 +1516,7 @@ def _reduce(e, func):

return acc


@taichi_scope
def _bind(m, func):
s = _shape_of(m)
Expand Down Expand Up @@ -1618,6 +1620,7 @@ def matmul(m1, m2):
from taichi.lang.matrix import make_matrix
return make_matrix(entries)


@taichi_scope
def transpose(m, in_place=False):
s = _shape_of(m)
Expand All @@ -1632,13 +1635,13 @@ def transpose(m, in_place=False):
return m
else:
from taichi.lang.matrix import make_matrix
entries = [[0 for _ in range(s[0])]
for _ in range(s[1])]
entries = [[0 for _ in range(s[0])] for _ in range(s[1])]
for i in range(s[0]):
for j in range(s[1]):
entries[j][i] = m[i, j]
return make_matrix(entries)


@taichi_scope
def determinant(a):
"""Returns the determinant of this matrix.
Expand All @@ -1661,9 +1664,9 @@ def determinant(a):
if n == 2 and m == 2:
return a[0, 0] * a[1, 1] - a[0, 1] * a[1, 0]
if n == 3 and m == 3:
return a[0, 0] * (a[1, 1] * a[2, 2] - a[2, 1] * a[1, 2]) - a[
1, 0] * (a[0, 1] * a[2, 2] - a[2, 1] * a[0, 2]) + a[
2, 0] * (a[0, 1] * a[1, 2] - a[1, 1] * a[0, 2])
return a[0, 0] * (a[1, 1] * a[2, 2] - a[2, 1] * a[1, 2]) - a[1, 0] * (
a[0, 1] * a[2, 2] - a[2, 1] * a[0, 2]) + a[2, 0] * (
a[0, 1] * a[1, 2] - a[1, 1] * a[0, 2])
if n == 4 and m == 4:
n = 4

Expand All @@ -1675,15 +1678,16 @@ def E(x, y):
det = det + (-1.0)**i * (
a[i, 0] *
(E(i + 1, 1) *
(E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) -
E(i + 2, 1) *
(E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) +
E(i + 3, 1) *
(E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3))))
(E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) -
E(i + 2, 1) *
(E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) +
E(i + 3, 1) *
(E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3))))
return det
raise Exception(
"Determinants of matrices with sizes >= 5 are not supported")


@taichi_scope
def inverse(m):
"""Returns the inverse of this matrix.
Expand All @@ -1708,8 +1712,8 @@ def inverse(m):
return make_matrix([1 / m[0, 0]])
if n == 2:
inv_determinant = impl.expr_init(1.0 / determinant(m))
return inv_determinant * make_matrix([[m[
1, 1], -m[0, 1]], [-m[1, 0], m[0, 0]]])
return inv_determinant * make_matrix([[m[1, 1], -m[0, 1]],
[-m[1, 0], m[0, 0]]])
if n == 3:
n = 3
inv_determinant = impl.expr_init(1.0 / determinant(m))
Expand All @@ -1734,17 +1738,17 @@ def E(x, y):

for i in range(n):
for j in range(n):
entries[j][i] = inv_determinant * (-1)**(i + j) * ((
E(i + 1, j + 1) *
(E(i + 2, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) *
(E(i + 1, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) *
(E(i + 1, j + 2) * E(i + 2, j + 3) -
E(i + 2, j + 2) * E(i + 1, j + 3))))
entries[j][i] = inv_determinant * (-1)**(i + j) * (
(E(i + 1, j + 1) *
(E(i + 2, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) *
(E(i + 1, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) *
(E(i + 1, j + 2) * E(i + 2, j + 3) -
E(i + 2, j + 2) * E(i + 1, j + 3))))
return make_matrix(entries)
raise Exception(
"Inversions of matrices with sizes >= 5 are not supported")
raise Exception("Inversions of matrices with sizes >= 5 are not supported")


def rows(rows):
"""Constructs a matrix by concatenating a list of
Expand Down Expand Up @@ -1779,12 +1783,11 @@ def rows(rows):
return make_matrix([[row[i] for i in range(shape[0])] for row in rows])
if isinstance(rows[0], list):
for row in rows:
assert len(row) == len(
rows[0]), "Input lists share the same shape"
assert len(row) == len(rows[0]), "Input lists share the same shape"
# l-value copy:
return make_matrix([[x for x in row] for row in rows])
raise Exception(
"Cols/rows must be a list of lists, or a list of vectors")
raise Exception("Cols/rows must be a list of lists, or a list of vectors")


def cols(cols):
"""Constructs a Matrix instance by concatenating Vectors/lists column by column.
Expand All @@ -1809,6 +1812,7 @@ def cols(cols):
"""
return transpose(rows(cols))


def trace(m):
"""The sum of a matrix diagonal elements.
Expand All @@ -1831,6 +1835,7 @@ def trace(m):
_sum = _sum + m[i, i]
return _sum


__all__ = [
"acos", "asin", "atan2", "atomic_and", "atomic_or", "atomic_xor",
"atomic_max", "atomic_sub", "atomic_min", "atomic_add", "bit_cast",
Expand Down
3 changes: 2 additions & 1 deletion taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
if (auto local_load = load_stmt->cast<LocalLoadStmt>()) {
std::vector<Stmt *> result;
for (auto &address : local_load->src.data) {
if (std::find(result.begin(), result.end(), address.var) == result.end()) {
if (std::find(result.begin(), result.end(), address.var) ==
result.end()) {
result.push_back(address.var);
}
}
Expand Down
5 changes: 3 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
auto elem_value_type = tlctx->get_data_type(elem_type);
if (elem_type->is_primitive(PrimitiveTypeID::f32) ||
elem_type->is_primitive(PrimitiveTypeID::f16)) {
elem_value = builder->CreateFPExt(elem_value, tlctx->get_data_type(PrimitiveType::f64));
elem_value = builder->CreateFPExt(
elem_value, tlctx->get_data_type(PrimitiveType::f64));
elem_value_type = tlctx->get_data_type(PrimitiveType::f64);
}
types.push_back(elem_value_type);
Expand Down Expand Up @@ -124,7 +125,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
formats += "%s";
}
TI_ASSERT_INFO(num_contents < 32,
"CUDA `print()` doesn't support more than 32 entries");
"CUDA `print()` doesn't support more than 32 entries");
}

llvm_val[stmt] = create_print(formats, types, values);
Expand Down
3 changes: 2 additions & 1 deletion taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,8 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) {
auto elem_value = builder->CreateExtractElement(value, i);
if (elem_type->is_primitive(PrimitiveTypeID::f32) ||
elem_type->is_primitive(PrimitiveTypeID::f16))
elem_value = builder->CreateFPExt(elem_value, tlctx->get_data_type(PrimitiveType::f64));
elem_value = builder->CreateFPExt(
elem_value, tlctx->get_data_type(PrimitiveType::f64));
args.push_back(elem_value);
}
formats += data_type_format(arg_stmt->ret_type);
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access,
}
}
if (regular) {
// if (!stmt->ret_type->is<TensorType>() || contain_variable(reach_in, stmt))
// if (!stmt->ret_type->is<TensorType>() || contain_variable(reach_in,
// stmt))
result = get_store_forwarding_data(alloca, i);
}
} else if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ class ConstantFold : public BasicStmtVisitor {
// TI_TRACE("Got constant idx: {}", idx_val);
// if (ptr_offset->origin->ret_type->is<TensorType>()) {
// if (auto matrix = ptr_offset->origin->cast<MatrixInitStmt>()) {
// TI_ASSERT_INFO(idx_val < matrix->values.size(), "Matrix indexing out-of-bound on value {}", stmt->name());
// TI_ASSERT_INFO(idx_val < matrix->values.size(), "Matrix
// indexing out-of-bound on value {}", stmt->name());
// stmt->replace_usages_with(matrix->values[idx_val]);
// modifier.erase(stmt);
// }
Expand Down

0 comments on commit dc5b593

Please sign in to comment.