Skip to content

Commit

Permalink
Fixed the segfault.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Lun committed Jul 27, 2023
1 parent f3058d2 commit 856e0c0
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions src/scranpy/feature_selection/model_gene_variances.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@
from mattress import TatamiNumericPointer, tatamize
from ..cpphelpers import lib

# TODO: move out for more general use.
def factorize(x):
levels = []
mapping = {}
output = np.ndarray((len(x),), dtype=np.int32)

for i in range(len(x)):
lev = x[i]
if not lev in mapping:
mapping[lev] = len(levels)
levels.append(len(levels))
output[i] = mapping[lev]

return { "levels": levels, "indices": output }

def model_gene_variances(x, block = None, span = 0.3, num_threads = 1):
if not isinstance(x, TatamiNumericPointer):
x = tatamize(x)
Expand All @@ -24,10 +39,8 @@ def model_gene_variances(x, block = None, span = 0.3, num_threads = 1):
num_threads)

else:
# TODO: need some way of factorizing a block to a numpy Int32 array.
# factorize(block)
block32 = block.astype(np.int32)
nlevels = block32.max() + 1
fac = factorize(block)
nlevels = len(fac["levels"])

all_means = []
all_variances = []
Expand All @@ -39,18 +52,18 @@ def model_gene_variances(x, block = None, span = 0.3, num_threads = 1):
all_residuals_ptr = np.ndarray((nlevels,), dtype=np.uintp)

for l in range(nlevels):
cur_means = np.ndarray((nlevels,), dtype=np.uintp)
cur_variances = np.ndarray((nlevels,), dtype=np.uintp)
cur_fitted = np.ndarray((nlevels,), dtype=np.uintp)
cur_residuals = np.ndarray((nlevels,), dtype=np.uintp)
cur_means = np.ndarray((NR,), dtype=np.float64)
cur_variances = np.ndarray((NR,), dtype=np.float64)
cur_fitted = np.ndarray((NR,), dtype=np.float64)
cur_residuals = np.ndarray((NR,), dtype=np.float64)

all_means_ptr[l] = cur_means.ctypes.data
all_variances_ptr[l] = cur_variances.ctypes.data
all_fitted_ptr[l] = cur_fitted.ctypes.data
all_residuals_ptr[l] = cur_residuals.ctypes.data

all_means.append(cur_means)
all_variance.append(cur_variance)
all_variances.append(cur_variances)
all_fitted.append(cur_fitted)
all_residuals.append(cur_residuals)

Expand All @@ -60,26 +73,28 @@ def model_gene_variances(x, block = None, span = 0.3, num_threads = 1):
variances.ctypes.data,
fitted.ctypes.data,
residuals.ctypes.data,
nlevels,
block32.ctypes.data,
nlevels,
fac["indices"].ctypes.data,
all_means_ptr.ctypes.data,
all_variances_ptr.ctypes.data,
all_fitted_ptr.ctypes.data,
all_residuals_ptr.ctypes.data,
span,
num_threads)

extra = {
"means": all_means,
"variances": all_variances,
"fitted": all_fitted,
"residuals": all_residuals
}
extra = {}
for i in range(nlevels):
extra[fac["levels"][i]] = {
"means": all_means[i],
"variances": all_variances[i],
"fitted": all_fitted[i],
"residuals": all_residuals[i]
}

return {
"means": means,
"variances": variances,
"fitted": fitted,
"residuals": residuals,
"extras": extras
"per_block": extra
}

0 comments on commit 856e0c0

Please sign in to comment.