-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpaddle_aux.py
105 lines (92 loc) · 3.37 KB
/
paddle_aux.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# This file is generated by PaConvert ToolKit, please Don't edit it!
import paddle
def split(x, num_or_sections, axis=0):
if isinstance(num_or_sections, int):
return paddle.split(x, x.shape[axis]//num_or_sections, axis)
else:
return paddle.split(x, num_or_sections, axis)
def repeat(self, *args, **kwargs):
if args:
if len(args)==1 and isinstance(args[0], (tuple, list)):
return paddle.tile(self, args[0])
else:
return paddle.tile(self, list(args))
elif kwargs:
assert 'repeats' in kwargs
return paddle.tile(self, repeat_times=kwargs['repeats'])
setattr(paddle.Tensor, 'repeat', repeat)
setattr(paddle.base.framework.Variable, 'repeat', repeat)
def to(self, *args, **kwargs):
args_list = ["x", "y", "non_blocking", "copy", "memory_format"]
new_kwargs = {}
for i, node in enumerate(args):
k = args_list[i]
new_kwargs[k] = node
for node in kwargs:
v = kwargs[node]
new_kwargs[node] = v
kwargs = new_kwargs
if not kwargs:
return self
elif "tensor" in kwargs:
return paddle.cast(self, "{}.dtype".format(kwargs["tensor"]))
elif "dtype" in kwargs:
return paddle.cast(self, "{}".format(kwargs["dtype"]))
elif "device" in kwargs and "dtype" not in kwargs:
return self
elif kwargs:
if "y" not in kwargs and "x" in kwargs:
if isinstance(kwargs["x"], paddle.dtype):
dtype = kwargs["x"]
elif isinstance(kwargs["x"], str) and kwargs["x"] not in ['cpu', 'cuda', 'ipu', 'xpu']:
dtype = kwargs["x"]
elif isinstance(kwargs["x"], paddle.Tensor):
dtype = kwargs["x"].dtype
else:
dtype = self.dtype
return paddle.cast(self, dtype)
elif "y" in kwargs and "x" in kwargs:
if isinstance(kwargs["x"], paddle.dtype):
dtype = kwargs["x"]
elif isinstance(kwargs["x"], str):
if x not in ['cpu', 'cuda', 'ipu', 'xpu']:
dtype = kwargs["x"]
else:
dtype = kwargs["y"] if isinstance(kwargs["y"], str) else self.dtype
else:
dtype = kwargs["x"]
return paddle.cast(self, dtype)
else:
return self
setattr(paddle.Tensor, 'to', to)
def reshape(self, *args, **kwargs):
if args:
if len(args)==1 and isinstance(args[0], (tuple, list)):
return paddle.reshape(self, args[0])
else:
return paddle.reshape(self, list(args))
elif kwargs:
assert 'shape' in kwargs
return paddle.reshape(self, shape=kwargs['shape'])
setattr(paddle.Tensor, 'reshape', reshape)
setattr(paddle.base.framework.Variable, 'reshape', reshape)
def add(self, *args, **kwargs):
if 'other' in kwargs:
y = kwargs['other']
elif 'y' in kwargs:
y = kwargs['y']
else:
y = args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
if alpha != 1:
if not isinstance(y, paddle.Tensor):
y = paddle.to_tensor(alpha * y)
else:
y = alpha * y
else:
if not isinstance(y, paddle.Tensor):
y = paddle.to_tensor(y)
return paddle.add(self, y)
setattr(paddle.Tensor, 'add', add)
setattr(paddle.base.framework.Variable, 'add', add)