Skip to content

Commit

Permalink
polish migrate to cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Apr 7, 2022
1 parent db88d73 commit 9048f3f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.2.5'
__version__ = '1.3.2.6'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
3 changes: 2 additions & 1 deletion python/jittor/src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
sync_times++;
}
for (Var* v : op->inputs()) {
migrate_to_cpu(v, allocator);
if (v->allocator->is_cuda())
migrate_to_cpu(v, allocator);
}
if (!use_cuda_managed_allocator) {
for (auto* var : op->outputs()) {
Expand Down
1 change: 1 addition & 0 deletions python/jittor/src/mem/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void migrate_to_cpu(Var* var, Allocator* allocator) {
);
} else
if (!use_cuda_managed_allocator) {
if (!var->allocator->is_cuda()) return;
// must be a device allocator
Allocation a(allocator, var->size);
checkCudaErrors(cudaMemcpy(a.ptr, var->mem_ptr, var->size, cudaMemcpyDeviceToHost));
Expand Down

0 comments on commit 9048f3f

Please sign in to comment.