Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PRIM】Support custom_vjp for reducing video memory #50885

Merged
merged 7 commits into from
Mar 14, 2023

Conversation

cxxly
Copy link
Contributor

@cxxly cxxly commented Feb 24, 2023

PR types

New features

PR changes

Others

Describe

Pcard-66975

  • Add new CustomVJP feature. When forward and backward are both decomposite, run backward decomposite firstly.
  • Enable dy2st to support CustomVJP . @2742195759
  • Fix _lower_composite bugs when op is not register in decomposite rules.
  • Fix cast prim api and vjp rules DataType and VarType mapping error.
  • Delete to_prim from autograd all list and supoort blacklist/whitelist feature.

@paddle-bot
Copy link

paddle-bot bot commented Feb 24, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@cxxly cxxly force-pushed the prim_custom_vjp branch 11 times, most recently from 18cdd67 to d75ef48 Compare March 2, 2023 13:17
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments

@@ -293,7 +316,8 @@ def _create_pure_fp16_program(self, is_infer_mode=False):
@switch_to_static_graph
def _create_forward_backward_train_program(self):
whole_program = self._train_program
_, forward_end_op_index = self._infer_info('fp32', self._create_program)
# _, forward_end_op_index = self._infer_info('fp32', self._create_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no comments code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

class PrimHooker(PartialProgramLayerHook):
def __init__(self):
self.custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use _is_all_prim_enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self, partial_program_layer, whole_program, backward_start_idx
):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments to show we need support other block later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, check block length and raise exception when the length is over 1.

to_prim(infer_program.block(0))
return infer_program

partial_program = partial_program_from(concrete_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only entry for dy2st?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, enable fp16 support.

@@ -1519,7 +1519,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
):
op = program_desc.block(0).op(i)
if op.type() == 'fill_any_like':
if op.type() in ['fill_any_like', "fill_constant"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



@switch_to_static_graph
def to_prim(blocks, exclude=frozenset()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change name maybe,too many to_prim

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -99,7 +100,12 @@ def _train(self, use_prim, data, axis, keep_dim):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave a comment or seal this into a function, in case of possible changes happened later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cxxly cxxly force-pushed the prim_custom_vjp branch 5 times, most recently from 74fd37a to 156c830 Compare March 10, 2023 10:28
@@ -183,6 +194,7 @@ def __init__(
# Set default mode to train
self.training = True
self._infer_info = ProgramInfo()
self._forward_end_index_map = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个字段可以集成到 ProgramInfo里吧?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个PR先合入吧,合入之后我来把这个字段放入到ProgramInfo里面。

Aurelius84 and others added 7 commits March 13, 2023 09:20
…Paddle#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557)

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* [dy2static-ci] fix dy2static ci errors.

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>
Copy link
Contributor

@2742195759 2742195759 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some questions can be solved later

@@ -3751,7 +3751,6 @@ def __init__(self, program, idx):
self.vars = collections.OrderedDict() # var_name --> var
self.ops = list() # operator list
self.program = program
self.removed_vars = collections.OrderedDict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit auto format

@@ -701,6 +736,7 @@ def _prepare_attributes(self):
'program_id',
self.program_id,
]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless blank?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit auto format

@@ -1119,5 +1155,8 @@ def add_build_strategy_for(
if hasattr(compiled_program._program, 'lr_sheduler'):
builded_program.lr_sheduler = compiled_program._program.lr_sheduler
else:
builded_program = program
# can't just create a new program, we need copy the vardesc.
builded_program = paddle.static.Program()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix the bug when the program only contain a var.

@paddle.jit.to_static
def f(x):
    return x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cxxly cxxly merged commit 300f36c into PaddlePaddle:develop Mar 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants