Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thrust support for nms #5116

Merged
merged 11 commits into from
Mar 23, 2020
Merged

Add thrust support for nms #5116

merged 11 commits into from
Mar 23, 2020

Conversation

Laurawly
Copy link
Contributor

@Laurawly Laurawly commented Mar 20, 2020

@kazum @masahi @icemelon9 Please feel free to raise any suggestions or comments.

@masahi
Copy link
Member

masahi commented Mar 21, 2020

@Laurawly Can we refactor thrust_sort and thrust_sort_nms? I see considerable duplicate.

The output of this function.
"""
if axis < 0:
axis = len(data.shape) + axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check for the axis. For example,

assert axis==len(data.shape)-1, "Supports sorting along the last axis only"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In original sort_nms, we do support axis other than -1, I'll add your swap operations in this operator.

@masahi
Copy link
Member

masahi commented Mar 22, 2020

@Laurawly Can we have implementation like this? Big if/else blocks should be put inside thrust_sort_common.

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  CHECK_GE(args.num_args, 5);
  DLTensor* input = args[0];
  DLTensor* valid_count = args[1];
  DLTensor* values_out = args[2];
  DLTensor* indices_out = args[3];
  bool is_ascend = args[4];

  auto data_dtype = DLDataType2String(input->dtype);
  auto out_dtype = DLDataType2String(indices_out->dtype);

  int n_values = input->shape[input->ndim - 1];
  auto get_sort_len = [=](int i) { return n_values; };
  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, data_type, out_dtype);
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  CHECK_GE(args.num_args, 5);
  DLTensor* input = args[0];
  DLTensor* valid_count = args[1];
  DLTensor* values_out = args[2];
  DLTensor* indices_out = args[3];
  bool is_ascend = args[4];

  auto data_dtype = DLDataType2String(input->dtype);
  auto out_dtype = DLDataType2String(indices_out->dtype);

  thrust::device_ptr<IndicesType> valid_count_ptr(static_cast<IndicesType *>(valid_count->data));
  auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i] };
  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, data_type, out_dtype);
}

@Laurawly
Copy link
Contributor Author

@Laurawly Can we have implementation like this? Big if/else blocks should be put inside thrust_sort_common.

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  CHECK_GE(args.num_args, 5);
  DLTensor* input = args[0];
  DLTensor* valid_count = args[1];
  DLTensor* values_out = args[2];
  DLTensor* indices_out = args[3];
  bool is_ascend = args[4];

  auto data_dtype = DLDataType2String(input->dtype);
  auto out_dtype = DLDataType2String(indices_out->dtype);

  int n_values = input->shape[input->ndim - 1];
  auto get_sort_len = [=](int i) { return n_values; };
  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, data_type, out_dtype);
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  CHECK_GE(args.num_args, 5);
  DLTensor* input = args[0];
  DLTensor* valid_count = args[1];
  DLTensor* values_out = args[2];
  DLTensor* indices_out = args[3];
  bool is_ascend = args[4];

  auto data_dtype = DLDataType2String(input->dtype);
  auto out_dtype = DLDataType2String(indices_out->dtype);

  thrust::device_ptr<IndicesType> valid_count_ptr(static_cast<IndicesType *>(valid_count->data));
  auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i] };
  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, data_type, out_dtype);
}

Lemme know if the changes look good to you.

@masahi
Copy link
Member

masahi commented Mar 23, 2020

Lemme know if the changes look good to you

Yes this looks great! Thanks

Copy link
Contributor

@kazum kazum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

@masahi masahi merged commit 25d3542 into apache:master Mar 23, 2020
@masahi
Copy link
Member

masahi commented Mar 23, 2020

Thanks @Laurawly @kazum

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
* add argsort_nms_thrust

* consider valid count in thrust nms sort

* make thrust optional

* typo

* typo

* fix pylint

* address some of the comments

* address more comments

* fix lint

* address more comments

* address more comments
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
* add argsort_nms_thrust

* consider valid count in thrust nms sort

* make thrust optional

* typo

* typo

* fix pylint

* address some of the comments

* address more comments

* fix lint

* address more comments

* address more comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants