Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Add negative log likelihood loss (nll_loss) op #8056

Merged
merged 7 commits into from
Jun 25, 2021

Conversation

zhuzilin
Copy link
Contributor

This PR adds the nll_loss op to relay and topi, so that we could translate aten::nll_loss in pytorch frontend. The nll_loss is the underlying function for cross entropy in pytorch and is very important in the way of supporting training in tvm.

Thank you for your time on reviewing this PR :).

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for submitting this op. I know some people are looking forward to having it available. I've left some comments for things to fix (mostly around documentation).

include/tvm/topi/nn.h Show resolved Hide resolved
python/tvm/relay/op/nn/nn.py Show resolved Hide resolved
python/tvm/relay/op/nn/nn.py Show resolved Hide resolved
python/tvm/topi/nn/loss.py Outdated Show resolved Hide resolved
src/relay/op/nn/nn.cc Outdated Show resolved Hide resolved
src/relay/op/nn/nn.cc Outdated Show resolved Hide resolved
src/relay/op/nn/nn.cc Outdated Show resolved Hide resolved
@zhuzilin
Copy link
Contributor Author

@tkonolige
Thank you for your review! I've added the docs. And because the lint check doesn't allow parameter name as input, I renamed the parameters to be the same form as cross_entropy and cross_entropy_with_logits (the old one was the same as aten::nll_loss).

Copy link
Contributor

@altanh altanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, excited to see others working on training support! I left some requests. A few more things:

  • please add unit tests for the op
  • are you planning on registering a gradient for this operator?

include/tvm/topi/nn.h Outdated Show resolved Hide resolved
targets : tvm.relay.Expr
The target value of each prediction.

weights : tvm.relay.Expr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make weights optional, like PyTorch? weights=1 is a pretty common case I believe and we could add a fast path implementation that skips the scaling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@altanh We can make weights an optional parameter. I wonder if there are any example of a relay op with an optional tensor parameter that I can learn from. And also, how should we deal with gradient of an optional parameter? BTW, is there any better way we can mark a parameter as "no need for gradient" instead of returning an one_like grad?

Copy link
Contributor

@altanh altanh May 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not sure, that's a good point- let's just keep the weights for now. As for not needing a gradient, currently there is no other way than just putting some dummy value. It might make sense for us to introduce a stop_gradient dummy op which cuts the gradient computation from going further at undifferentiable arguments (this can be a future PR). Thanks!

<< "weights shape = " << weights->shape;
ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float())
<< "NLLLossRel: predictions and weights should be of the same floating type.";
ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int type.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you replace these ICHECKs with Diagnostics?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a plan on changing all ICHECKS to diagnostics? Or any rules on when to use ICHECK and when to use diagnostics?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically, if the error can happen due to user input (e.g. using wrong shapes), we should definitely use diagnostics. ICHECK should be reserved only for internal compiler checks that should basically never fail unless there's a bug somewhere. The diagnostic framework is fairly new so a lot of old code still uses ICHECK incorrectly, we just need to slowly go through and update them unfortunately

python/tvm/relay/op/nn/nn.py Outdated Show resolved Hide resolved
python/tvm/topi/nn/loss.py Outdated Show resolved Hide resolved
include/tvm/topi/nn.h Outdated Show resolved Hide resolved
*/
inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
std::string reduction = "mean", int ignore_index = -100,
const std::string name = "nll_loss", const std::string tag = kBroadcast) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the tag be kOpaque to match the Relay pattern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@altanh I am confused with the tag in topi (ones in topi/tags.h) and the OpPatternKind in relay/op_attr_types.h. It seems that they are not matched. Could you tell me the usage of them in tvm? Thank you~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that is confusing.. I'll get back to you on this soon

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@altanh Could you take a look at the tests? Thank you~ And I wonder if there is any update with the tag and OpPatternKind?

I don't have much of an update for the tag, maybe you could try leaving it empty string?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tag here is topi-level, sometimes we use it to identify a specific compute operation during schedule, otherwise we can leave it empty

@zhuzilin
Copy link
Contributor Author

zhuzilin commented May 19, 2021

@tkonolige Sorry about the incomplete formula... I don't know what I was thinking about yesterday 😱

please add unit tests for the op

I found that the ops are categorized into different levels. Could you tell me how to decide to level of an op? And I could add the test to the correct place.

are you planning on registering a gradient for this operator?

@altanh Yes, I'm trying to write one for this op. I will try to implement one like the cpu version of pytorch or onnxruntime).

@altanh
Copy link
Contributor

altanh commented May 20, 2021

Could you tell me how to decide to level of an op? And I could add the test to the correct place.

Yep, sorry that the documentation of support level is lacking! We do have a doc at (https://tvm.apache.org/docs/langref/relay_op.html#overview-of-operators) but it seems outdated and missing quite a bit, so this needs to be improved. For this op, I think we can go with level 10 which matches the existing nn.cross_entropy op. I think we might want to remove this operator in the future since NLLLoss seems to be more general and subsumes it in my opinion, but that will be a separate PR :)

Thanks!

@zhuzilin
Copy link
Contributor Author

zhuzilin commented May 21, 2021

@altanh Thanks. I'll add the test soon. Could you also check the comments above relate to the optional weight, Diagnostics and the tag?

@altanh
Copy link
Contributor

altanh commented May 24, 2021

@altanh Thanks. I'll add the test soon. Could you also check the comments above relate to the optional weight, Diagnostics and the tag?

replied

@zhuzilin
Copy link
Contributor Author

@altanh @tkonolige Sorry for the late update. I've just added the test of nll_loss for relay, topi and pytorch frontend. Could you have another look?

@zhuzilin
Copy link
Contributor Author

zhuzilin commented Jun 2, 2021

@altanh Could you take a look at the tests? Thank you~ And I wonder if there is any update with the tag and OpPatternKind?

@altanh
Copy link
Contributor

altanh commented Jun 2, 2021

It looks like we have some CI problems currently

@zhuzilin
Copy link
Contributor Author

@altanh @tkonolige @tqchen Is there anything I can help with this pr... It has been stuck for a while....

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being slow on this. Could you make just a couple more changes.

python/tvm/relay/op/nn/_nn.py Outdated Show resolved Hide resolved
tests/python/topi/python/test_topi_loss.py Outdated Show resolved Hide resolved
tests/python/topi/python/test_topi_loss.py Outdated Show resolved Hide resolved
@zhuzilin
Copy link
Contributor Author

@tkonolige Thank you for the reviews! I've modified the parts based on them. Could you take another look?

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last change and then I think this will be good to go.

@zhuzilin
Copy link
Contributor Author

One last change and then I think this will be good to go.

@tkonolige Could you point out the change that I need to make? Thank you~

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot to include the actual change I wanted.

@@ -577,6 +577,49 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT")


@tvm.testing.uses_gpu
def test_nll_loss():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch this to use parameterize_targets.

@zhuzilin
Copy link
Contributor Author

@tkonolige I've updated the test. But there seems to be some error with the cuda target. Could you give me some help? Thank you! Part of the error message in the CI is listed below:

dev = cuda(0), target = 'cuda'

   @tvm.testing.parametrize_targets

   def test_nll_loss(dev, target):

>       verify_nll_loss(dev, target, (10, 5))

tests/python/topi/python/test_topi_loss.py:60:


tests/python/topi/python/test_topi_loss.py:40: in verify_nll_loss

fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss")

python/tvm/driver/build_module.py:353: in build

mod_host, mdev = _build_for_device(input_mod, tar, target_host)

python/tvm/driver/build_module.py:177: in _build_for_device

mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)

python/tvm/ir/transform.py:161: in call

return _ffi_transform_api.RunPass(self, mod)
...

E Did you forget to bind?

E Variable T_divide is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable targets is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable weights is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable targets is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable targets is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable weights is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable targets is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable predictions is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E Variable targets is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.

E File "/workspace/src/tir/analysis/verify_memory.cc", line 202

E RuntimeError: Memory verification failed with the following errors:

E PrimFunc([predictions, targets, weights, T_divide]) attrs={"global_symbol": "nll_loss", "tir.noalias": (bool)1, "target": cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32} {

E // attr [nll_loss] storage_scope = "global"

E allocate nll_loss[float32 * 10]

E // attr [nll_loss_red] storage_scope = "global"

E allocate nll_loss_red[float32 * 1]

E // attr [nll_loss_red] storage_scope = "global"

E allocate nll_loss_red[float32 * 1]

E for (ax0, 0, 10) {

E nll_loss[ax0] = tir.if_then_else((targets[ax0] != -100), ((0f - predictions[((ax0*5) + targets[ax0])])*weights[targets[ax0]]), 0f)

E }

E nll_loss_red[0] = 0f

E for (k0, 0, 10) {

E nll_loss_red[0] = (nll_loss_red[0] + nll_loss[k0])

E }

E for (ax0, 0, 10) {

E nll_loss[ax0] = tir.if_then_else((targets[ax0] != -100), weights[targets[ax0]], 0f)

E }

E nll_loss_red[0] = 0f

E for (k0, 0, 10) {

E nll_loss_red[0] = (nll_loss_red[0] + nll_loss[k0])

E }

E T_divide[0] = (nll_loss_red[0]/nll_loss_red[0])

E }

@zhuzilin
Copy link
Contributor Author

@tkonolige Could you take another look? Thank you~

@tqchen
Copy link
Member

tqchen commented Jun 22, 2021

cc @altanh @vinx13 it would be great if you can help take another look

Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topi test on coda didn't pass. Looks like there are some issues with scheduling, when you called register_reduction_schedule,on cuda target this schedule is used https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/reduction.py#L26 could you check if there are anything wrong?

*/
inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
std::string reduction = "mean", int ignore_index = -100,
const std::string name = "nll_loss", const std::string tag = kBroadcast) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tag here is topi-level, sometimes we use it to identify a specific compute operation during schedule, otherwise we can leave it empty

@altanh
Copy link
Contributor

altanh commented Jun 23, 2021

@tkonolige I've updated the test. But there seems to be some error with the cuda target. Could you give me some help? Thank you! Part of the error message in the CI is listed below:

I think this error is due to using the default schedule (from te.create_schedule) which does not bind the iteration axes to the GPU threads. This is a bit outside of area of understanding but the solution will be to use the respective cuda/gpu schedules for each compute in the kernel (e.g. schedule_reduce, schedule_injective, etc.). Maybe @vinx13 can give some more detailed help?

nll_loss_result = topi.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)

with tvm.target.Target(target):
s = tvm.te.create_schedule(nll_loss_result.op)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default schedule is used here which caused ci errors

Suggested change
s = tvm.te.create_schedule(nll_loss_result.op)
fschedule = tvm.topi.testing.get_reduce_schedule(target)
s = fschedule([nll_loss_result])

@zhuzilin
Copy link
Contributor Author

@vinx13 @altanh Thank you for your help!

tag here is topi-level, sometimes we use it to identify a specific compute operation during schedule, otherwise we can leave it empty

If I change the value of tag to an empty string, it will fail the check in schedule_reduce, which is:

if tag.is_injective(operator.tag):

I'm not sure if I need to adjust somewhere else...

@zhuzilin zhuzilin requested a review from vinx13 June 24, 2021 02:46
@vinx13
Copy link
Member

vinx13 commented Jun 24, 2021

@vinx13 @altanh Thank you for your help!

tag here is topi-level, sometimes we use it to identify a specific compute operation during schedule, otherwise we can leave it empty

If I change the value of tag to an empty string, it will fail the check in schedule_reduce, which is:

if tag.is_injective(operator.tag):

I'm not sure if I need to adjust somewhere else...

in this case we need to change tag to kInjective as the reduction op is not broadcast

@zhuzilin
Copy link
Contributor Author

zhuzilin commented Jun 25, 2021

in this case we need to change tag to kInjective as the reduction op is not broadcast

@vinx13 Changing tag to kInjective will fail the tag check in traverse_after_reduce in

def traverse_after_reduce(operator):

whereas the empty tag was triggering traverse_before_reduce....

It's worth noticing that the tag is of:

  auto T = tvm::te::compute(
      targets->shape,
      ...
      name, tag);

which is an element-wise operation on targets. And when reduction="mean" or reduction="sum", T will be reduced, whereas when reduction="none", the T will be returned as the result of the nll_loss.

Therefore, as nll_loss is registered to use reduce schedule, when reduction="mean" or "sum", the T will go through the checks in traverse_before_reduce while when reduction="none", T will go through checks in traverse_after_reduce. Right now, the kBroadcast happens to be the only tag to satisfiy both checks and passed the CI . To really solve this problem, I'm afraid we may need to define different tag and different scheduling for different op attr...

Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vinx13 vinx13 merged commit 2186835 into apache:main Jun 25, 2021
@vinx13
Copy link
Member

vinx13 commented Jun 25, 2021

This is merged, thanks @zhuzilin @altanh @tkonolige @tqchen

@zhuzilin
Copy link
Contributor Author

@vinx13 @altanh @tkonolige @tqchen Thank you for your help! I'll submit a pr for the gradient of nll_loss soon.

ylc pushed a commit to ylc/tvm that referenced this pull request Sep 29, 2021
…8056)

* add nll_loss

* enrich the doc and rename parameters

* update upon review

* add tests

* update based on reviews

* update upon reviews

* update upon reviews
zxy844288792 pushed a commit to zxy844288792/tvm that referenced this pull request Mar 4, 2022
…8056)

* add nll_loss

* enrich the doc and rename parameters

* update upon review

* add tests

* update based on reviews

* update upon reviews

* update upon reviews
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants