Skip to content

Commit

Permalink
Pass the correctness check
Browse files Browse the repository at this point in the history
  • Loading branch information
LRY89757 committed Oct 19, 2022
1 parent b28e6f9 commit 19e4511
Showing 1 changed file with 17 additions and 28 deletions.
45 changes: 17 additions & 28 deletions src/layer/grid_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace ncnn {

Grid_Sample::Grid_Sample()
{
// one_blob_only = true;
one_blob_only = false;
support_inplace = false;
}
Expand Down Expand Up @@ -64,24 +65,24 @@ grid_sample_unormalize(int w, float coordx, int align_corner)
return align_corner ? (coordx + 1) / 2.f * (w - 1) : ((coordx + 1) * w - 1) / 2.f;
}

static int border_coord(float coord, int border)
static float border_coord(float coord, int border)
{
return std::min(border, std::max((int)coord, 0));
return std::min(static_cast<float>(border), std::max(coord, static_cast<float>(0)));
}

// Reflects coordinates until they fall between low and high (inclusive).
static int reflect_coord(int coord, int low, int high)
static float reflect_coord(float coord, int low, int high)
{
if (low == high)
{
return 0;
}
int min = low / 2;
int span = static_cast<int>(high - low) / 2;
float min = static_cast<int>(low) / 2;
float span = static_cast<float>(high - low) / 2;
coord = std::fabs(coord - min);
// `fmod` returns same sign as `coord`, which is positive after the `fabs` above.
int extra = std::fmod(coord, span);
int flips = std::floor(coord / span);
float extra = std::fmod(coord, span);
int flips = static_cast<int>(std::floor(coord / span));

return flips % 2 ? (span - extra + min) : (extra + min);
}
Expand All @@ -94,7 +95,6 @@ static float compute_coord(float sx, int w,
{
// clip coordinates to image borders
sx = border_coord(sx, w - 1);
// printf("here!\n";
}
else if (padding_mode == 3) // reflection
{
Expand All @@ -108,7 +108,7 @@ static float compute_coord(float sx, int w,
sx = reflect_coord(sx, -1, 2 * w - 1);
}
// clip coordinates to image borders
sx = border_coord(sx, w);
sx = border_coord(sx, w - 1);
}
return sx;
}
Expand All @@ -133,12 +133,12 @@ static float get_value_bounded(const float* data, float x, float y, int w, int h
int padding_mode, int align_corner)
{
x = compute_coord(x, w, padding_mode, align_corner);
y = compute_coord(y, w, padding_mode, align_corner);
y = compute_coord(y, h, padding_mode, align_corner);

// int ix = static_cast<int>(x);
int ix = static_cast<int>(x);
int iy = static_cast<int>(y);

return in_bounds(x, y, w, h) ? data[iy * w + h] : 0.f;
return in_bounds(ix, iy, w, h) ? data[iy * w + ix] : 0.f;
}

// Based on
Expand Down Expand Up @@ -198,6 +198,7 @@ int Grid_Sample::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
return -100;
if (resize_type == 1) // bilinear
{
// GSample_bilinear(src, dst, grid, align_corner, padding_mode);
#pragma omp parallel for num_threads(opt.num_threads) collapse(2)
for (int row = 0; row < outh; row++)
{
Expand Down Expand Up @@ -285,7 +286,7 @@ int Grid_Sample::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&

if (in_bounds(x, y, w, h))
{
ans += ptr[y * h + w];
ans += ptr[y * w + x];
}

outptr[row * outw + col] = ans;
Expand All @@ -300,15 +301,14 @@ int Grid_Sample::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
{
for (int col = 0; col < outw; col++)
{
// const float* gridptr = grid.channel(row).row(col);
const float* gridptr = grid.depth(row).row(col);

// get the coordinate of every output point
float ix = gridptr[0];
float iy = gridptr[1];

ix = grid_sample_unormalize(ix, w, align_corner);
iy = grid_sample_unormalize(iy, h, align_corner);
ix = grid_sample_unormalize(w, ix, align_corner);
iy = grid_sample_unormalize(h, iy, align_corner);

float xnw = std::floor(ix);
float ynw = std::floor(iy);
Expand Down Expand Up @@ -345,17 +345,6 @@ int Grid_Sample::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
}
}
}

// float* outptr = top_blob.channel(0);
// outptr[0] = grid.c;
// outptr[1] = grid.d;
// outptr[2] = grid.h;
// outptr[3] = grid.w;

// outptr[4] = bottom_blob.c;
// outptr[5] = bottom_blob.d;
// outptr[6] = bottom_blob.h;
// outptr[7] = bottom_blob.w;
}

if (dims == 4)
Expand All @@ -365,4 +354,4 @@ int Grid_Sample::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
return 0;
}

} // namespace ncnn
} // namespace ncnn

0 comments on commit 19e4511

Please sign in to comment.