Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Detection output op for SSD #6488

Merged
merged 24 commits into from
Jan 2, 2018

Conversation

sweetsky0901
Copy link
Contributor

@sweetsky0901 sweetsky0901 commented Dec 11, 2017

fix #6225

@sweetsky0901 sweetsky0901 requested a review from pkuyym December 12, 2017 02:35
namespace paddle {
namespace operators {

class Detection_output_OpMaker : public framework::OpProtoAndCheckerMaker {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the naming style.
Detection_output_OpMaker --> DetectionOutputOpMaker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


class Detection_output_OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Detection_output_OpMaker(framework::OpProto* proto,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

@sweetsky0901 sweetsky0901 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done
thanks

namespace paddle {
namespace operators {

class Detection_output_OpMaker : public framework::OpProtoAndCheckerMaker {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


class Detection_output_OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Detection_output_OpMaker(framework::OpProto* proto,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
template <typename DeviceContext, typename T>
class Detection_output_Kernel : public framework::OpKernel<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class name use camel-case naming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// KNCHW ==> NHWC
// template <typename T>
template <typename T>
void getBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
Copy link
Contributor

@wanghaox wanghaox Dec 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use lowercase naming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

void getBBoxVarFromPriorData(const T* prior_data, const size_t num,
std::vector<std::vector<T>>& var_vec);
template <typename T>
BBox<T> decodeBBoxWithVar(BBox<T>& prior_bbox,
Copy link
Contributor

@wanghaox wanghaox Dec 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use lowercase naming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

@sweetsky0901 sweetsky0901 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// KNCHW ==> NHWC
// template <typename T>
template <typename T>
void getBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

void getBBoxVarFromPriorData(const T* prior_data, const size_t num,
std::vector<std::vector<T>>& var_vec);
template <typename T>
BBox<T> decodeBBoxWithVar(BBox<T>& prior_bbox,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
template <typename DeviceContext, typename T>
class Detection_output_Kernel : public framework::OpKernel<T> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@sweetsky0901 sweetsky0901 merged commit 90a33dd into PaddlePaddle:develop Jan 2, 2018
@@ -187,6 +186,36 @@ endfunction()
add_subdirectory(math)
add_subdirectory(nccl)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个PR merge的有点早。这里cmake的冲突来自于 #7067 ,但不应该直接加189-218行。

@wangkuiyi
Copy link
Collaborator

void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_loc = context.Input<framework::Tensor>("Loc");
const framework::Tensor* in_conf = context.Input<framework::Tensor>("Conf");
const framework::Tensor* in_priorbox =
context.Input<framework::Tensor>("PriorBox");
auto* out = context.Output<framework::Tensor>("Out");
int num_classes = context.template Attr<int>("num_classes");
int top_k = context.template Attr<int>("top_k");
int nms_top_k = context.template Attr<int>("nms_top_k");
int background_label_id = context.template Attr<int>("background_label_id");
float nms_threshold = context.template Attr<float>("nms_threshold");
float confidence_threshold =
context.template Attr<float>("confidence_threshold");
size_t batch_size = in_conf->dims()[1];
int conf_sum_size = in_conf->numel();
// for softmax
std::vector<int64_t> conf_shape_softmax_vec(
{conf_sum_size / num_classes, num_classes});
framework::DDim conf_shape_softmax(
framework::make_ddim(conf_shape_softmax_vec));
// for knchw => nhwc
std::vector<int64_t> loc_shape_vec({1, in_loc->dims()[1], in_loc->dims()[3],
in_loc->dims()[4],
in_loc->dims()[2] * in_loc->dims()[0]});
std::vector<int64_t> conf_shape_vec(
{1, in_conf->dims()[1], in_conf->dims()[3], in_conf->dims()[4],
in_conf->dims()[2] * in_conf->dims()[0]});
framework::DDim loc_shape(framework::make_ddim(loc_shape_vec));
framework::DDim conf_shape(framework::make_ddim(conf_shape_vec));
framework::Tensor loc_tensor;
framework::Tensor conf_tensor;
loc_tensor.mutable_data<T>(loc_shape, context.GetPlace());
conf_tensor.mutable_data<T>(conf_shape, context.GetPlace());
// for cpu
framework::Tensor loc_cpu;
framework::Tensor conf_cpu;
framework::Tensor priorbox_cpu;
const T* priorbox_data = in_priorbox->data<T>();
transpose_fun<DeviceContext, T>(context, *in_loc, &loc_tensor);
transpose_fun<DeviceContext, T>(context, *in_conf, &conf_tensor);
conf_tensor.Resize(conf_shape_softmax);
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &conf_tensor,
&conf_tensor);
T* loc_data = loc_tensor.data<T>();
T* conf_data = conf_tensor.data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
loc_cpu.mutable_data<T>(loc_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(loc_tensor, platform::CPUPlace(),
context.device_context(), &loc_cpu);
loc_data = loc_cpu.data<T>();
conf_cpu.mutable_data<T>(conf_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(conf_tensor, platform::CPUPlace(),
context.device_context(), &conf_cpu);
conf_data = conf_cpu.data<T>();
priorbox_cpu.mutable_data<T>(in_priorbox->dims(), platform::CPUPlace());
framework::CopyFrom(*in_priorbox, platform::CPUPlace(),
context.device_context(), &priorbox_cpu);
priorbox_data = priorbox_cpu.data<T>();
}
// get decode bboxes
size_t num_priors = in_priorbox->numel() / 8;
std::vector<std::vector<operators::math::BBox<T>>> all_decoded_bboxes;
for (size_t n = 0; n < batch_size; ++n) {
std::vector<operators::math::BBox<T>> decoded_bboxes;
for (size_t i = 0; i < num_priors; ++i) {
size_t prior_offset = i * 8;
size_t loc_pred_offset = n * num_priors * 4 + i * 4;
std::vector<math::BBox<T>> prior_bbox_vec;
math::GetBBoxFromPriorData<T>(priorbox_data + prior_offset, 1,
prior_bbox_vec);
std::vector<std::vector<T>> prior_bbox_var;
math::GetBBoxVarFromPriorData<T>(priorbox_data + prior_offset, 1,
prior_bbox_var);
std::vector<T> loc_pred_data;
for (size_t j = 0; j < 4; ++j)
loc_pred_data.push_back(*(loc_data + loc_pred_offset + j));
math::BBox<T> bbox = math::DecodeBBoxWithVar<T>(
prior_bbox_vec[0], prior_bbox_var[0], loc_pred_data);
decoded_bboxes.push_back(bbox);
}
all_decoded_bboxes.push_back(decoded_bboxes);
}
std::vector<std::map<size_t, std::vector<size_t>>> all_indices;
int num_kept = math::GetDetectionIndices<T>(
conf_data, num_priors, num_classes, background_label_id, batch_size,
confidence_threshold, nms_top_k, nms_threshold, top_k,
all_decoded_bboxes, &all_indices);
if (num_kept <= 0) {
std::vector<int64_t> out_shape_vec({0, 0});
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
out->Resize(out_shape);
return;
}
std::vector<int64_t> out_shape_vec({num_kept, 7});
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
out->mutable_data<T>(out_shape, context.GetPlace());
framework::Tensor out_cpu;
T* out_data = out->data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
out_cpu.mutable_data<T>(out->dims(), platform::CPUPlace());
out_data = out_cpu.data<T>();
}
math::GetDetectionOutput<T>(conf_data, num_kept, num_priors, num_classes,
batch_size, all_indices, all_decoded_bboxes,
out_data);
if (platform::is_gpu_place(context.GetPlace())) {
framework::CopyFrom(out_cpu, platform::CUDAPlace(),
context.device_context(), out);
}
}
-- a single function with more than 100 lines of code -- @wanghaox how could you approve and merge such an extremely low-quality PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Detection Output Operator.
5 participants