Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support view strategy in eager_fluid state #40830

Merged
merged 7 commits into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1705,10 +1705,31 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
}
}
generated_function_body += "\n";

VLOG(6) << "Generated Outs Map";

// [Generation] Apply View Strategy (Tensor)
if (inplace_map.empty() && view_op_map.count(op_type)) {
const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT =
" if (ins.count(\"%s\") && outs.count(\"%s\")) {\n"
" egr::EagerUtils::HandleViewBetweenInputAndOutput(ins[\"%s\"][0], "
"outs[\"%s\"][0]);\n"
" };\n";

std::string view_strategy_str = "";
std::string viwe_input_name = view_op_map[op_type].first;
std::string viwe_output_name = view_op_map[op_type].second;
view_strategy_str += paddle::string::Sprintf(
HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT, viwe_input_name, viwe_output_name,
viwe_input_name, viwe_output_name);

generated_function_body += view_strategy_str;
generated_function_body += "\n";

VLOG(6) << "Generated View Strategy";
}
generated_function_body += "\n";

// [Generation] Get Attrs
dygraph_function_args_str +=
", const paddle::framework::AttributeMap& attr_map";
Expand Down
27 changes: 27 additions & 0 deletions paddle/fluid/eager/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,33 @@ std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars(
return res;
}

void EagerUtils::HandleViewBetweenInputAndOutput(
const std::shared_ptr<EagerVariable>& input_var,
const std::shared_ptr<EagerVariable>& view_output_var) {
PADDLE_ENFORCE_EQ(
input_var->Var().IsInitialized(), true,
paddle::platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", input_var->name()));

if (phi::DenseTensor::classof(input_var->GetTensorBase().get())) {
auto input_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input_var->GetTensorBase());
PADDLE_ENFORCE_EQ(
input_dense_tensor->IsInitialized(), true,
paddle::platform::errors::InvalidArgument(
"DenseTensor %s has not been initialized!", input_var->name()));

auto* view_output_tensor =
view_output_var->MutableVar()->GetMutable<phi::DenseTensor>();
view_output_tensor->ShareBufferWith(*input_dense_tensor);
view_output_tensor->ShareInplaceVersionCounterWith(*input_dense_tensor);

VLOG(3) << "Perform View between Output Var(" << view_output_var->name()
<< ") and Input Var(" << input_var->name()
<< "), share allocation and inplace version.";
}
}

void EagerUtils::ModifyInplaceInput(
const std::shared_ptr<EagerVariable>& inplace_variable,
paddle::experimental::Tensor* inplace_tensor) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/eager/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ class EagerUtils {
}
}

// View Strategy
static void HandleViewBetweenInputAndOutput(
const std::shared_ptr<EagerVariable>& input_var,
const std::shared_ptr<EagerVariable>& view_output_var);

// TensorWrapper Utils
static paddle::experimental::Tensor RecoverTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,8 @@
# View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach`
class TestDygraphViewReuseAllocation(unittest.TestCase):
def setUp(self):
self.set_flag_to_test_eager_mode()
self.init_shape()

# some op don't suport eager_final_state in temporary
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False

def init_shape(self):
self.input_shape = [2, 3, 1]
self.output_shape = [2, 3]
Expand All @@ -46,10 +41,7 @@ def view_api_processing(self, var):
def func_test_view_api(self):
var = paddle.rand(self.input_shape)
view_var = self.view_api_processing(var)
# setitem don't support inplace in temporary.
# replace setitem with inplace exp_ in temporary.
# view_var[0] = 2.
view_var.exp_()
view_var[0] = 2.
self.assertEqual(var.shape, self.input_shape)
self.assertEqual(view_var.shape, self.output_shape)

Expand All @@ -58,9 +50,8 @@ def func_test_view_api(self):
self.assertTrue(np.array_equal(var_numpy, view_var_numpy))

def test_view_api(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_view_api()
with _test_eager_guard():
self.func_test_view_api()
self.func_test_view_api()

def func_test_forward_version(self):
Expand All @@ -69,23 +60,20 @@ def func_test_forward_version(self):
view_var = self.view_api_processing(var)
self.assertEqual(view_var.inplace_version, 0)

# var[0] = 2.
var.exp_()
var[0] = 2.
self.assertEqual(var.inplace_version, 1)
self.assertEqual(view_var.inplace_version, 1)

view_var_2 = self.view_api_processing(var)
self.assertEqual(view_var_2.inplace_version, 1)

# var[0] = 3.
var.exp_()
var[0] = 3.
self.assertEqual(view_var.inplace_version, 2)
self.assertEqual(view_var_2.inplace_version, 2)

def test_forward_version(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_forward_version()
with _test_eager_guard():
self.func_test_forward_version()
self.func_test_forward_version()

def func_test_backward_error(self):
Expand All @@ -100,8 +88,7 @@ def func_test_backward_error(self):
# Here, the gradient computation will use the value of var_b
var_c = var_b**2
view_var_b = self.view_api_processing(var_b)
# view_var_b[0] = 2. # var_b is modified inplace
view_var_b.exp_()
view_var_b[0] = 2. # var_b is modified inplace

loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
Expand All @@ -118,16 +105,12 @@ def func_test_backward_error(self):
loss.backward()

def test_backward_error(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_backward_error()
with _test_eager_guard():
self.func_test_backward_error()
self.func_test_backward_error()


class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False

def init_shape(self):
self.input_shape = [2, 3]
self.output_shape = [2, 3, 1]
Expand All @@ -137,9 +120,6 @@ def view_api_processing(self, var):


class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = True

def init_shape(self):
self.input_shape = [3, 4]
self.output_shape = [2, 2, 3]
Expand All @@ -149,9 +129,6 @@ def view_api_processing(self, var):


class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False

def init_shape(self):
self.input_shape = [3, 4]
self.output_shape = [12]
Expand Down