Skip to content

Commit

Permalink
Merge pull request #3 from lizexu123/add_trt
Browse files Browse the repository at this point in the history
add_trt_2
  • Loading branch information
lizexu123 authored Jul 11, 2024
2 parents 6034d99 + 691dc83 commit f8e7d37
Showing 1 changed file with 52 additions and 3 deletions.
55 changes: 52 additions & 3 deletions test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def build_ir_program(self):
data_format='NCHW',
bias_attr=False,
)

conv2d_transpose = paddle.nn.Conv2DTranspose(
in_channels=1,
out_channels=32,
kernel_size=3,
padding=1,
data_format='NCHW',
bias_attr=False,
)
y = create_parameter(
name="y",
shape=bias_shape,
Expand All @@ -54,12 +61,17 @@ def build_ir_program(self):
)
act_op = paddle.nn.ReLU()
act_out = act_op(paddle.add(conv2d(x), y))
conv_transpose = act_op(paddle.add(conv2d_transpose(x), y))

add_out = paddle.add(act_out, conv_transpose)
pool2d = paddle.nn.MaxPool2D(
kernel_size=2, stride=2, padding=0
)
padding_out = pool2d(act_out)
padding_out = pool2d(add_out)
batch_norm = paddle.nn.BatchNorm(32)
batch_norm_out = batch_norm(padding_out)
softmax = paddle.nn.Softmax()
softmax_out = softmax(padding_out)
softmax_out = softmax(batch_norm_out)
reshaped_out = paddle.reshape(
softmax_out, [softmax_out.shape[0], -1]
)
Expand Down Expand Up @@ -223,5 +235,42 @@ def test_check_output(self):
self.check_pass_correct()


class TestFlattenConCatPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
for x_shape in [[2, 1, 1, 19]]:
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=x_shape, dtype='float32'
)
flatten = paddle.nn.Flatten(start_axis=0, stop_axis=2)
flatten_out = flatten(
paddle.transpose(x, perm=[0, 3, 1, 2])
)
out = paddle.concat([flatten_out], axis=1)
out = paddle.assign(out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.random.random(x_shape).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))

def test_check_output(self):
self.check_pass_correct()


if __name__ == "__main__":
unittest.main()

0 comments on commit f8e7d37

Please sign in to comment.