Skip to content

Commit

Permalink
merge and solve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 committed Sep 13, 2023
1 parent 7575338 commit 18c08f8
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 32 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ static auto GetNameFromValue(const ::pir::Block *block,
if (op->name() == "pd_op.data") {
name =
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString();
value2name[op->results()[0].value_impl()] = name;
value2name[op->results()[0].Value::impl()] = name;
} else if (op->name() == "builtin.set_parameter") {
name = op->attributes()
.at("parameter_name")
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ Operation *BuildOpFrom(
cloned_op->results().begin(),
std::back_inserter(tmp), // NOLINT, just a placeholder.
[&value_map](const OpResult &a, const OpResult &b) { // NOLINT
value_map[a.value_impl()] = b.value_impl();
value_map[a.Value::impl()] = b.Value::impl();
return 1;
});
return cloned_op;
Expand Down Expand Up @@ -623,7 +623,7 @@ std::vector<pir::Value> AnalysisMiddleVariable(
forward_range,
[&middle_values, &backward_inputs, &x_or_param](Operation *op) {
for (auto &t : op->results()) {
auto v = Value(t.value_impl());
auto v = Value(t.Value::impl());
if (backward_inputs.count(v) && !x_or_param.count(v))
middle_values.push_back(v);
}
Expand Down Expand Up @@ -667,7 +667,7 @@ SplitedResult ForwardBackwardSplit(

auto op_result_to_value = [](const pir::OpResult &r) {
if (r.impl() == nullptr) return Value(nullptr);
return Value(r.value_impl());
return Value(r.Value::impl());
};

std::transform(op_result_forward_inputs.begin(),
Expand Down Expand Up @@ -734,7 +734,7 @@ SplitedResult ForwardBackwardSplit(
dtype,
place);
counter += 1;
backward_value_map[v] = op->results()[0].value_impl();
backward_value_map[v] = op->results()[0].Value::impl();
};

auto create_output_fn_forward = [&ctx,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ void CastPyArg2AttrValues(PyObject* obj,
if (opresult->impl() == nullptr) {
results.emplace_back(pir::Value(nullptr));
} else {
results.emplace_back(pir::Value(opresult->value_impl()));
results.emplace_back(pir::Value(opresult->Value::impl()));
}
}
} else {
Expand Down
26 changes: 0 additions & 26 deletions test/ir/new_ir/test_new_ir_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def func(x):
x.stop_gradient = False
y.stop_gradient = False
ans = func(x)
print("Ans: ", ans)
print(static_func.get_concrete_program(x)[1].train_program)
out = static_func(x)

np.testing.assert_allclose(
Expand Down Expand Up @@ -69,30 +67,6 @@ def func(x):
)


# class TestDy2staticNewIR2(unittest.TestCase):
# def test_basic_layer(self):
# class SimpleNet(paddle.nn.Layer):
# def __init__(self):
# super().__init__()
# self.linear = paddle.nn.Linear(10, 10)

# def forward(self, x):
# return self.linear(x)

# net = SimpleNet()
# x = paddle.randn((10, 10))
# x.stop_gradient = False
# ans = net(x)
# print("Ans: ", ans)
# net = paddle.jit.to_static(net)
# print(net.forward.get_concrete_program(x)[1].train_program)
# out = net(x)

# np.testing.assert_allclose(
# out.numpy(), ans.numpy(), rtol=1e-05, atol=1e-8
# )


class TestDy2staticNewIR3(unittest.TestCase):
def test_complex_layer(self):
def output_pure_func(x, y):
Expand Down

0 comments on commit 18c08f8

Please sign in to comment.