Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add no grad when det = 0
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed May 24, 2019
1 parent 8aa6e3b commit 9826d25
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ struct inverse_backward {
}
};

// Here we set grad to zero if det = 0 as a temporary method
// Here we set grad to zero if det = 0
struct StopZeroDetGrad {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int grad_step, DType *grad, DType *det, DType zero_det) {
Expand All @@ -897,8 +897,7 @@ struct StopZeroDetGrad {

// Backward of det(A) is derived from Jacobi's formula.
// The closed form solution is pretty easy when A is invertible.
// For non-invertible A, grad is not backwarded now.
// TODO(arcadiaphy) add implementation for non-invertible case
// For non-invertible A, grad is not backwarded.
struct det_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 1, DType>& ddet,
Expand All @@ -924,9 +923,8 @@ struct det_backward {

// Backward of slogdet(A) is derived from Jacobi's formula.
// The closed form solution is pretty easy when A is invertible.
// For non-invertible A, grad is not backwarded now.
// For non-invertible A, grad is not backwarded.
// Grad is not properly defined on sign, so it's not backwarded either.
// TODO(arcadiaphy) add implementation for non-invertible case
struct slogdet_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 1, DType>& dlogabsdet,
Expand All @@ -945,7 +943,7 @@ struct slogdet_backward {
Shape3(logabsdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \
transpose(LU, Shape3(0, 2, 1));
Stream<xpu> *s = ctx.get_stream<xpu>();
// stop grad for zero det temporarily
// stop grad for zero det
Kernel<StopZeroDetGrad, xpu>::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \
dA.dptr_, logabsdet.dptr_, DType(-INFINITY));
}
Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/la_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,10 @@ If *n>2*, *det* is performed separately on the trailing two dimensions
for all inputs (batch mode).
.. note:: The operator supports float32 and float64 data types only.
.. note:: There is no gradient backwarded when det(A) == 0 because it's
rarely hit upon in float point computation and the Jacobi's
formula on determinant gradient is not computationally efficient
when A is non-invertible.
Examples::
Expand Down

0 comments on commit 9826d25

Please sign in to comment.