Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pir]Fix Value eq error when using set #58896

Merged
merged 53 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6539a97
[Pir]Fix Value eq error when using set
0x45f Nov 10, 2023
3899a7d
Fix iter
0x45f Nov 10, 2023
7f6c298
Refine code
0x45f Nov 10, 2023
0793a38
add ValueDict
zrr1999 Nov 10, 2023
f8abfa3
fix
zrr1999 Nov 13, 2023
6e5947b
update valueset
zrr1999 Nov 13, 2023
09f5eac
update dict
zrr1999 Nov 13, 2023
ad7461d
fix bug
zrr1999 Nov 14, 2023
531621c
improve
zrr1999 Nov 14, 2023
6b252fc
fix
zrr1999 Nov 14, 2023
7eb2685
fix contains
zrr1999 Nov 14, 2023
d7792bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 15, 2023
bc174c9
Fix set(Value) or dict[Value] = xx code
0x45f Nov 15, 2023
8a6f0a7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 16, 2023
d7b38e8
Fix ut
0x45f Nov 16, 2023
963d0f8
Fix double grad ut
0x45f Nov 16, 2023
d299a50
Fix ut
0x45f Nov 17, 2023
8bae46b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 24, 2023
ba7a2cd
Forbid opresult hash
0x45f Nov 24, 2023
4a50818
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Nov 26, 2023
ace1ae1
Fix set and dict
0x45f Nov 26, 2023
cd3e2ea
Fix dy2st
0x45f Nov 26, 2023
470a2aa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 7, 2023
faefbe8
Refine code
0x45f Dec 7, 2023
526eecd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 7, 2023
78cbe1b
Forbid eq
0x45f Dec 7, 2023
9be9129
Fix map and value_eq
0x45f Dec 7, 2023
176e09b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 8, 2023
ea46d6b
Fix None hash
0x45f Dec 8, 2023
dcaf6c4
Fix decomp
0x45f Dec 8, 2023
1df44e2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 11, 2023
55cbfcc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 12, 2023
b8b6416
Refine value set/dict
0x45f Dec 12, 2023
16bc3c6
Add hash
0x45f Dec 12, 2023
3b0a7cd
fix clone program, return a assciated array.
2742195759 Dec 12, 2023
1a7767d
Merge commit 'refs/pull/58896/head' of https://github.com/PaddlePaddl…
2742195759 Dec 12, 2023
7ab860c
Fix iter
0x45f Dec 12, 2023
31686de
Fix code
0x45f Dec 13, 2023
294abe4
Fix backward
0x45f Dec 13, 2023
b168af9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 13, 2023
96d8cf3
Fix decomp and add copy()
0x45f Dec 13, 2023
e3637da
Format code
0x45f Dec 13, 2023
ef65985
Fix ut
0x45f Dec 14, 2023
7e6422c
Fix layer params set
0x45f Dec 14, 2023
a43cd6e
Fix prim op test
0x45f Dec 14, 2023
cf1ce82
Fix named_parameters
0x45f Dec 14, 2023
88b3e9f
Support value __eq__
0x45f Dec 15, 2023
4ec0f71
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 15, 2023
2611c9f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 15, 2023
7b8cf7e
Fix test cond ==
0x45f Dec 15, 2023
6af0553
Add ut for value set/dict
0x45f Dec 18, 2023
b8953f2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 18, 2023
4464870
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,8 @@ void BindValue(py::module *m) {
[](Value self) {
return paddle::dialect::scale(self, -1.0, 0.0, true);
})
.def("__eq__", &Value::operator==)
.def("__hash__", [](Value self) { return std::hash<pir::Value>{}(self); })
.def("is_same", &Value::operator==)
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
.def("__repr__", &Value2String);
// For basaic operators
OVERRIDE_OPERATOR_FOR_EACH(__add__, add, 1.0, other, true);
Expand Down Expand Up @@ -1029,7 +1029,8 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
no_need_buffer_values.end());
}

using OpResultMap = std::unordered_map<pir::OpResult, pir::OpResult>;
using OpResultMap =
std::pair<std::vector<pir::OpResult>, std::vector<pir::OpResult>>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
const Program &program) {
// Limitation of this function:
Expand All @@ -1042,12 +1043,14 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
auto *cloned_op = BuildOpFrom(&op, value_map);
cloned_program->block()->push_back(cloned_op);
}
std::unordered_map<pir::OpResult, pir::OpResult> op_result_map;
std::vector<pir::OpResult> associated_array_key, associated_array_value;
for (auto &pair : value_map) {
op_result_map[pair.first.dyn_cast<pir::OpResult>()] =
pair.second.dyn_cast<pir::OpResult>();
associated_array_key.push_back(pair.first.dyn_cast<pir::OpResult>());
associated_array_value.push_back(pair.second.dyn_cast<pir::OpResult>());
}
return std::make_pair(cloned_program, op_result_map);
return std::make_pair(
cloned_program,
std::make_pair(associated_array_key, associated_array_value));
}

void AppendSetParameter(Program *forward_program,
Expand Down
149 changes: 142 additions & 7 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,149 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
import warnings
from collections.abc import Sequence
from typing import Any

from paddle import pir
from paddle.base import core
from paddle.base.wrapped_decorator import signature_safe_contextmanager


class ValueWrapper:
def __init__(self, value) -> None:
if isinstance(value, ValueWrapper):
assert isinstance(value._value, (type(None), pir.Value))
else:
assert isinstance(value, (type(None), pir.Value))
self._value = value._value if isinstance(value, ValueWrapper) else value

def __hash__(self) -> int:
if isinstance(self._value, pir.Value):
return self._value.hash()
else:
return hash(self._value)

def __eq__(self, other) -> bool:
if not isinstance(other, ValueWrapper):
warnings.warn(
f'In ValueWrapper.__eq__ expected type of `other` is ValueWrapper but received {other.__class__}.'
)
return False

if self._value is None or other._value is None:
return self._value is None and other._value is None
return self._value.is_same(other._value)


class ValueDict:
def __init__(
self,
iter=None,
*,
default_factory=None,
):
self._items: dict[ValueWrapper] = {}
self._default_factory = default_factory
if iter is not None:
for key, val in iter.items():
self[key] = val

def copy(self):
ret = ValueDict()
ret._items = self._items.copy()
ret._default_factory = self._default_factory
return ret

def update(self, other_dict):
for key, val in other_dict.items():
self[key] = val

def keys(self):
for key in self._items.keys():
yield key._value

def values(self):
return self._items.values()

def items(self):
for key, val in self._items.items():
yield key._value, val

def pop(self, key):
if not self.__contains__(key):
raise KeyError(f'{key} is not in ValueDict')
return self._items.pop(ValueWrapper(key))

def __setitem__(self, key, val: Any):
self._items[ValueWrapper(key)] = val

def __getitem__(self, key):
if not self.__contains__(key):
if self._default_factory is not None:
self[key] = self._default_factory()
else:
raise KeyError(f'{key} is not in ValueDict')
return self._items[ValueWrapper(key)]

def __bool__(self):
return bool(self._items)

def __len__(self):
return len(self._items)

def __iter__(self):
return self.keys()

def __contains__(self, key):
return ValueWrapper(key) in self._items


class ValueSet:
def __init__(
self, iter: Sequence[ValueWrapper] | set[ValueWrapper] | None = None
):
self._set: set[ValueWrapper] = set()
if iter is not None:
for val in iter:
self.add(val)

def copy(self):
ret = ValueSet()
ret._set = self._set.copy()
return ret

def add(self, val):
if not self.__contains__(val):
self._set.add(ValueWrapper(val))

def update(self, other: set):
for val in other:
self.add(val)

def __and__(self, other: ValueSet):
return ValueSet(self._set & other._set)

def __or__(self, other: ValueSet):
return ValueSet(self._set | other._set)

def __bool__(self):
return bool(self._set)

def __len__(self):
return len(self._set)

def __iter__(self):
for val in self._set:
yield val._value

def __contains__(self, val):
return ValueWrapper(val) in self._set


class State:
"""
record relationship of forward op/value and backward op/value
Expand All @@ -30,24 +165,24 @@ class State:
def __init__(self, block):
self.block = block
# opresult -> list(list(opresult))
self.value_to_valuegrad = collections.defaultdict(list)
self.value_to_sumvaluegrad = collections.defaultdict(list)
self.value_to_valuegrad = ValueDict(default_factory=list)
self.value_to_sumvaluegrad = ValueDict(default_factory=list)
# operation -> list(operation)
self.op_to_opgrad = collections.defaultdict(list)

# opresult -> list(opresult)
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
# operation -> list(operation)
self.opgrad_to_op = collections.defaultdict(list)
# only for controlflow
# inside_value is sub block value, which will yield to parent block,
# parant block value is outside_value
self.inside_value_to_outside_value_map = {}
self.inside_value_to_outside_value_map = ValueDict()

def turn_map(self) -> None:
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
self.opgrad_to_op = collections.defaultdict(list)

for k, v in self.value_to_valuegrad.items():
Expand Down
Loading