From 9826d25e0f9448026f211c1e7a4e92ae3d211eb2 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 19:50:08 +0800 Subject: [PATCH] add no grad when det = 0 --- src/operator/tensor/la_op-inl.h | 10 ++++------ src/operator/tensor/la_op.cc | 4 ++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index de27187bca9a..42d1f4527575 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -884,7 +884,7 @@ struct inverse_backward { } }; -// Here we set grad to zero if det = 0 as a temporary method +// Here we set grad to zero if det = 0 struct StopZeroDetGrad { template MSHADOW_XINLINE static void Map(int i, int grad_step, DType *grad, DType *det, DType zero_det) { @@ -897,8 +897,7 @@ struct StopZeroDetGrad { // Backward of det(A) is derived from Jacobi's formula. // The closed form solution is pretty easy when A is invertible. -// For non-invertible A, grad is not backwarded now. -// TODO(arcadiaphy) add implementation for non-invertible case +// For non-invertible A, grad is not backwarded. struct det_backward { template static void op(const Tensor& ddet, @@ -924,9 +923,8 @@ struct det_backward { // Backward of slogdet(A) is derived from Jacobi's formula. // The closed form solution is pretty easy when A is invertible. -// For non-invertible A, grad is not backwarded now. +// For non-invertible A, grad is not backwarded. // Grad is not properly defined on sign, so it's not backwarded either. -// TODO(arcadiaphy) add implementation for non-invertible case struct slogdet_backward { template static void op(const Tensor& dlogabsdet, @@ -945,7 +943,7 @@ struct slogdet_backward { Shape3(logabsdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); Stream *s = ctx.get_stream(); - // stop grad for zero det temporarily + // stop grad for zero det Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ dA.dptr_, logabsdet.dptr_, DType(-INFINITY)); } diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index dade5692800d..c426e52a844f 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -952,6 +952,10 @@ If *n>2*, *det* is performed separately on the trailing two dimensions for all inputs (batch mode). .. note:: The operator supports float32 and float64 data types only. +.. note:: There is no gradient backwarded when det(A) == 0 because it's + rarely hit upon in float point computation and the Jacobi's + formula on determinant gradient is not computationally efficient + when A is non-invertible. Examples::