Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move device lock before the execution instead of tensor gathering #3457

Merged
merged 20 commits into from
Apr 28, 2022

Conversation

JackCaoG
Copy link
Collaborator

This pr ports pytorch/pytorch#74864 to pytorch/xla and should help #3434. This will need some more testing but the concept seem right to me.

I think this optimization is based on the assumption that even if previous execution does not finish, we still truncate the pending IR of the Tensor to become a BackendData with the correct shape. The actual handle assignment to the BackendData wont' happen until the previous execution finish but that would be OK since we only need true data handle in next execution.

There is a catch that PlaceHolder and Ir truncating only happened when force_xla_data is true. AFAICT GetTensor will have force_xla_data == false and Synctensor will have force_xla_data == true. I think this is fine. If we consider a case of

a = torch.tensor(1)
b = torch.tensor(1)
c = a + b
print (c) ---> GetTensor, won't truncate c's IR
d = c + a
xm.mark_step() --> Synctensor, will truncate IR

we expect print not changing the IR of c anyway so there is no need to wait for that execution to finish before we finalize our second graph.

@JackCaoG JackCaoG requested a review from wconstab March 30, 2022 00:06
@wconstab
Copy link
Collaborator

There is a catch that PlaceHolder and Ir truncating only happened when force_xla_data is true. AFAICT GetTensor will have force_xla_data == false and Synctensor will have force_xla_data == true. I think this is fine. If we consider a case of

Regarding this, I think the behavior is OK/correct but the naming could be improved so it is more obvious from the API names what is happening.

Copy link
Collaborator

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although I am also waiting to hear your results to confirm correctness/performance across your benchmark workloads

@JackCaoG
Copy link
Collaborator Author

@wconstab @desertfire CPU and GPU both failed at test TestParallelTensorMNIST. This is likely a real error, I will take a look tmr. Let's hold on merging ltc pr for now.

@JackCaoG
Copy link
Collaborator Author

@desertfire the failure was coming from opbyop executor, I fixed by adding the barrier. However there is one change I think we need to apply to ltc code, we need to check if coll is not empty since it might return early in https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/lazy_graph_executor.cpp#L615. Accessing device in this case will segfault.

@desertfire
Copy link

@desertfire the failure was coming from opbyop executor, I fixed by adding the barrier. However there is one change I think we need to apply to ltc code, we need to check if coll is not empty since it might return early in https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/lazy_graph_executor.cpp#L615. Accessing device in this case will segfault.

Make sense. Let me update the LTC change.

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Apr 1, 2022

@desertfire I found any bug that I don't have a very good solution. in RunPostOrder we actually access the real data handle

device_data->data()->GetOpaqueHandle();

This step is to make sure po_data.parameters_data has unique xla_data. The issue here is that true data handle won't be assigned until the previous execution finish in

async->tensors_data[i]->Assign(*results[i]);

@desertfire
Copy link

GetOpaqueHandle

Hmm, the LTC behavior diverges here.
GetHandle still works, because CreateDataPlaceHolder would make sure that TSData is there.

If this change doesn't work for XLA, we can either add an option for control or derive a subclass of LazyGraphExecutor. Given how small this change is, I would say add an option and XLA can override its value. How do you think?

@wconstab
Copy link
Collaborator

wconstab commented Apr 1, 2022

If this change doesn't work for XLA, we can either add an option for control or derive a subclass of LazyGraphExecutor. Given how small this change is, I would say add an option and XLA can override its value. How do you think?

Well, we really want a solution for XLA since AWS is the first reporter of this issue. (Even though you're right that we could land it for TS only and it would strictly be an improvement).

This step is to make sure po_data.parameters_data has unique xla_data. The issue here is that true data handle won't be assigned until the previous execution finish

Hmm, you don't have a way before launching the computation to know how many data handles will be returned, or of them will be unique?

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Apr 1, 2022

@desertfire the placeholder does not have a valid handle for xla at least, it will get filled later. the RunPostOrder want to access the real handle is the issue here.

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Apr 1, 2022

Let me run some experiments. I think when two graph does not have any dependency, we can run RunPostOrder and access the real handle.

for example

a = torch.tensor(1)
b = torch.tensor(1)
c = a + b
xm.mark_step() 
d = c + a.           -----> c's handle is not ready yet, we need to block RunPostOrder
xm.mark_step() 
a = torch.tensor(1)
b = torch.tensor(1)
c = a + b
xm.mark_step() 
d = a*x.            -------> we can proceed
xm.mark_step() 

@wconstab the graph is being run in a post-order fashion, same tensor might be the input to multiple result. For example

a = torch.tensor(1)
b = torch.tensor(1)
c = a * 2 + b
d = a * 4

In this case a will show up in graph twice.. I think.. need to verify. I think checking the device handle is the easiest way since if the device handle is the same, we know they are the same (from runtime point of view). I don't know if there is ever a case that two tensor have the same handle. If this is never the case, we can probably do the dedup at Tensor level/Data level not at the handle level.

@miladm miladm self-assigned this Apr 1, 2022
@miladm
Copy link
Collaborator

miladm commented Apr 7, 2022

Update:

I am trying the above test scenarios from @JackCaoG + the tests in test_operations.py to verify this implementation.

It turns out this implementation leads to a couple of segfault scenarios.

  1. A source of failure is calling GetOpaqueHandle under RunPostOrder leads to a segfault. A solution for this is we call TensorCollectionBarrier just before GetOpaqueHandle call in RunPostOrder.
  2. Another source of failure is caused by moving TensorCollectionBarrier past SaveTensorsGraphInfo in SyncTensorsGraphInternal. I am pin pointing the line that causes the error.

[WIP]

@miladm
Copy link
Collaborator

miladm commented Apr 8, 2022

It turns out when coll.indices.empty() is true, XLATensor::SyncTensorsGraphInternal returns a nullptr. Before returning nullptr, we need to call TensorCollectionBarrier(&coll) (ref).

The current implementation calls TensorCollectionBarrier() in RunPostOrder() before this line. I have a working implementation that passes all our existing tests.

The OpByOp code path runs into similar seg-faults I am currently addressing.

@miladm miladm added enhancement New feature or request perf labels Apr 12, 2022
@wconstab
Copy link
Collaborator

@miladm just catching up, sounds like you fixed at least one of the major failures- is there still something else failing or is it just a matter of cleaning up the PR at this point?

@miladm
Copy link
Collaborator

miladm commented Apr 18, 2022

To compare the performance gain of this implementation, I ran it against ResNet50 and MNIST models. Neither model shows performance gains compared to the current PyTorch/XLA master commit. This is rather surprising because it suggests there has been no independent graphs that would run back to back. I'd like to hear other's thoughts on this. CC @JackCaoG @wconstab

In light of the above observation, we can consider dropping the conditional barriers, and replace them with an implementation that doesn't require dependency tracking via GetOpaqueHandle API. If so, to confirm we will have the expected gain in advance, I suggest we discuss a test scheme that verifies this solution closes the expected 7% MNIST performance gap. Please share your thoughts.

@wconstab
Copy link
Collaborator

To compare the performance gain of this implementation, I ran it against ResNet50 and MNIST models.

Are you able to confirm with profiling that either of these models on your HW are showing bubbles in-between kernel executions on the accelerator due to this run-post-order locking issue?

@JackCaoG
Copy link
Collaborator Author

This is rather surprising because it suggests there has been no independent graphs that would run back to back

I think this is an incorrect assumption. It is possible that this RunPostOrder and FetchTensorData lag does not impact some model as much. You can add print statement to the conditional lock to confirm, or do a profile.

@miladm
Copy link
Collaborator

miladm commented Apr 18, 2022

Yup, I will extract the profile to confirm if the same performance observation made on GPU (here) holds true for TPU devices.

@miladm
Copy link
Collaborator

miladm commented Apr 19, 2022

The profiling numbers for this branch and master are attached below.

  • Turns out the gap is similar between the two branch runs. Each ExecuteComputation takes around 4.4ms.
  • The profiles suggest the observed gap on TPU is about >2x wider than GPU (see here).
  • The profiles "seem to suggest" CollectSyncTensors takes less relative time on the TPU than GPU.
  • For the TPU, a MarkStep takes around 0.4ms.

master branch profile:

image

move_sync_lock branch profile:

I was not able to capture the ExecutorComputation diagram as it was plotted many rows away from this profile row. Though, as expected, it started immediately after StepMarker.

image

@wconstab
Copy link
Collaborator

Turns out the gap is similar between the two branch runs. Each ExecuteComputation takes around 4.4ms.
The profiles suggest the observed gap on TPU is about >2x wider than GPU (#3434 (comment)).
The profiles "seem to suggest" CollectSyncTensors takes less relative time on the TPU than GPU.

Are you concluding that the bug / fix are not observable on TPU because of this relative time difference, therefore we have to test the fix on GPU? Or what?

For the TPU, a MarkStep takes around 0.4ms.
I was not able to capture the ExecutorComputation diagram as it was plotted many rows away from this profile row. Though, as expected, it started immediately after StepMarker.

You are saying that on TPU, the 'ExecuteComputation' happens back to back both with/without the fix? I think this indicates that for some reason the bug was never present on TPU, but I can't observe this on the screenshot since I only see the second 'ExecuteComputation' not the previous one.

@miladm
Copy link
Collaborator

miladm commented Apr 20, 2022

Update on the results shared above:

  • Turns out the reason I observed no perf difference between master and this branch was because all ExecuteComputation triggers completed well before the next SyncTensorsGraphsInternal call. See below for fix.

To verify this PR gives the intended performance gain + respond to @will-cromar's questions above:

  • Raising --batch_size=2048 on MNIST: this shows ExecuteComputation overlap with the next SyncTensorsGraphsInternal call. For this training batch size, I observe 8.4% speedup.
  • Raising --batch_size=8192 on MNIST: I observe 16.3% speedup.

As intended, the conditional barrier enables an ExecuteComputation call to overlap with the next graph processing cycle; an example profile snapshot is included below.

image

@miladm miladm merged commit 5d1c421 into master Apr 28, 2022
@miladm miladm deleted the move_sync_lock branch April 28, 2022 05:09
desertfire added a commit to pytorch/pytorch that referenced this pull request May 6, 2022
desertfire added a commit to pytorch/pytorch that referenced this pull request May 6, 2022
Summary: cherry-picking pytorch/xla#3457

ghstack-source-id: 06503402cee5d93e58bf6fbe97c0b00eb8d0f81c
Pull Request resolved: #76974
desertfire added a commit to pytorch/pytorch that referenced this pull request May 9, 2022
… to a later place"

Summary: cherry-picking pytorch/xla#3457

[ghstack-poisoned]
desertfire added a commit to pytorch/pytorch that referenced this pull request May 9, 2022
desertfire added a commit to pytorch/pytorch that referenced this pull request May 9, 2022
… to a later place"

Summary: cherry-picking pytorch/xla#3457

[ghstack-poisoned]
desertfire added a commit to pytorch/pytorch that referenced this pull request May 9, 2022
desertfire added a commit to pytorch/pytorch that referenced this pull request May 9, 2022
Summary: cherry-picking pytorch/xla#3457

ghstack-source-id: b69b538fec618b1a0b32d363ac5a44cb45b31916
Pull Request resolved: #76974
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request May 9, 2022
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request May 13, 2022
Summary:
cherry-picking pytorch/xla#3457

Pull Request resolved: #76974

Approved by: https://github.com/wconstab, https://github.com/JackCaoG

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/36150c63a777a7cc854006d5d4bcef011a71671a

Reviewed By: malfet, mikekgfb

Differential Revision: D36269299

Pulled By: desertfire

fbshipit-source-id: a2a793cfdc4538ed6713d73d30015dd9c3c0a841
@philipturner
Copy link
Contributor

philipturner commented Jun 29, 2022

This seems related to s4tf/s4tf#14, which I am trying to fix. Swift for TensorFlow and PyTorch share the same code base for LazyTensor/XLA execution. The pytorch/xla directory third_party/xla_client matches the s4tf/s4tf directory Sources/xla/xla_client. The pytorch/xla directory torch_xla/csrc matches the s4tf/s4tf directory Sources/xla/xla_tensor. I'm having a crash when creating a LazyTensorBarrier (equal to mark.step() in PyTorch). It calls XrtData::HasValue() somewhere and asserts that it isn't a nullptr. Unfortunately, it is a null pointer and I can't trace why that happens. The crash did not appear when compiled against TensorFlow 2.4, but it now crashes when I compile against TensorFlow 2.9.

Do you have any insight into why the crash might have happened? I'm not an expert on XLA. Even if this patch resolves the crash, the crash could pop up somewhere else.

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Jul 1, 2022

@philipturner This might be a bit tricky to debug. I don't see why tf would affect XrtData. My best guess is empty XrtData was created using

XrtData(std::string device, Shape device_shape)
: Data(std::move(device), std::move(device_shape)),
handle_ptr(nullptr) {}

as a placeholder in

xla/torch_xla/csrc/tensor.cpp

Lines 1405 to 1408 in f0efaf5

xla_data =
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder(
tensor_device.toString(), std::move(shape)));
tensor.SetXlaData(xla_data, config.sync_xla_data);

but it never get an assignment in

xla/torch_xla/csrc/tensor.cpp

Lines 1440 to 1445 in f0efaf5

for (size_t i = 0; i < results.size(); ++i) {
if (async->tensors_data[i] != nullptr) {
UnwrapXlaData(async->tensors_data[i])->Assign(*results[i]);
} else {
async->tensors_data[i] = WrapXlaData(std::move(results[i]));
}

That being said.. this part of the logic is purely lazy tensor and XRT related.. not sure why tf will affect this. This might also be a race condition.

@philipturner
Copy link
Contributor

philipturner commented Jul 8, 2022

I synchronized S4TF's x10/xla_client directory up to this commit, although x10/xla_tensor (equivalent to torch_xla/csrc) might be a bit behind. This PR did not fix the crash. In S4TF, we don't need to wrap/unwrap the XLA data because it's always a xla::ComputationClient::DataPtr. I don't have any functions called WrapXlaData or UnwrapXlaData, and I don't think I need any such functions.

I located the third section of code you mentioned above. It's inside XLATensor::ScheduleSyncTensorsGraph. I modified it so it logged the state of tensors during each loop iteration. It's a bit messy, but the print output gets the job done.

      for (size_t i = 0; i < results.size(); ++i) {
        TF_VLOG(0) << "Starter ============";
        if (async->tensors_data[i] != nullptr) {
          TF_VLOG(0) << "Case 1";
          if (async->tensors_data[i]->HasValue()) {
            TF_VLOG(0) << "Case 1.1";
          } else {
            TF_VLOG(0) << "Case 1.2";
          }
        } else {
          TF_VLOG(0) << "Case 2";
          if (async->tensors_data[i]->HasValue()) {
            TF_VLOG(0) << "Case 2.1";
          } else {
            TF_VLOG(0) << "Case 2.2";
          }
        }

        if (async->tensors_data[i] != nullptr) {
          async->tensors_data[i]->Assign(*results[i]);
        } else {
          async->tensors_data[i] = std::move(results[i]);
        }
      }

This is all I found in the tutorial, up to the point where it crashed. It's "Case 1.2" every time - not a single 1.1, 2.1, or 2.2. That means the xla::ComputationClient::DataPtr comes in as a non-null pointer, but without any data. It returns false for HasValue(), then gets assigned some actual data in ->Assign(*results[i]).

Test Case '-[x10Tests.XLATensorTests testLazyTensorBarrier]' started.
2022-07-08 15:19:00.977198: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977209: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977211: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.977213: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977214: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977252: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.977269: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977271: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977273: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.977274: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977276: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977277: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.977279: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977280: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977281: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.977283: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977284: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977285: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.977287: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977288: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977290: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.977292: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1234] Starter ============
2022-07-08 15:19:00.977293: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1236] Case 1
2022-07-08 15:19:00.977295: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:1240] Case 1.2
2022-07-08 15:19:00.984498: F tensorflow/compiler/xla/xla_client/tf_logging.cc:26] tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:585 : Check failed: xla_data->HasValue() 
*** Begin stack trace ***
	tensorflow::CurrentStackTrace()
	swift_xla::XLATensor::GetXlaData()
	swift_xla::XLATensor::ToTensor(bool)
	XLATensor_materialize
...
*** End stack trace ***
Trying to access XLA data while an async operation is in flight: f32[]
error: Exited with signal code 6

For the first/second sections of code you mentioned, we initialize data a bit differently. But the overall result is the same, creating a nullptr for handle_ptr.

Code section 2:

      xla_data = xla::GetX10Device(tensor_device)
                     ->CreateDataPlaceholder(std::move(shape));
      tensor.SetXlaData(xla_data, config.sync_xla_data);

CreateDataPlaceholder is defined in a xrt_computation_client.cc, but we refactored a lot of things two years ago. In PyTorch, the definition does not appear in this C++ file.

class XrtComputationClient::XrtDevice : public ComputationClient::Device {
 public:
  ...
  DataPtr CreateDataPlaceholder(Shape shape) override {
    return std::make_shared<XrtData>(this, std::move(shape));
  }
}

S4TF version of your first code section from the previous comment:

  struct XrtData : public Data {
    XrtData(Device* device, Shape device_shape)
        : Data(device, std::move(device_shape)),
          handle_ptr(nullptr) {}
    XrtData(XrtDevice* device, Shape device_shape, int64_t handle);

    XrtData(XrtDevice* device, Shape device_shape, XrtHandlePtr handle);

This might also be a race condition.

The Swift test testLazyTensorBarrier has failed every single time I've run it (~50 times in total), so it seems like a deterministic crash. I'm doubting whether it's a race condition, because data races are usually non-deterministic. Also, the crash never happens on the old TF/X10 2.4 binary on Linux. I haven't actually tested the v2.4 binary on macOS with XLA/LazyTensorBarrier, but I presume it wouldn't fail.

I'll see if I can undo most of the refactoring done by the old S4TF team, which should bring our implementation closer in line with yours. That will also allow more features from PyTorch to be integrated into S4TF, such as setting the XLA RNG seed asynchronously. Most importantly, it will be easier to debug because our source code differs a lot less.

@philipturner
Copy link
Contributor

philipturner commented Jul 8, 2022

Here's a more precise look at the crash site.

xla::ComputationClient::DataPtr XLATensor::GetXlaData() {
  bool up_to_date = true;
  // My additions to existing code for the sake of debugging begin here.
  xla::ComputationClient::DataPtr xla_data = CurrentXlaData();
  if ((xla_data != nullptr) && (xla_data->HasValue() == false)) {
    TF_VLOG(0) << "Encountered something that would cause the crash.";
    ApplyPendingGraph();
  }
  // My additions to existing code for the sake of debugging end here.

  ir::Value ir_value;
  if (up_to_date) {
    xla::ComputationClient::DataPtr xla_data = CurrentXlaData();
    if (xla_data != nullptr) {
      // This still fails, even though I applied the pending graph.
      XLA_CHECK(xla_data->HasValue())
          << "Trying to access XLA data while an async operation is in flight: "
          << xla_data->shape();
      return xla_data;
    }
  }
  if (ir_value) {
    AssignIrValue(std::move(ir_value));
  }
  if (data()->ir_value) {
    ApplyPendingGraph();
  } else {
    XLA_CHECK(data()->tensor_data);
    data()->xla_data = TensorToXlaData(*data()->tensor_data, GetDevice());
  }
  return data()->xla_data;
}

Terminal output

Test Case '-[x10Tests.XLATensorTests testLazyTensorBarrier]' started.
2022-07-08 15:59:11.153362: I tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:584] Encountered something that would cause the crash.
2022-07-08 15:59:11.160997: F tensorflow/compiler/xla/xla_client/tf_logging.cc:26] tensorflow/compiler/tf2xla/xla_tensor/tensor.cpp:593 : Check failed: xla_data->HasValue() 
*** Begin stack trace ***
	tensorflow::CurrentStackTrace()
	swift_xla::XLATensor::GetXlaData()
	swift_xla::XLATensor::ToTensor(bool)
	XLATensor_materialize

Repository branch: philipturner/s4tf:fix-xla-bug for s4tf/s4tf#19.

ApplyPendingGraph does nothing because it only checks whether the DataPtr is null, not whether the handle is null. That's why it crashes anyways. Is there supposed to be a scenario where the DataPtr itself is non-null, but the underlying handle_ptr is null?

void XLATensor::ApplyPendingGraph() {
  DeviceBarrier(GetDevice());
  // This method is called to ensure that the tensor data is available on
  // device, so that a call to CurrentXlaData() returns a valid pointer.
  if (CurrentXlaData() == nullptr) {
    std::vector<XLATensor> tensors({*this});
    SyncTensorsGraph(&tensors, {}, /*wait=*/true, /*sync_xla_data=*/false);
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request perf
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RunPostOrder and FetchTensorData methods adds overhead and prevents complete pipelining of two executions
6 participants