Skip to content

Commit

Permalink
SE layer fix when not using fused kernel (#852)
Browse files Browse the repository at this point in the history
* SE layer fix when not using fused kernel
- only SE layer needs transposed weights.
- need to store non-transposed weights too in case we have to fall back.
  • Loading branch information
ankan-ban committed May 17, 2019
1 parent fa926e5 commit 07babd1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/neural/cuda/layers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ SELayer<DataType>::SELayer(BaseLayer<DataType>* ip, int fc1Outputs,
ReportCUDAErrors(cudaMalloc(&w1_, C * numFc1Out_ * sizeof(DataType)));
ReportCUDAErrors(cudaMalloc(&w2_, 2 * C * numFc1Out_ * sizeof(DataType)));

if (kUseFusedSELayer && nhwc_) {
ReportCUDAErrors(cudaMalloc(&w1_t_, C * numFc1Out_ * sizeof(DataType)));
ReportCUDAErrors(cudaMalloc(&w2_t_, 2 * C * numFc1Out_ * sizeof(DataType)));
}

ReportCUDAErrors(cudaMalloc(&b1_, numFc1Out_ * sizeof(DataType)));
ReportCUDAErrors(cudaMalloc(&b2_, 2 * C * sizeof(DataType)));

Expand Down Expand Up @@ -366,26 +371,27 @@ void SELayer<half>::LoadWeights(float* w1, float* b1, float* w2, float* b2,
std::vector<float> temp(weight_size2);

// Weight for the first FC layer.
ReportCUDAErrors(
cudaMemcpy(scratch, w1, weight_size1, cudaMemcpyHostToDevice));
copyTypeConverted((half*)w1_, (float*)scratch, num_weights1);
if (kUseFusedSELayer && nhwc_) {
// transposed copy for fused SE kernel
cpuTranspose(temp.data(), w1, numFc1Out_, C);
ReportCUDAErrors(
cudaMemcpy(scratch, temp.data(), weight_size1, cudaMemcpyHostToDevice));
} else {
ReportCUDAErrors(
cudaMemcpy(scratch, w1, weight_size1, cudaMemcpyHostToDevice));
copyTypeConverted((half*)w1_t_, (float*)scratch, num_weights1);
}
copyTypeConverted((half*)w1_, (float*)scratch, num_weights1);

// Weight for the second FC layer.
ReportCUDAErrors(
cudaMemcpy(scratch, w2, weight_size2, cudaMemcpyHostToDevice));
copyTypeConverted((half*)w2_, (float*)scratch, num_weights2);
if (kUseFusedSELayer && nhwc_) {
cpuTranspose(temp.data(), w2, 2 * C, numFc1Out_);
ReportCUDAErrors(
cudaMemcpy(scratch, temp.data(), weight_size2, cudaMemcpyHostToDevice));
} else {
ReportCUDAErrors(
cudaMemcpy(scratch, w2, weight_size2, cudaMemcpyHostToDevice));
copyTypeConverted((half*)w2_t_, (float*)scratch, num_weights2);
}
copyTypeConverted((half*)w2_, (float*)scratch, num_weights2);

// Bias for the first FC layer.
ReportCUDAErrors(cudaMemcpy(scratch, b1, numFc1Out_ * sizeof(float),
Expand Down Expand Up @@ -443,8 +449,8 @@ void SELayer<half>::Eval(int N, half* output, const half* input,
cudnnHandle_t /*cudnn*/, cublasHandle_t cublas) {
bool se_done = false;
if (kUseFusedSELayer && nhwc_) {
se_done = Se_Fp16_NHWC(N, C, numFc1Out_, output, input2, input, w1_, b1_,
w2_, b2_, bPrev_);
se_done = Se_Fp16_NHWC(N, C, numFc1Out_, output, input2, input, w1_t_, b1_,
w2_t_, b2_, bPrev_);
}
if (!se_done) {
assert(output == input2);
Expand Down
2 changes: 2 additions & 0 deletions src/neural/cuda/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ class SELayer : public BaseLayer<DataType> {

private:
DataType* w1_ = nullptr;
DataType* w1_t_ = nullptr; // transposed copy used by fused SE kernel
DataType* b1_ = nullptr;
DataType* w2_ = nullptr;
DataType* w2_t_ = nullptr;
DataType* b2_ = nullptr;
DataType* bPrev_ = nullptr;
int numFc1Out_;
Expand Down

0 comments on commit 07babd1

Please sign in to comment.