-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[async] [cuda] AsyncEngine now supports CUDA #1687
Conversation
@@ -673,8 +673,6 @@ def to_numpy(self, keep_dims=False, as_vector=None): | |||
ret = np.zeros(self.shape + shape_ext, dtype=to_numpy_type(self.dtype)) | |||
from .meta import matrix_to_ext_arr | |||
matrix_to_ext_arr(self, ret, as_vector) | |||
import taichi as ti | |||
ti.sync() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed since we already synchronize in Kernel.__call__
static std::unordered_map<std::thread::id, CUDAContext *> instances; | ||
static std::mutex mut; | ||
{ | ||
// critical section | ||
auto _ = std::lock_guard<std::mutex>(mut); | ||
|
||
auto tid = std::this_thread::get_id(); | ||
if (instances.find(tid) == instances.end()) { | ||
instances[tid] = new CUDAContext(); | ||
// We expect CUDAContext to live until the process ends, thus the raw | ||
// pointers and `new`s. | ||
} | ||
return *instances[tid]; | ||
} | ||
static auto context = new CUDAContext(); | ||
return *context; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't exactly remember why I had to create a standalone CUDAContext for each thread.
Creating a single context seems to work well.
@@ -1,5 +1,7 @@ | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since now we have one CUDAContext
shared by all threads, I add mutex to serialize the driver API calls.
Codecov Report
@@ Coverage Diff @@
## master #1687 +/- ##
==========================================
+ Coverage 43.78% 43.82% +0.03%
==========================================
Files 42 42
Lines 5940 5940
Branches 1023 1024 +1
==========================================
+ Hits 2601 2603 +2
+ Misses 3189 3187 -2
Partials 150 150
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
Related issue = #742
[Click here for the format server]