Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create new stream for data copy for IOBidning input scenario #14719

Merged
merged 2 commits into from
Feb 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 58 additions & 43 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,48 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
return Status::OK();
}

#ifdef ORT_ENABLE_STREAM
struct DeviceStreamCollectionHolder {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeviceStreamCollectionHolder

why we need a holder here?
the collection might get "recycled" accidentally?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lei just move this func to a new place. the original implementation is to use this holder as a guard to make sure if any error happened we return the device streams to session state correctly.

DeviceStreamCollectionHolder(
const SessionState& session_state) : session_state_(session_state),
p_(session_state.AcquireDeviceStreamCollection()) {
}

~DeviceStreamCollectionHolder() {
if (p_) {
session_state_.RecycleDeviceStreamCollection(std::move(p_));
}
}

const SessionState& session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};

static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
// TODO: in theory, we should make current subgraph's stream depends on parent stream.
// but in current code structure, it causing issues with the resource sharing and stream
// lifetime. it also may cause additional cost of stream sync for single stream case.
// In first phase, let's just put all the subgraph execution on the parent stream.
for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) {
auto* stream = device_stream_collection.GetStream(i);
if (stream) {
// if current logic stream is not on the same EP instance as parent stream
// and the EP instance does have async streams (not EP like CPU)
// throw error as we don't have the code to setup the dependency at this moment.
if (stream->GetDevice() != parent_stream->GetDevice()) {
ORT_THROW("Subgraph has nodes running on device: ", stream->GetDevice().Type(),
" while parent graph node running on device: ", parent_stream->GetDevice().Type(),
", this is not supported yet.");
}
device_stream_collection.SetDeviceStream(i, parent_stream);
}
}
}
}
#endif

// public method to do a single copy. used by external partners
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue) {
Expand All @@ -526,8 +568,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
copy_info.source_device = orig_mlvalue.Get<Tensor>().Location().device;
#endif

Stream* device_stream = nullptr;
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
if (device_stream_collection_holder.p_ != nullptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (device_stream_collection_holder.p_ != nullptr)

device_stream_collection_holder.p_

DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
gsl::span<Stream*> streams = device_stream_collection->GetStreams();
for (Stream* stream : streams) {
if (stream && stream->GetDevice().Type() != OrtDevice::CPU) {
device_stream = stream;
break;
}
}
}
#endif

// copy_info.target_device is not set leaving to be equal to CPU.
return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue, nullptr);
return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue, device_stream);
}

static common::Status CopyOutputsAcrossDevices(const SessionState& session_state,
Expand Down Expand Up @@ -566,48 +623,6 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
return Status::OK();
}

#ifdef ORT_ENABLE_STREAM
struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(
const SessionState& session_state) : session_state_(session_state),
p_(session_state.AcquireDeviceStreamCollection()) {
}

~DeviceStreamCollectionHolder() {
if (p_) {
session_state_.RecycleDeviceStreamCollection(std::move(p_));
}
}

const SessionState& session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};

static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
// TODO: in theory, we should make current subgraph's stream depends on parent stream.
// but in current code structure, it causing issues with the resource sharing and stream
// lifetime. it also may cause additional cost of stream sync for single stream case.
// In first phase, let's just put all the subgraph execution on the parent stream.
for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) {
auto* stream = device_stream_collection.GetStream(i);
if (stream) {
// if current logic stream is not on the same EP instance as parent stream
// and the EP instance does have async streams (not EP like CPU)
// throw error as we don't have the code to setup the dependency at this moment.
if (stream->GetDevice() != parent_stream->GetDevice()) {
ORT_THROW("Subgraph has nodes running on device: ", stream->GetDevice().Type(),
" while parent graph node running on device: ", parent_stream->GetDevice().Type(),
", this is not supported yet.");
}
device_stream_collection.SetDeviceStream(i, parent_stream);
}
}
}
}
#endif

static common::Status
ExecuteGraphImpl(const SessionState& session_state,
const FeedsFetchesManager& feeds_fetches_manager,
Expand Down