Skip to content

Commit

Permalink
support view strategy in eager_fluid state (#40830)
Browse files Browse the repository at this point in the history
* support view strategy in eager_fluid state

* little change

* little change

* optimize unittest

* fix
  • Loading branch information
pangyoki authored Mar 31, 2022
1 parent 56493c9 commit 2f1c1ae
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 34 deletions.
23 changes: 22 additions & 1 deletion paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1707,10 +1707,31 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
}
}
generated_function_body += "\n";

VLOG(6) << "Generated Outs Map";

// [Generation] Apply View Strategy (Tensor)
if (inplace_map.empty() && view_op_map.count(op_type)) {
const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT =
" if (ins.count(\"%s\") && outs.count(\"%s\")) {\n"
" egr::EagerUtils::HandleViewBetweenInputAndOutput(ins[\"%s\"][0], "
"outs[\"%s\"][0]);\n"
" };\n";

std::string view_strategy_str = "";
std::string viwe_input_name = view_op_map[op_type].first;
std::string viwe_output_name = view_op_map[op_type].second;
view_strategy_str += paddle::string::Sprintf(
HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT, viwe_input_name, viwe_output_name,
viwe_input_name, viwe_output_name);

generated_function_body += view_strategy_str;
generated_function_body += "\n";

VLOG(6) << "Generated View Strategy";
}
generated_function_body += "\n";

// [Generation] Get Attrs
dygraph_function_args_str +=
", const paddle::framework::AttributeMap& attr_map";
Expand Down
27 changes: 27 additions & 0 deletions paddle/fluid/eager/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,33 @@ std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars(
return res;
}

void EagerUtils::HandleViewBetweenInputAndOutput(
const std::shared_ptr<EagerVariable>& input_var,
const std::shared_ptr<EagerVariable>& view_output_var) {
PADDLE_ENFORCE_EQ(
input_var->Var().IsInitialized(), true,
paddle::platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", input_var->name()));

if (phi::DenseTensor::classof(input_var->GetTensorBase().get())) {
auto input_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input_var->GetTensorBase());
PADDLE_ENFORCE_EQ(
input_dense_tensor->IsInitialized(), true,
paddle::platform::errors::InvalidArgument(
"DenseTensor %s has not been initialized!", input_var->name()));

auto* view_output_tensor =
view_output_var->MutableVar()->GetMutable<phi::DenseTensor>();
view_output_tensor->ShareBufferWith(*input_dense_tensor);
view_output_tensor->ShareInplaceVersionCounterWith(*input_dense_tensor);

VLOG(3) << "Perform View between Output Var(" << view_output_var->name()
<< ") and Input Var(" << input_var->name()
<< "), share allocation and inplace version.";
}
}

void EagerUtils::ModifyInplaceInput(
const std::shared_ptr<EagerVariable>& inplace_variable,
paddle::experimental::Tensor* inplace_tensor) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/eager/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ class EagerUtils {
}
}

// View Strategy
static void HandleViewBetweenInputAndOutput(
const std::shared_ptr<EagerVariable>& input_var,
const std::shared_ptr<EagerVariable>& view_output_var);

// TensorWrapper Utils
static paddle::experimental::Tensor RecoverTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,8 @@
# View APIs include: `squeeze`, `unsqueeze`, `reshape`, `flatten`, `detach`
class TestDygraphViewReuseAllocation(unittest.TestCase):
def setUp(self):
self.set_flag_to_test_eager_mode()
self.init_shape()

# some op don't suport eager_final_state in temporary
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False

def init_shape(self):
self.input_shape = [2, 3, 1]
self.output_shape = [2, 3]
Expand All @@ -46,10 +41,7 @@ def view_api_processing(self, var):
def func_test_view_api(self):
var = paddle.rand(self.input_shape)
view_var = self.view_api_processing(var)
# setitem don't support inplace in temporary.
# replace setitem with inplace exp_ in temporary.
# view_var[0] = 2.
view_var.exp_()
view_var[0] = 2.
self.assertEqual(var.shape, self.input_shape)
self.assertEqual(view_var.shape, self.output_shape)

Expand All @@ -58,9 +50,8 @@ def func_test_view_api(self):
self.assertTrue(np.array_equal(var_numpy, view_var_numpy))

def test_view_api(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_view_api()
with _test_eager_guard():
self.func_test_view_api()
self.func_test_view_api()

def func_test_forward_version(self):
Expand All @@ -69,23 +60,20 @@ def func_test_forward_version(self):
view_var = self.view_api_processing(var)
self.assertEqual(view_var.inplace_version, 0)

# var[0] = 2.
var.exp_()
var[0] = 2.
self.assertEqual(var.inplace_version, 1)
self.assertEqual(view_var.inplace_version, 1)

view_var_2 = self.view_api_processing(var)
self.assertEqual(view_var_2.inplace_version, 1)

# var[0] = 3.
var.exp_()
var[0] = 3.
self.assertEqual(view_var.inplace_version, 2)
self.assertEqual(view_var_2.inplace_version, 2)

def test_forward_version(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_forward_version()
with _test_eager_guard():
self.func_test_forward_version()
self.func_test_forward_version()

def func_test_backward_error(self):
Expand All @@ -100,8 +88,7 @@ def func_test_backward_error(self):
# Here, the gradient computation will use the value of var_b
var_c = var_b**2
view_var_b = self.view_api_processing(var_b)
# view_var_b[0] = 2. # var_b is modified inplace
view_var_b.exp_()
view_var_b[0] = 2. # var_b is modified inplace

loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
Expand All @@ -118,16 +105,12 @@ def func_test_backward_error(self):
loss.backward()

def test_backward_error(self):
if self.flag_test_eager_mode:
with _test_eager_guard():
self.func_test_backward_error()
with _test_eager_guard():
self.func_test_backward_error()
self.func_test_backward_error()


class TestUnsqueezeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False

def init_shape(self):
self.input_shape = [2, 3]
self.output_shape = [2, 3, 1]
Expand All @@ -137,9 +120,6 @@ def view_api_processing(self, var):


class TestReshapeDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = True

def init_shape(self):
self.input_shape = [3, 4]
self.output_shape = [2, 2, 3]
Expand All @@ -149,9 +129,6 @@ def view_api_processing(self, var):


class TestFlattenDygraphViewReuseAllocation(TestDygraphViewReuseAllocation):
def set_flag_to_test_eager_mode(self):
self.flag_test_eager_mode = False

def init_shape(self):
self.input_shape = [3, 4]
self.output_shape = [12]
Expand Down

0 comments on commit 2f1c1ae

Please sign in to comment.