From c08ca90ac939d6b211bfdba47a7473e3454653cf Mon Sep 17 00:00:00 2001 From: zixuanweeei Date: Mon, 17 Jun 2019 20:45:21 +0800 Subject: [PATCH 1/4] Effective multinomial --- src/operator/random/sample_multinomial_op.h | 41 ++++++++++++--------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index 377df4f313da..c0f6d479631e 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -122,25 +122,28 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs, struct SampleMultinomialKernel { template MSHADOW_XINLINE static void Map(int i, index_t K, index_t M, - DType* dist, float* uniform, IType* out, - DType* prob) { + DType* dist, float* uniform, float* cum_table, + IType* out, DType* prob) { + cum_table[i*K] = 0.0; + // CDF table + for (index_t c = 1; c < K + 1; ++c) { + cum_table[i*K + c] = cum_table[i*K + c - 1] + dist[i*K + c - 1]; + } for (index_t j = 0; j < M; ++j) { + index_t left = 0, right = K; + index_t middle = left + (right - left) / 2; DType loc = static_cast(uniform[i*M + j]); - DType acc = 0; - bool found = false; - for (index_t k = 0; k < K; ++k) { - acc += dist[i*K + k]; - if (acc > loc) { - found = true; - out[i*M + j] = static_cast(k); - if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + k]); - break; + while (right - left > 0) { + middle = left + (right - left) / 2; + DType cum_prob = cum_table[i*K + middle]; + if (cum_prob < loc) { + left = middle + 1; + } else { + right = middle; } } - if (!found) { - out[i*M + j] = static_cast(K-1); - if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + K - 1]); - } + out[i*M + j] = static_cast(middle); + if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + middle - 1]); } } }; @@ -163,12 +166,14 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { Random *prnd = ctx.requested[0].get_random(s); - Tensor uniform = - ctx.requested[1].get_space_typed(Shape1(N*M), s); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(N*M + N*(K + 1)), s); + Tensor uniform(workspace.dptr_, Shape1(N*M)); prnd->SampleUniform(&uniform, 0, 1); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, { Kernel::Launch( - s, N, K, M, inputs[0].dptr(), uniform.dptr_, outputs[0].dptr(), + s, N, K, M, inputs[0].dptr(), workspace.dptr_, workspace.dptr_ + N*M, + outputs[0].dptr(), param.get_prob ? outputs[1].dptr() : nullptr); }); }); From a5484ac1e9df0ee9617c881013b6099fd8a260e1 Mon Sep 17 00:00:00 2001 From: zixuanweeei Date: Tue, 18 Jun 2019 10:47:19 +0800 Subject: [PATCH 2/4] Meaningful uniform data pointer as input --- src/operator/random/sample_multinomial_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index c0f6d479631e..ae8f73088b9f 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -172,7 +172,7 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs, prnd->SampleUniform(&uniform, 0, 1); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, { Kernel::Launch( - s, N, K, M, inputs[0].dptr(), workspace.dptr_, workspace.dptr_ + N*M, + s, N, K, M, inputs[0].dptr(), uniform.dptr_, workspace.dptr_ + N*M, outputs[0].dptr(), param.get_prob ? outputs[1].dptr() : nullptr); }); From 5540eec02a616829a9efc39e64a1115cccd47456 Mon Sep 17 00:00:00 2001 From: zixuanweeei Date: Tue, 18 Jun 2019 13:53:51 +0800 Subject: [PATCH 3/4] Remove beginning Zeros from CDFs --- src/operator/random/sample_multinomial_op.h | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index ae8f73088b9f..d6bb0241d179 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -124,10 +124,11 @@ struct SampleMultinomialKernel { MSHADOW_XINLINE static void Map(int i, index_t K, index_t M, DType* dist, float* uniform, float* cum_table, IType* out, DType* prob) { - cum_table[i*K] = 0.0; + float acc = 0.0; // CDF table - for (index_t c = 1; c < K + 1; ++c) { - cum_table[i*K + c] = cum_table[i*K + c - 1] + dist[i*K + c - 1]; + for (index_t c = 0; c < K; ++c) { + acc += dist[i*K + c]; + cum_table[i*K + c] = acc; } for (index_t j = 0; j < M; ++j) { index_t left = 0, right = K; @@ -142,8 +143,8 @@ struct SampleMultinomialKernel { right = middle; } } - out[i*M + j] = static_cast(middle); - if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + middle - 1]); + out[i*M + j] = static_cast(left); + if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + left]); } } }; @@ -167,7 +168,7 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs, MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { Random *prnd = ctx.requested[0].get_random(s); Tensor workspace = - ctx.requested[1].get_space_typed(Shape1(N*M + N*(K + 1)), s); + ctx.requested[1].get_space_typed(Shape1(N*M + N*K), s); Tensor uniform(workspace.dptr_, Shape1(N*M)); prnd->SampleUniform(&uniform, 0, 1); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, { From 6c3c49b89caed65598dc2e19902fe70a17491128 Mon Sep 17 00:00:00 2001 From: zixuanweeei Date: Tue, 18 Jun 2019 15:16:30 +0800 Subject: [PATCH 4/4] Double precision for accumulated var --- src/operator/random/sample_multinomial_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index d6bb0241d179..5a0b9bb21acb 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -124,11 +124,11 @@ struct SampleMultinomialKernel { MSHADOW_XINLINE static void Map(int i, index_t K, index_t M, DType* dist, float* uniform, float* cum_table, IType* out, DType* prob) { - float acc = 0.0; + double acc = 0.0; // CDF table for (index_t c = 0; c < K; ++c) { acc += dist[i*K + c]; - cum_table[i*K + c] = acc; + cum_table[i*K + c] = static_cast(acc); } for (index_t j = 0; j < M; ++j) { index_t left = 0, right = K;