Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cinn(dynamic): support run exp sub subgraph with dynamic shape graph #59640

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ void CompilationTask::operator()() {

void CompilationTask::Lowering() {
auto op_lowerer = CreateOpLowerer<pir::GroupPtr>(context_->target_);
context_->SetLoweredFuncs(op_lowerer.BucketLower(context_->group_));
context_->SetLoweredFuncs(
op_lowerer.BucketLower(context_->group_, false, false, false));
}

void CompilationTask::CodegenAndJit() {
Expand Down
21 changes: 19 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ std::vector<ir::Tensor> CollectInputTensor(
if (func_args != nullptr) {
func_args->push_back(tensor);
}
} else {
// TODO(6clc): After supporting symbolic calculation,
// 1. Check that the shape of the tensor with the same name is the same
// size
// 2. Or make the symbol expression in compute output tensor consistent
// with the one inferred in shape_analysis
(*tensor_map)[in_value]->sym_shape = tensor->sym_shape;
(*tensor_map)[in_value]->shape = tensor->shape;
(*tensor_map)[in_value]->sym_domain = tensor->sym_domain;
(*tensor_map)[in_value]->domain = tensor->domain;
}
tensors.push_back(tensor);
}
Expand Down Expand Up @@ -528,15 +538,22 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
// update args for dynamic dim
int num_tensor_args = static_cast<int>(group_func_args.size());
int non_tensor_arg_idx = group_func_args.size();
std::unordered_set<std::string> int_args_set;
for (int tensor_arg_idx = 0; tensor_arg_idx < num_tensor_args;
tensor_arg_idx++) {
auto tensor_dim = (*group_func_arg_tensors)[tensor_arg_idx]->sym_shape;
int tensor_dim_size = tensor_dim.size();
for (int tensor_arg_dim_idx = 0; tensor_arg_dim_idx < tensor_dim_size;
tensor_arg_dim_idx++) {
if (tensor_dim[tensor_arg_dim_idx]->IsDynamic()) {
group_func_args.emplace_back(ir::_Var_::Make(
tensor_dim[tensor_arg_dim_idx]->GetSymbolName(), common::Int(32)));
const std::string symbol_name =
tensor_dim[tensor_arg_dim_idx]->GetSymbolName();
if (int_args_set.count(symbol_name) != 0) {
continue;
}
int_args_set.insert(symbol_name);
group_func_args.emplace_back(
ir::_Var_::Make(symbol_name, common::Int(32)));
group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx,
tensor_arg_dim_idx};
}
Expand Down
24 changes: 17 additions & 7 deletions paddle/cinn/hlir/pe/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
PD_DECLARE_bool(cinn_bucket_compile);

namespace cinn {
namespace hlir {
Expand Down Expand Up @@ -241,13 +242,22 @@ Tensor Broadcast(const FuncOp& op,
// the counts of left-shift of tensor b so as to right alignment
int axis_offset = 0;

GetBroadcastShape(a->shape,
b->shape,
&common_shape,
&broadcast_flags1,
&broadcast_flags2,
&axis_offset,
axis);
if (FLAGS_cinn_bucket_compile) {
// TODO(6clc): After supporting symbolic calculation,
// perfect the logic of shape equal judgment
common_shape = a->shape;
broadcast_flags1.resize(common_shape.size(), true);
broadcast_flags2.resize(common_shape.size(), true);
} else {
GetBroadcastShape(a->shape,
b->shape,
&common_shape,
&broadcast_flags1,
&broadcast_flags2,
&axis_offset,
axis);
}

auto fn = [=](const std::vector<Expr>& indice) {
std::vector<Expr> broadcast_indice1;
std::vector<Expr> broadcast_indice2;
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class CublasHandle {

int32_t cinn_get_value_in_cuda_kernel_args(void *v_args, int idx) {
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
return args[idx].operator int32_t();
return args[idx].operator int64_t();
}

void cinn_call_cuda_kernel(void *kernel_fn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#if defined(PADDLE_WITH_CUDA)
#include "paddle/cinn/runtime/cinn_runtime.h"
#endif
PD_DECLARE_bool(cinn_bucket_compile);

namespace paddle {
namespace framework {
Expand All @@ -50,6 +51,9 @@ class CinnJitInstruction::FnPtrImpl {
for (const auto& int_arg_mp : cinn_kernel_info_.int_args_map) {
func_args_.emplace_back(kernel_args[int_arg_mp.second.arg_idx]->dims().at(
int_arg_mp.second.dim_idx));
func_args_.emplace_back(static_cast<int64_t>(
kernel_args[int_arg_mp.second.arg_idx]->dims().at(
int_arg_mp.second.dim_idx)));
}

// 3. Launch host kernel
Expand Down Expand Up @@ -100,13 +104,13 @@ CinnJitInstruction::CinnJitInstruction(

tensor_args_.push_back(tensor);

out_tensor_ = tensor;

auto alloc_tensor_type =
result.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
tensor->set_type(
paddle::dialect::TransToPhiDataType(alloc_tensor_type.dtype()));
tensor->Resize(alloc_tensor_type.dims());
if (!FLAGS_cinn_bucket_compile) {
auto alloc_tensor_type =
result.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
tensor->set_type(
paddle::dialect::TransToPhiDataType(alloc_tensor_type.dtype()));
tensor->Resize(alloc_tensor_type.dims());
}
}
}

Expand All @@ -115,7 +119,14 @@ void CinnJitInstruction::Run() {
auto gpu_ctx = static_cast<phi::GPUContext*>(dev_ctx_);

auto stream = gpu_ctx->stream();

for (size_t i = 0; i < tensor_args_.size(); ++i) {
// TODO(6clc): template infer shape from tensor_args_[0].
// After supporting symbolic calculation, perfect the code to query shape
// of output tensor
if (FLAGS_cinn_bucket_compile) {
tensor_args_[i]->Resize(tensor_args_[0]->dims());
}
gpu_ctx->Alloc(tensor_args_[i], tensor_args_[i]->dtype());
}

Expand Down
31 changes: 23 additions & 8 deletions test/ir/pir/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,29 @@ if(WITH_GPU)

foreach(cinn_pir_test_name ${CINN_PIR_TEST})
string(REGEX REPLACE ".py" "" cinn_pir_test_name ${cinn_pir_test_name})
add_test(
NAME ${cinn_pir_test_name}
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH}
FLAGS_enable_pir_api=1 FLAGS_prim_all=True ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/${cinn_pir_test_name}.py
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
string(FIND "${cinn_pir_test_name}" "symbolic" index_of_substr)

message(${index_of_substr} "liuchao liuchao" ${cinn_pir_test_name})
if(index_of_substr EQUAL -1)
add_test(
NAME ${cinn_pir_test_name}
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH}
FLAGS_enable_pir_api=1 FLAGS_prim_all=True ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/${cinn_pir_test_name}.py
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
else()
add_test(
NAME ${cinn_pir_test_name}
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH}
FLAGS_cinn_bucket_compile=True FLAGS_enable_pir_api=1
FLAGS_prim_all=True ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/${cinn_pir_test_name}.py
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endif()
set_tests_properties(${cinn_pir_test_name} PROPERTIES LABELS
"RUN_TYPE=CINN")
endforeach()
Expand Down
81 changes: 81 additions & 0 deletions test/ir/pir/cinn/test_cinn_sub_graph_symbolic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022 -> 2023

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,下个pr中修改。

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.static import InputSpec


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


def exp_sub(x):
y = paddle.exp(x)
z = y - x
return z


class CINNSubGraphNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fn = exp_sub

def forward(self, x):
out = self.fn(x)
return out


class TestCinnSubGraphBase(unittest.TestCase):
"""
Test Pir API + @to_static + CINN.
"""

def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.shape = [64, 128]
self.axis = -1
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = False

def eval_symbolic(self, use_cinn):
paddle.seed(2022)
net = CINNSubGraphNet()
input_spec = [InputSpec(shape=[None, 128], dtype='float32')]
net = apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
return out

def test_eval_symolic(self):
cinn_out = self.eval_symbolic(use_cinn=True)
dy_out = self.eval_symbolic(use_cinn=False)
np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


if __name__ == '__main__':
unittest.main()