Skip to content

Commit

Permalink
process makefile
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Aug 24, 2023
1 parent 99d4a56 commit bef3511
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
3 changes: 1 addition & 2 deletions python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
def mean(x, axis, keepdim):
"""define composite rule of op mean"""
x_shape = x.shape
axes = axis or tuple(range(0, len(x_shape)))
axes = (axes,) if isinstance(axes, int) else axes
axes = (axis,) if isinstance(axis, int) else tuple(range(0, len(x_shape)))
sum_x = sum(x, axis=axes, keepdim=keepdim)
value_to_fill = 1
for axis in axes:
Expand Down
7 changes: 3 additions & 4 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
out = difference * rsqrt_var

if scale is not None:
if x.shape[begin_norm_axis:] is not scale.shape:
if x.shape[begin_norm_axis:] != scale.shape:
scale = reshape(scale, x.shape[begin_norm_axis:])
out = out * scale
if bias is not None:
if x.shape[begin_norm_axis:] is not bias.shape:
if x.shape[begin_norm_axis:] != bias.shape:
bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + bias

Expand Down Expand Up @@ -266,8 +266,7 @@ def mean_composite(x, axis, keepdim):
is_amp = True
x = cast(x, "float32")

axes = axis or list(range(0, len(x.shape)))
axes = [axes] if isinstance(axes, int) else axes
axes = [axis] if isinstance(axis, int) else list(range(0, len(x.shape)))
sum_x = sum(x, axis=axes, keepdim=keepdim)
ele_nums_list = [x.shape[axis] for axis in axes]
if ele_nums_list == []:
Expand Down
16 changes: 13 additions & 3 deletions test/prim/new_ir_prim/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program)

foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
FLAGS_enable_new_ir_api=true)
endforeach()

file(
GLOB TEST_INTERP_CASES
GLOB TEST_PRIM_TRANS_NEW_IR_CASES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")
string(REPLACE ".py" "" TEST_PRIM_TRANS_NEW_IR_CASES
"${TEST_PRIM_TRANS_NEW_IR_CASES}")

list(REMOVE_ITEM TEST_PRIM_TRANS_NEW_IR_CASES ${TEST_PRIM_PURE_NEW_IR_CASES})

foreach(target ${TEST_INTERP_CASES})
foreach(target ${TEST_PRIM_TRANS_NEW_IR_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
FLAGS_enable_new_ir_in_executor=true)
endforeach()

0 comments on commit bef3511

Please sign in to comment.