Skip to content

Commit

Permalink
[IR] Support multi-thread run && delete unused code of new_ir interpr…
Browse files Browse the repository at this point in the history
…eter (#56148)

* add code

* fix bug

* fix bug

* delete unused code

* refine code

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
  • Loading branch information
zhangbo9674 authored Aug 14, 2023
1 parent 982100a commit 2386db8
Show file tree
Hide file tree
Showing 14 changed files with 348 additions and 1,381 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,12 @@ void NewIrDependencyBuilder::BuildDownstreamMap() {
}
}

void NewIrDependencyBuilder::ShareDependencyFrom(
const NewIrDependencyBuilder& src) {
std::tie(op_downstream_map_, op_happens_before_) = src.GetDependency();
is_build_ = true;
}

} // namespace interpreter
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class NewIrDependencyBuilder : public DependencyBuilder {

void BuildDownstreamMap();

void ShareDependencyFrom(const NewIrDependencyBuilder& src);

private:
std::vector<paddle::framework::InstructionBase*> instructions_; // not_owned
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ platform::DeviceType NewIrStreamAnalyzer::GetWaiterType(
}
}

void NewIrStreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) {
void NewIrStreamAnalyzer::ShareEventInfoFrom(const NewIrStreamAnalyzer& src) {
event_info_ = src.GetEventInfo();
is_event_info_build_ = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class NewIrStreamAnalyzer {
platform::DeviceType GetWaiterType(
const paddle::framework::InstructionBase* instr) const;

void ShareEventInfoFrom(const StreamAnalyzer& src);
void ShareEventInfoFrom(const NewIrStreamAnalyzer& src);

std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/framework/new_executor/interpreter_base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@ class InterpreterBaseImpl {
virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;

// NOTE(zhangbo): This interface is only used for temporary testing and only
// for testing during the iteration process of the new IR access actuator
// version. It will be deleted in the future.
virtual paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;

virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0;

virtual void ShareBuildResultsFrom(const InterpreterBaseImpl& src) = 0;
Expand Down Expand Up @@ -107,6 +101,12 @@ class InterpreterBaseImpl {

virtual const interpreter::StreamAnalyzer& GetStreamAnalyzer() const = 0;

virtual const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const = 0;

virtual const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const = 0;

virtual bool IsSharedResultsBuild() const = 0;
};

Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,6 @@ FetchList InterpreterCore::Run(const std::vector<std::string>& feed_names,
return impl_->Run(feed_names, need_fetch);
}

FetchList InterpreterCore::BetaRun(const std::vector<std::string>& feed_names,
bool need_fetch) {
return impl_->BetaRun(feed_names, need_fetch);
}

void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
impl_->ShareWorkQueueFrom(const_cast<InterpreterBaseImpl*>(src->Impl()));
}
Expand Down
3 changes: 0 additions & 3 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ class InterpreterCore {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true);

paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true);

void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src);

void ShareBuildResultsFrom(std::shared_ptr<InterpreterCore> src);
Expand Down
Loading

0 comments on commit 2386db8

Please sign in to comment.