Skip to content

Commit

Permalink
mc
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Feb 7, 2024
1 parent 467c917 commit b48480c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
30 changes: 15 additions & 15 deletions inst/pydevil/pydevil/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydevil.model import model
from pydevil.guide import guide
from pydevil.utils import prepare_batch
from pydevil.utils_input import check_and_prepare_input_run_SVDE, detach_tensor
from pydevil.utils_input import check_and_prepare_input_run_SVDE, detach_tensor, detach_tensor_and_numpy
from pydevil.utils_hessian import compute_hessians, compute_sandwiches

def run_SVDE(
Expand Down Expand Up @@ -115,20 +115,20 @@ def run_SVDE(
#lk = dist.NegativeBinomial(logits = eta - torch.log(overdispersion) ,
# total_count= torch.clamp(overdispersion, 1e-9,1e9)).log_prob(input_matrix).sum(dim = 0)

input_data['input_matrix'] = detach_tensor(input_data['input_matrix'])
input_data['model_matrix'] = detach_tensor(input_data['model_matrix'])
input_data['group_matrix'] = detach_tensor(input_data['group_matrix'])
input_data['sf'] = detach_tensor(input_data['sf'])
input_data['offset_matrix'] = detach_tensor(input_data['offset_matrix'])
input_data['beta_estimate_matrix'] = detach_tensor(input_data['beta_estimate_matrix'])
input_data['dispersion_priors'] = detach_tensor(input_data['dispersion_priors'])
input_data['clusters'] = detach_tensor(input_data['clusters'])
input_matrix = detach_tensor(input_matrix)
model_matrix = detach_tensor(model_matrix)
overdispersion = detach_tensor(overdispersion)
coeff = detach_tensor(coeff)
loc = detach_tensor(loc)
UMI = detach_tensor(input_data['sf'])
input_data['input_matrix'] = detach_tensor_and_numpy(input_data['input_matrix'])
input_data['model_matrix'] = detach_tensor_and_numpy(input_data['model_matrix'])
input_data['group_matrix'] = detach_tensor_and_numpy(input_data['group_matrix'])
input_data['sf'] = detach_tensor_and_numpy(input_data['sf'])
input_data['offset_matrix'] = detach_tensor_and_numpy(input_data['offset_matrix'])
input_data['beta_estimate_matrix'] = detach_tensor_and_numpy(input_data['beta_estimate_matrix'])
input_data['dispersion_priors'] = detach_tensor_and_numpy(input_data['dispersion_priors'])
input_data['clusters'] = detach_tensor_and_numpy(input_data['clusters'])
input_matrix = detach_tensor_and_numpy(input_matrix)
model_matrix = detach_tensor_and_numpy(model_matrix)
overdispersion = detach_tensor_and_numpy(overdispersion)
coeff = detach_tensor_and_numpy(coeff)
loc = detach_tensor_and_numpy(loc)
UMI = detach_tensor_and_numpy(input_data['sf'])

# if cuda and torch.cuda.is_available():
# input_matrix = input_matrix.cpu().detach().numpy()
Expand Down
13 changes: 12 additions & 1 deletion inst/pydevil/pydevil/utils_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,19 @@ def ensure_tensor(obj, cuda):
if torch.cuda.is_available():
obj = obj.cuda()
return obj

def detach_tensor(obj):
"""
Unload the tensor from the GPU.
"""
if isinstance(obj, torch.Tensor):
if obj.get_device() == 0:
return obj.cpu().detach()
else:
return obj.detach()
return obj

def detach_tensor_and_numpy(obj):
"""
Unload the tensor from the GPU.
"""
Expand Down

0 comments on commit b48480c

Please sign in to comment.