Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
remove logdet
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed May 24, 2019
1 parent e3e489d commit 8aa6e3b
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 105 deletions.
46 changes: 0 additions & 46 deletions src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,25 +510,6 @@ struct det {
}
};

// logdet = log(det(A))
struct logdet {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 1, DType>& logdet,
const Tensor<xpu, 3, DType>& LU, const Tensor<xpu, 2, int>& pivot,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 1, DType> sign = ctx.requested[0]
.get_space_typed<xpu, 1, DType>(logdet.shape_, s);
Copy(LU, A, s);
linalg_batch_getrf(LU, pivot, false, s);
using namespace mxnet_op;
using namespace mshadow::expr;
Kernel<SignedLogDet, xpu>::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_,
LU.dptr_, sign.dptr_, logdet.dptr_);
const_cast<Tensor<xpu, 1, DType>&>(logdet) = F<mshadow_op::log>(sign) + logdet;
}
};

// sign = sign(det(A))
// logabsdet = log(abs(det(A)))
struct slogdet {
Expand Down Expand Up @@ -941,33 +922,6 @@ struct det_backward {
}
};

// Backward of logdet(A) is derived from Jacobi's formula.
// The closed form solution is pretty easy when A is invertible.
// For non-invertible A, grad is not backwarded now.
// TODO(arcadiaphy) add implementation for non-invertible case
struct logdet_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 1, DType>& dlogdet,
const Tensor<xpu, 1, DType>& logdet,
const Tensor<xpu, 3, DType>& LU,
const Tensor<xpu, 2, int>& pivot,
const Tensor<xpu, 3, DType>& dA,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
// compute inverse(A) and stores it to LU
linalg_batch_det_backward_helper(LU, pivot, logdet, dA, DType(-INFINITY), ctx);
const_cast<Tensor<xpu, 3, DType>&>(dA) = broadcast_to(reshape(dlogdet, \
Shape3(logdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \
transpose(LU, Shape3(0, 2, 1));
Stream<xpu> *s = ctx.get_stream<xpu>();
// stop grad for zero det temporarily
Kernel<StopZeroDetGrad, xpu>::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \
dA.dptr_, logdet.dptr_, DType(-INFINITY));
}
};

// Backward of slogdet(A) is derived from Jacobi's formula.
// The closed form solution is pretty easy when A is invertible.
// For non-invertible A, grad is not backwarded now.
Expand Down
48 changes: 0 additions & 48 deletions src/operator/tensor/la_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -986,54 +986,6 @@ NNVM_REGISTER_OP(_backward_linalg_det)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpDetBackward<cpu, 1, det_backward>);

NNVM_REGISTER_OP(_linalg_logdet)
.add_alias("linalg_logdet")
.describe(R"code(Compute the log determinant of a matrix.
Input is a tensor *A* of dimension *n >= 2*.
If *n=2*, *A* is a square matrix. We compute:
*out* = *log(det(A))*
If *n>2*, *logdet* is performed separately on the trailing two dimensions
for all inputs (batch mode).
.. note:: The operator supports float32 and float64 data types only.
Examples::
Single matrix inversion
A = [[2., 3.], [1., 4.]]
logdet(A) = [1.609438]
Batch matrix inversion
A = [[[2., 3.], [1., 4.]],
[[1., 2.], [2., 4.]],
[[1., 2.], [4., 3.]]]
logdet(A) = [1.609438, -inf, nan]
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(3)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
{ return std::vector<std::string>{"A"}; })
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
return 1; })
.set_attr<mxnet::FInferShape>("FInferShape", DetShape<1>)
.set_attr<nnvm::FInferType>("FInferType", DetType<1>)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
{ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<FCompute>("FCompute<cpu>", LaOpDetForward<cpu, 1, logdet>)
.set_attr<nnvm::FGradient>("FGradient", ReduceDetGrad<1>{"_backward_linalg_logdet"})
.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix");

NNVM_REGISTER_OP(_backward_linalg_logdet)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
{ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpDetBackward<cpu, 1, logdet_backward>);

NNVM_REGISTER_OP(_linalg_slogdet)
.add_alias("linalg_slogdet")
.describe(R"code(Compute the sign and log of the determinant of a matrix.
Expand Down
6 changes: 0 additions & 6 deletions src/operator/tensor/la_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ NNVM_REGISTER_OP(_linalg_det)
NNVM_REGISTER_OP(_backward_linalg_det)
.set_attr<FCompute>("FCompute<gpu>", LaOpDetBackward<gpu, 1, det_backward>);

NNVM_REGISTER_OP(_linalg_logdet)
.set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 1, logdet>);

NNVM_REGISTER_OP(_backward_linalg_logdet)
.set_attr<FCompute>("FCompute<gpu>", LaOpDetBackward<gpu, 1, logdet_backward>);

NNVM_REGISTER_OP(_linalg_slogdet)
.set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 2, slogdet>);

Expand Down
5 changes: 0 additions & 5 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6536,11 +6536,6 @@ def test_laop_6():
test_det = mx.sym.linalg.det(data)
check_fw(test_det, [a], [r])
check_grad(test_det, [a])
# logdet
r = np.log(np.linalg.det(a))
test_logdet = mx.sym.linalg.logdet(data)
check_fw(test_logdet, [a], [r])
check_grad(test_logdet, [a])
# test slogdet
r1 = np.array([1., 1., 1.])
r2 = np.log(np.abs(np.linalg.det(a)))
Expand Down

0 comments on commit 8aa6e3b

Please sign in to comment.