From b48480c8433762d18b0928bc4fffb45aba622a2d Mon Sep 17 00:00:00 2001 From: jovoni Date: Wed, 7 Feb 2024 12:05:30 +0100 Subject: [PATCH] mc --- inst/pydevil/pydevil/interface.py | 30 ++++++++++++++--------------- inst/pydevil/pydevil/utils_input.py | 13 ++++++++++++- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/inst/pydevil/pydevil/interface.py b/inst/pydevil/pydevil/interface.py index 2add30f..413984b 100644 --- a/inst/pydevil/pydevil/interface.py +++ b/inst/pydevil/pydevil/interface.py @@ -10,7 +10,7 @@ from pydevil.model import model from pydevil.guide import guide from pydevil.utils import prepare_batch -from pydevil.utils_input import check_and_prepare_input_run_SVDE, detach_tensor +from pydevil.utils_input import check_and_prepare_input_run_SVDE, detach_tensor, detach_tensor_and_numpy from pydevil.utils_hessian import compute_hessians, compute_sandwiches def run_SVDE( @@ -115,20 +115,20 @@ def run_SVDE( #lk = dist.NegativeBinomial(logits = eta - torch.log(overdispersion) , # total_count= torch.clamp(overdispersion, 1e-9,1e9)).log_prob(input_matrix).sum(dim = 0) - input_data['input_matrix'] = detach_tensor(input_data['input_matrix']) - input_data['model_matrix'] = detach_tensor(input_data['model_matrix']) - input_data['group_matrix'] = detach_tensor(input_data['group_matrix']) - input_data['sf'] = detach_tensor(input_data['sf']) - input_data['offset_matrix'] = detach_tensor(input_data['offset_matrix']) - input_data['beta_estimate_matrix'] = detach_tensor(input_data['beta_estimate_matrix']) - input_data['dispersion_priors'] = detach_tensor(input_data['dispersion_priors']) - input_data['clusters'] = detach_tensor(input_data['clusters']) - input_matrix = detach_tensor(input_matrix) - model_matrix = detach_tensor(model_matrix) - overdispersion = detach_tensor(overdispersion) - coeff = detach_tensor(coeff) - loc = detach_tensor(loc) - UMI = detach_tensor(input_data['sf']) + input_data['input_matrix'] = detach_tensor_and_numpy(input_data['input_matrix']) + input_data['model_matrix'] = detach_tensor_and_numpy(input_data['model_matrix']) + input_data['group_matrix'] = detach_tensor_and_numpy(input_data['group_matrix']) + input_data['sf'] = detach_tensor_and_numpy(input_data['sf']) + input_data['offset_matrix'] = detach_tensor_and_numpy(input_data['offset_matrix']) + input_data['beta_estimate_matrix'] = detach_tensor_and_numpy(input_data['beta_estimate_matrix']) + input_data['dispersion_priors'] = detach_tensor_and_numpy(input_data['dispersion_priors']) + input_data['clusters'] = detach_tensor_and_numpy(input_data['clusters']) + input_matrix = detach_tensor_and_numpy(input_matrix) + model_matrix = detach_tensor_and_numpy(model_matrix) + overdispersion = detach_tensor_and_numpy(overdispersion) + coeff = detach_tensor_and_numpy(coeff) + loc = detach_tensor_and_numpy(loc) + UMI = detach_tensor_and_numpy(input_data['sf']) # if cuda and torch.cuda.is_available(): # input_matrix = input_matrix.cpu().detach().numpy() diff --git a/inst/pydevil/pydevil/utils_input.py b/inst/pydevil/pydevil/utils_input.py index 1383065..677f9ae 100644 --- a/inst/pydevil/pydevil/utils_input.py +++ b/inst/pydevil/pydevil/utils_input.py @@ -16,8 +16,19 @@ def ensure_tensor(obj, cuda): if torch.cuda.is_available(): obj = obj.cuda() return obj - + def detach_tensor(obj): + """ + Unload the tensor from the GPU. + """ + if isinstance(obj, torch.Tensor): + if obj.get_device() == 0: + return obj.cpu().detach() + else: + return obj.detach() + return obj + +def detach_tensor_and_numpy(obj): """ Unload the tensor from the GPU. """