Skip to content

Commit

Permalink
Create new stream for data copy for IOBidning input scenario (#14719)
Browse files Browse the repository at this point in the history
### Description
Create new stream for data copy for IOBidning input scenario



### Motivation and Context
Previously in bindInput(), a nullptr Stream is passed to copy data cross
device. This caused the default stream is used thus hurt the
performance.
This PR is to fix #14484

---------

Co-authored-by: Lei Cao <leca@microsoft.com>
  • Loading branch information
2 people authored and PatriceVignola committed Feb 22, 2023
1 parent 999747d commit 71e7e35
Showing 1 changed file with 58 additions and 43 deletions.
101 changes: 58 additions & 43 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,48 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
return Status::OK();
}

#ifdef ORT_ENABLE_STREAM
struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(
const SessionState& session_state) : session_state_(session_state),
p_(session_state.AcquireDeviceStreamCollection()) {
}

~DeviceStreamCollectionHolder() {
if (p_) {
session_state_.RecycleDeviceStreamCollection(std::move(p_));
}
}

const SessionState& session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};

static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
// TODO: in theory, we should make current subgraph's stream depends on parent stream.
// but in current code structure, it causing issues with the resource sharing and stream
// lifetime. it also may cause additional cost of stream sync for single stream case.
// In first phase, let's just put all the subgraph execution on the parent stream.
for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) {
auto* stream = device_stream_collection.GetStream(i);
if (stream) {
// if current logic stream is not on the same EP instance as parent stream
// and the EP instance does have async streams (not EP like CPU)
// throw error as we don't have the code to setup the dependency at this moment.
if (stream->GetDevice() != parent_stream->GetDevice()) {
ORT_THROW("Subgraph has nodes running on device: ", stream->GetDevice().Type(),
" while parent graph node running on device: ", parent_stream->GetDevice().Type(),
", this is not supported yet.");
}
device_stream_collection.SetDeviceStream(i, parent_stream);
}
}
}
}
#endif

// public method to do a single copy. used by external partners
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue) {
Expand All @@ -521,8 +563,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
copy_info.source_device = orig_mlvalue.Get<Tensor>().Location().device;
#endif

Stream* device_stream = nullptr;
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
if (device_stream_collection_holder.p_ != nullptr) {
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
gsl::span<Stream*> streams = device_stream_collection->GetStreams();
for (Stream* stream : streams) {
if (stream && stream->GetDevice().Type() != OrtDevice::CPU) {
device_stream = stream;
break;
}
}
}
#endif

// copy_info.target_device is not set leaving to be equal to CPU.
return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue, nullptr);
return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue, device_stream);
}

static common::Status CopyOutputsAcrossDevices(const SessionState& session_state,
Expand Down Expand Up @@ -561,48 +618,6 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
return Status::OK();
}

#ifdef ORT_ENABLE_STREAM
struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(
const SessionState& session_state) : session_state_(session_state),
p_(session_state.AcquireDeviceStreamCollection()) {
}

~DeviceStreamCollectionHolder() {
if (p_) {
session_state_.RecycleDeviceStreamCollection(std::move(p_));
}
}

const SessionState& session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};

static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
// TODO: in theory, we should make current subgraph's stream depends on parent stream.
// but in current code structure, it causing issues with the resource sharing and stream
// lifetime. it also may cause additional cost of stream sync for single stream case.
// In first phase, let's just put all the subgraph execution on the parent stream.
for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) {
auto* stream = device_stream_collection.GetStream(i);
if (stream) {
// if current logic stream is not on the same EP instance as parent stream
// and the EP instance does have async streams (not EP like CPU)
// throw error as we don't have the code to setup the dependency at this moment.
if (stream->GetDevice() != parent_stream->GetDevice()) {
ORT_THROW("Subgraph has nodes running on device: ", stream->GetDevice().Type(),
" while parent graph node running on device: ", parent_stream->GetDevice().Type(),
", this is not supported yet.");
}
device_stream_collection.SetDeviceStream(i, parent_stream);
}
}
}
}
#endif

static common::Status
ExecuteGraphImpl(const SessionState& session_state,
const FeedsFetchesManager& feeds_fetches_manager,
Expand Down

0 comments on commit 71e7e35

Please sign in to comment.