Skip to content

Commit

Permalink
fix VisitVariable
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Apr 19, 2018
1 parent fbb75c6 commit 0357128
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/framework/details/broadcast_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ namespace framework {
namespace details {

struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;

public:
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);

Expand All @@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {

protected:
void RunImpl() override;

void WaitInputVarGenerated(const VarHandle &in_var);
};

private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
};
} // namespace details
} // namespace framework
} // namespace paddle
16 changes: 8 additions & 8 deletions paddle/fluid/framework/details/variable_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ namespace paddle {
namespace framework {
namespace details {
template <typename Func>
static void VisitVariable(Variable* var, Func func) {
static void VisitVariable(Variable* var, Func* func) {
if (var->IsType<LoDTensor>()) {
func(var->GetMutable<LoDTensor>());
(*func)(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
func(var->GetMutable<SelectedRows>());
(*func)(var->GetMutable<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var->Type().name());
}
}

template <typename Func>
static void VisitVariable(const Variable& var, Func func) {
static void VisitVariable(const Variable& var, Func* func) {
if (var.IsType<LoDTensor>()) {
func(var.Get<LoDTensor>());
(*func)(var.Get<LoDTensor>());
} else if (var.IsType<SelectedRows>()) {
func(var.Get<SelectedRows>());
(*func)(var.Get<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var.Type().name());
}
Expand All @@ -56,7 +56,7 @@ struct TensorVisitor {

Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
TensorVisitor vistor;
VisitVariable(var, vistor);
VisitVariable(var, &vistor);
return *vistor.result_;
}

Expand Down Expand Up @@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor {

void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
ShareDimsAndLoDVisitor visitor{trg};
VisitVariable(src, visitor);
VisitVariable(src, &visitor);
}

} // namespace details
Expand Down

0 comments on commit 0357128

Please sign in to comment.