Skip to content

Commit

Permalink
Bug fix for inferring reduction dimensions in NDL branch.
Browse files Browse the repository at this point in the history
  • Loading branch information
cha-zhang authored and mahilleb-msft committed Apr 21, 2017
1 parent 43e5494 commit 260ca86
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions Source/ComputationNetworkLib/ConvolutionalNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,22 @@ class ConvolutionNodeBase : public ComputationNode<ElemType>
shape = TensorShape(dims);
}
protected:
// infer reduction dimensions if m_convolution2D is true, for legacy NDL branch
void InferConvolution2DReductionDims(const TensorShape& inputShape, size_t numChannels)
{
size_t kW = m_kernelShape[0];
size_t kH = m_kernelShape[1];
size_t sW = m_stride[0];
size_t sH = m_stride[1];
m_kernelShape = TensorShape(kW, kH, numChannels);
m_stride = TensorShape(sW, sH, numChannels);
size_t filterRank = 2;
FixVectorShape(filterRank, inputShape.size(), m_autoPad, false);
FixTensorShape(filterRank, inputShape.size(), m_lowerPad, 0);
FixTensorShape(filterRank, inputShape.size(), m_upperPad, 0);
FixVectorShape(filterRank, inputShape.size(), m_sharing, true);
}

// infer reduction dimensions if not given
void InferReductionDims(const TensorShape& inputShape, const TensorShape& fromShape)
{
Expand Down Expand Up @@ -439,13 +455,10 @@ class ConvolutionNode : public ConvolutionNodeBase<ElemType>, public NumInputs<2
auto inDims = ImageDimensions(GetInputSampleLayout(inputIdx), m_imageLayout);
// inputShape is used in ConvolveGeometry which supports only CHW layout.
inputShape = inDims.AsTensorShape(ImageLayoutKind::CHW);
InferConvolution2DReductionDims(inputShape, inDims.m_numChannels);

size_t kW = m_kernelShape[0];
size_t kH = m_kernelShape[1];
size_t sW = m_stride[0];
size_t sH = m_stride[1];
m_kernelShape = TensorShape(kW, kH, inDims.m_numChannels);
m_stride = TensorShape(sW, sH, inDims.m_numChannels);

size_t mapCount = m_mapCount.GetNumElements();
size_t weightCols = kW * kH * inDims.m_numChannels;

Expand Down

0 comments on commit 260ca86

Please sign in to comment.