Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][Tracking Issue][AMP] Tracking Issue for Mixed Precision Pass #8296

Closed
10 of 18 tasks
AndrewZhaoLuo opened this issue Jun 21, 2021 · 7 comments
Closed
10 of 18 tasks
Labels
relay:op src/relay/op topi python/tvm/topi type:rfc-tracking RFC progress tracking. Ref: https://github.com/apache/tvm-rfcs

Comments

@AndrewZhaoLuo
Copy link
Contributor

AndrewZhaoLuo commented Jun 21, 2021

This issue tracks work on supporting mixed precision within TVM.

RFC: apache/tvm-rfcs#6

Edge case ops:

Other discussions:

  • Creating default ALLOW, FOLLOW, NEVER lists for ops
  • Move certain pooling / average operations that are global into NEVER list (or use FP32 accumulation).
  • Write a tutorial

Tasks which may help:

Benchmarking improvements from pass: https://docs.google.com/spreadsheets/d/12lgyfuHaRS-X4uG-1iQOV8oAuPpuVAbspcmkOSPRFHQ/edit?usp=sharing

@AndrewZhaoLuo AndrewZhaoLuo changed the title Tracking Issue for Mixed Precision Pass [AMP] Tracking Issue for Mixed Precision Pass Jun 25, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

cc @Lunderberg

@comaniac comaniac changed the title [AMP] Tracking Issue for Mixed Precision Pass [RFC][Tracking Issue][AMP] Tracking Issue for Mixed Precision Pass Jul 27, 2021
@comaniac comaniac added the type:rfc-tracking RFC progress tracking. Ref: https://github.com/apache/tvm-rfcs label Jul 27, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

cc @masahi

@masahi
Copy link
Member

masahi commented Aug 4, 2021

I've hit a nasty issue. On CPU targets, our sorting related ops are implemented in C++ https://github.com/apache/tvm/blob/main/src/runtime/contrib/sort/sort.cc#L436, and they don't support fp16. So ops like topk, argsort, nms etc do not work on fp16 + cpu target combination. We can add all of them to the NEVER list, but then that would introduce unnecessary cast for GPU targets because sorting on GPU is implemented in TIR so it doesn't have issues with fp16.

Maybe we need to add a specialized CPU sort for fp16 or rewrite CPU sort in TIR... (the same issue would come up with int4, bfloat16 etc). The former solution would not be hard since we just need to add a specialized comparison functor for fp16 like https://github.com/apache/tvm/blob/main/src/runtime/contrib/sort/sort.cc#L40-L43

@masahi
Copy link
Member

masahi commented Aug 24, 2021

It looks like transformer like models have many softmax ops that introduce a lot of casting before / after them, like https://gist.github.com/masahi/0d7d96ae88722b616a906cec2054559e#file-transformer-txt-L137-L143

The fact that softmax and the following cast to fp16 are not fused surprised me. This is because the op pattern for softmax is kOpaque,

reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
. The cast overheads are big if they are not fused, so we are leaving a lot of perf on the table.

@yzhliu Is there a reason softmax op pattern cannot be OUT_ELEMWISE_FUSABLE?

@masahi
Copy link
Member

masahi commented Sep 3, 2021

@AndrewZhaoLuo What is our goal wrt mixed_type accumulation? Assuming we do find cases where mixed accum is beneficial, how are we going to decide when to enable / disable it? Given that currently we can only choose one or the other per op basis:

# return ["float32", mixed_precision_type]
return [mixed_precision_type, mixed_precision_type]

@AndrewZhaoLuo
Copy link
Contributor Author

Yeah the issue behind creating defaults is that we cannot create defaults that work best for every situation. This is especially true since whenever we want speed we trade accuracy which can sometimes become a problem.

For the defaults I envision that for most ops we don't accumulate to FP32. For some ops like the global pools and sums we might turn it on. Really the best way to determine the criteria is to do a lot of the work you've been doing in trying out different models in different applications and seeing what needs to be turned on and off.

That being said, this is really designed to be a tool which requires the user sometimes to go back and modify the default values provided to either get more speed if their model can afford it, or accuracy if they need it. It requires investigation and I don't think we can probably hit all cases well. A tutorial here would help (which is on my long list of TODOs).

Finally, while things are done on a per-op basis, the actual mixed precision function can look at some parts of the relay call like the attributes or the node or the input tensor sizes. Therefore we can be smart about the quantization (e.g. for global pooling, only accumulate in fp32 if the input to output reduction is large enough). Again, a tutorial or example would help flesh this out.

@masahi
Copy link
Member

masahi commented Sep 24, 2021

@AndrewZhaoLuo I briefly looked at bfloat16. While fp16 vs bf16 makes no difference for the conversion pass, it seems it is going to take a lot of effort to compile and run a bf16 model end to end, for at least two reasons:

  • The constant folding pass doesn't work on bfloat16 input
  • Numpy doesn't understand bfloat16, but some topi schedules (winograd conv) try to create a numpy array of type out_dype, which in this case bfloat16.

Since tensorcore can natively run bf16 workloads at the same rate as fp16, and bf16 on x86 servers is becoming a thing, it would be nice to have a good support for bf16 across the stack in the future.

@areusch areusch added the needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it label Oct 19, 2022
@Lunderberg Lunderberg added topi python/tvm/topi relay:op src/relay/op and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Oct 19, 2022
@tqchen tqchen closed this as completed Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
relay:op src/relay/op topi python/tvm/topi type:rfc-tracking RFC progress tracking. Ref: https://github.com/apache/tvm-rfcs
Projects
None yet
Development

No branches or pull requests

6 participants