Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pir]Fix Value eq error when using set #58896

Merged
merged 53 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6539a97
[Pir]Fix Value eq error when using set
0x45f Nov 10, 2023
3899a7d
Fix iter
0x45f Nov 10, 2023
7f6c298
Refine code
0x45f Nov 10, 2023
0793a38
add ValueDict
zrr1999 Nov 10, 2023
f8abfa3
fix
zrr1999 Nov 13, 2023
6e5947b
update valueset
zrr1999 Nov 13, 2023
09f5eac
update dict
zrr1999 Nov 13, 2023
ad7461d
fix bug
zrr1999 Nov 14, 2023
531621c
improve
zrr1999 Nov 14, 2023
6b252fc
fix
zrr1999 Nov 14, 2023
7eb2685
fix contains
zrr1999 Nov 14, 2023
d7792bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 15, 2023
bc174c9
Fix set(Value) or dict[Value] = xx code
0x45f Nov 15, 2023
8a6f0a7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 16, 2023
d7b38e8
Fix ut
0x45f Nov 16, 2023
963d0f8
Fix double grad ut
0x45f Nov 16, 2023
d299a50
Fix ut
0x45f Nov 17, 2023
8bae46b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 24, 2023
ba7a2cd
Forbid opresult hash
0x45f Nov 24, 2023
4a50818
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 26, 2023
ace1ae1
Fix set and dict
0x45f Nov 26, 2023
cd3e2ea
Fix dy2st
0x45f Nov 26, 2023
470a2aa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 7, 2023
faefbe8
Refine code
0x45f Dec 7, 2023
526eecd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 7, 2023
78cbe1b
Forbid eq
0x45f Dec 7, 2023
9be9129
Fix map and value_eq
0x45f Dec 7, 2023
176e09b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 8, 2023
ea46d6b
Fix None hash
0x45f Dec 8, 2023
dcaf6c4
Fix decomp
0x45f Dec 8, 2023
1df44e2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 11, 2023
55cbfcc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 12, 2023
b8b6416
Refine value set/dict
0x45f Dec 12, 2023
16bc3c6
Add hash
0x45f Dec 12, 2023
3b0a7cd
fix clone program, return a assciated array.
2742195759 Dec 12, 2023
1a7767d
Merge commit 'refs/pull/58896/head' of https://github.com/PaddlePaddl…
2742195759 Dec 12, 2023
7ab860c
Fix iter
0x45f Dec 12, 2023
31686de
Fix code
0x45f Dec 13, 2023
294abe4
Fix backward
0x45f Dec 13, 2023
b168af9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 13, 2023
96d8cf3
Fix decomp and add copy()
0x45f Dec 13, 2023
e3637da
Format code
0x45f Dec 13, 2023
ef65985
Fix ut
0x45f Dec 14, 2023
7e6422c
Fix layer params set
0x45f Dec 14, 2023
a43cd6e
Fix prim op test
0x45f Dec 14, 2023
cf1ce82
Fix named_parameters
0x45f Dec 14, 2023
88b3e9f
Support value __eq__
0x45f Dec 15, 2023
4ec0f71
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 15, 2023
2611c9f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 15, 2023
7b8cf7e
Fix test cond ==
0x45f Dec 15, 2023
6af0553
Add ut for value set/dict
0x45f Dec 18, 2023
b8953f2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 18, 2023
4464870
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ void BindValue(py::module *m) {
[](Value &self, Value &op_value) {
self.ReplaceAllUsesWith(op_value);
})
.def("__eq__", &Value::operator==)
.def("__eq__",
.def("is_same", &Value::operator==)
.def("is_same",
[](Value &self, OpResult &other) {
return self.impl() == other.Value::impl();
})
Expand Down Expand Up @@ -664,8 +664,8 @@ void BindOpResult(py::module *m) {
OVERRIDE_COMPARE_OP_FOR_EACH(__gt__, greater_than);
OVERRIDE_COMPARE_OP_FOR_EACH(__ge__, greater_equal);

op_result.def("__eq__", &OpResult::operator==)
.def("__eq__",
op_result.def("is_same", &OpResult::operator==)
.def("is_same",
[](OpResult &self, Value &other) {
return self.Value::impl() == other.impl();
})
Expand Down
90 changes: 84 additions & 6 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,86 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
from typing import Any


class ValueInDict:
def __init__(self, value) -> None:
self.value = value
0x45f marked this conversation as resolved.
Show resolved Hide resolved

def __hash__(self) -> int:
return hash(self.value)

def __eq__(self, other) -> bool:
if isinstance(other, ValueInDict):
other = other.value
return self.value.is_same(other)


class ValueDict:
def __init__(
self,
iter: dict[ValueInDict, Any] | None = None,
*,
default_factory=None,
):
self._items: dict[ValueInDict, Any] = {}
self._default_factory = default_factory
if iter is not None:
for key, val in iter.items():
self[key] = val

def update(self, other_dict):
for key, val in other_dict:
self[ValueInDict(key)] = val

def keys(self):
return self._items.keys()

def values(self):
return self._items.values()

def items(self):
return self._items.items()

def __setitem__(self, other_key, other_val: Any):
self._items[other_key] = other_val
0x45f marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, other_key):
if not self.__contains__(other_key):
if self._default_factory is not None:
self[other_key] = self._default_factory()
else:
self[other_key] = None
return self._items[other_key]

def __and__(self, other_dict: ValueDict):
ret = ValueDict()
for key, val in self._items.items():
if key in other_dict:
ret[key] = val
return ret

def __or__(self, other_dict: ValueDict):
return ValueDict(self._items | other_dict._items)

def __bool__(self):
return bool(self._items)

def __len__(self):
return len(self._items)

def __iter__(self):
return self.keys()

def __contains__(self, other_key):
for key in self._items.keys():
if hash(key) == hash(other_key) and key == other_key:
return True
return False


class State:
Expand All @@ -25,20 +103,20 @@ class State:
def __init__(self, program):
self.program = program
# opresult -> list(list(opresult))
self.value_to_valuegrad = collections.defaultdict(list)
self.value_to_sumvaluegrad = collections.defaultdict(list)
self.value_to_valuegrad = ValueDict(default_factory=list)
self.value_to_sumvaluegrad = ValueDict(default_factory=list)
# operation -> list(operation)
self.op_to_opgrad = collections.defaultdict(list)

# opresult -> list(opresult)
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
# operation -> list(operation)
self.opgrad_to_op = collections.defaultdict(list)

def turn_map(self) -> None:
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
self.opgrad_to_op = collections.defaultdict(list)

for k, v in self.value_to_valuegrad.items():
Expand Down
102 changes: 81 additions & 21 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections
import logging
from collections.abc import Sequence
Expand All @@ -27,6 +29,63 @@
__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper']


class ValueInSet:
def __init__(self, value) -> None:
self.value = value

def __hash__(self) -> int:
return hash(self.value)

def __eq__(self, other) -> bool:
if isinstance(other, ValueInSet):
other = other.value
return self.value.is_same(other)


class ValueSet:
0x45f marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self, iter: Sequence[ValueInSet] | set[ValueInSet] | None = None
):
self._values: set[ValueInSet] = set()
if iter is not None:
for val in iter:
self.add(val)

def add(self, other_val):
other_val = ValueInSet(other_val)
if not self.__contains__(other_val):
self._values.add(other_val)

def update(self, other_set: set):
for val in other_set:
self.add(ValueInSet(val))

def __and__(self, other_set: ValueSet):
ret = ValueSet()
for val in self._values:
if val in other_set:
ret.add(val)
return ret

def __or__(self, other_set: ValueSet):
return ValueSet(self._values | other_set._values)

def __bool__(self):
return bool(self._values)

def __len__(self):
return len(self._values)

def __iter__(self):
return iter(self._values)
0x45f marked this conversation as resolved.
Show resolved Hide resolved

def __contains__(self, other_val):
for value in self._values:
0x45f marked this conversation as resolved.
Show resolved Hide resolved
if hash(value) == hash(other_val) and value == other_val:
return True
return False


def check_type(input, input_name, expected_type, op_name, extra_message=''):
if not isinstance(input, expected_type):
raise TypeError(
Expand Down Expand Up @@ -124,7 +183,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
complete_outputs = outputs
complete_gradoutputs = grad_outputs

visited_output = set()
visited_output = ValueSet()
for output in outputs:
if output in visited_output:
continue
Expand Down Expand Up @@ -157,7 +216,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state):

def some_in_set(value_list, value_set):
def operand2value(values):
value_set = set()
value_set = ValueSet()
for item in values:
if isinstance(item, paddle.pir.OpOperand):
value_set.add(item.source())
Expand Down Expand Up @@ -245,7 +304,7 @@ def update_no_grad_set_after_prune(
from inputs to outputs add value not in the path to no_grad_set,
from outputs to inputs add value not in the path to no_grad_set,
'''
inputs_set = set(inputs)
inputs_set = ValueSet(inputs)
if inputs_set:
for op in block.ops:
if some_in_set(op.operands_source(), inputs_set):
Expand All @@ -258,12 +317,12 @@ def update_no_grad_set_after_prune(
if value not in inputs_set:
no_grad_set.add(value)

outputs_set = set(outputs)
no_grad_set_tmp = set()
outputs_set = ValueSet(outputs)
no_grad_set_tmp = ValueSet()
for op in reversed(effective_forward_ops):
for output in op.results():
if output not in outputs_set and not some_in_set(
[output], set(op.operands_source())
[output], ValueSet(op.operands_source())
):
no_grad_set_tmp.add(output)

Expand Down Expand Up @@ -317,7 +376,7 @@ def inverse_sort_op(ops):


def append_backward_ops(
block, effective_forward_ops, no_grad_set, backward_ops, state
block, effective_forward_ops, no_grad_set, backward_ops, state: State
):
'''
add grad_op in order of topological inverse sort
Expand Down Expand Up @@ -577,7 +636,7 @@ def update_input_grad_map(op, input_grads):


def create_backward_prune_set(inputs, outputs, no_grad_set, state):
outputs_set = set()
outputs_set = ValueSet()
for input_ in inputs:
if not input_.use_empty():
for item in input_.first_use().owner().operands_source():
Expand All @@ -586,18 +645,18 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
else:
logging.warning("input privided by inputs has no use")

inputs_set = set()
inputs_set = ValueSet()
for output in outputs:
if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_set_tmp = set()
inputs_set_tmp = ValueSet()
for out_grad in inputs_set:
if not out_grad.use_empty():
for item in out_grad.first_use().owner().operands_source():
inputs_set_tmp.add(item)
inputs_set.update(inputs_set_tmp)

no_gradvar_set = set() # grad_value of value in no_grad_set
no_gradvar_set = ValueSet() # grad_value of value in no_grad_set
for key in state.value_to_valuegrad:
if key in no_grad_set and state.value_to_valuegrad[key] != []:
no_gradvar_set.add(state.value_to_valuegrad[key][0][0])
Expand Down Expand Up @@ -640,8 +699,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
grad_outputs, outputs, state
)

inputs_set = set(inputs)
outputs_set = set(complete_outputs)
inputs_set = ValueSet(inputs)
outputs_set = ValueSet(complete_outputs)
effective_forward_ops, _ = prune_ops(
block.ops, inputs_set, outputs_set, no_grad_set
)
Expand Down Expand Up @@ -690,7 +749,7 @@ def calc_gradient(outputs, inputs, grad_outputs, no_grad_set):
be: (1) a Value filled with 1 when the i-th element of `grad_outputs`
is None; (2) the i-th element of `grad_outputs` when the i-th element of
`grad_outputs` is a Value. Default None.
no_grad_set (set(Value), optional):
no_grad_set (list(Value)|tuple(Value), optional):
the Values whose gradients are not needed to compute. Default None.

Return:
Expand All @@ -701,7 +760,10 @@ def calc_gradient(outputs, inputs, grad_outputs, no_grad_set):
"""
# record input value and its gradient (Value to Value)
input_to_inputgrad_map = calc_gradient_helper(
outputs, inputs, grad_outputs=grad_outputs, no_grad_set=no_grad_set
outputs,
inputs,
grad_outputs=grad_outputs,
no_grad_set=ValueSet(no_grad_set),
)

inputgrad = []
Expand Down Expand Up @@ -764,7 +826,7 @@ def grad(
`inputs` are unreachable in the graph (i.e., their gradients are None),
error would be raised if allow_unused=False, or None would be returned as
their gradients if allow_unused=True. Default False.
no_grad_vars (Value|list(Value)|tuple(Value)|set(Value), optional):
no_grad_vars (Value|list(Value)|tuple(Value), optional):
the Values whose gradients are not needed to compute. Default None.

Returns:
Expand Down Expand Up @@ -794,18 +856,16 @@ def grad(
check_type(
no_grad_vars,
'no_grad_vars',
((paddle.pir.Value, paddle.pir.OpResult), list, tuple, set, type(None)),
0x45f marked this conversation as resolved.
Show resolved Hide resolved
((paddle.pir.Value, paddle.pir.OpResult), list, tuple, type(None)),
'paddle.autograd.ir_backward.grad',
)
outputs = _as_list(outputs)
inputs = _as_list(inputs)
grad_outputs = _as_list(grad_outputs)
if no_grad_vars is None:
no_grad_set = set()
elif no_grad_vars is not set:
no_grad_set = set(no_grad_vars)
no_grad_set = ValueSet()
else:
no_grad_set = no_grad_vars
no_grad_set = ValueSet(no_grad_vars)

input_grad = calc_gradient(outputs, inputs, grad_outputs, no_grad_set)

Expand Down
7 changes: 1 addition & 6 deletions python/paddle/base/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,7 +2653,6 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
(paddle.pir.Value, paddle.pir.OpResult),
list,
tuple,
set,
0x45f marked this conversation as resolved.
Show resolved Hide resolved
type(None),
),
'paddle.autograd.ir_backward.grad',
Expand All @@ -2662,11 +2661,7 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
inputs = _as_list(inputs)
target_gradients = _as_list(target_gradients)
if no_grad_set is None:
no_grad_set = set()
elif no_grad_set is not set:
no_grad_set = set(no_grad_set)
else:
no_grad_set = no_grad_set
no_grad_set = []
from paddle.autograd.ir_backward import (
calc_gradient as pir_calc_gradient,
)
Expand Down