Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.0/cmake-3.24.0
&& mkdir /opt/cmake \
&& /tmp/cmake-install.sh --skip-license --prefix=/opt/cmake \
&& rm /tmp/cmake-install.sh
ENV PATH /opt/cmake/bin:${PATH}
ENV PATH=/opt/cmake/bin:${PATH}

RUN curl -L "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
-o "/tmp/Miniconda3.sh"
RUN bash /tmp/Miniconda3.sh -b -p /opt/anaconda
RUN rm -rf /tmp/Miniconda3.sh
RUN cd /opt && eval "$(/opt/anaconda/bin/conda shell.bash hook)"
ENV PATH /opt/anaconda/bin:${PATH}
ENV LD_LIBRARY_PATH /opt/anaconda/lib:${LD_LIBRARY_PATH}
ENV PATH=/opt/anaconda/bin:${PATH}
ENV LD_LIBRARY_PATH=/opt/anaconda/lib:${LD_LIBRARY_PATH}

# python prereqs
RUN conda tos accept --override-channels --channel conda-forge
RUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main
RUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r
RUN conda install -c conda-forge git
RUN pip install numpy>=2.0.0
RUN pip install scipy>=1.13.0 cython nibabel dipy tqdm
Expand Down
12 changes: 12 additions & 0 deletions cuslines/cudamacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@
exit(EXIT_FAILURE); \
}}

#if CUDART_VERSION >= 13000
#define CUDA_MEM_ADVISE(devPtr, count, advice, device) \
cudaMemLocation loc; \
loc.type = cudaMemLocationTypeDevice; \
loc.id = (device); \
CHECK_CUDA(cudaMemAdvise((devPtr), (count), (advice), loc));
#else
#define CUDA_MEM_ADVISE(devPtr, count, advice, device) \
CHECK_CUDA(cudaMemAdvise((devPtr), (count), (advice), (device)))
#endif


#ifdef USE_NVTX
#include "nvToolsExt.h"

Expand Down
2 changes: 1 addition & 1 deletion cuslines/cuslines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class GPUTracker {
for (int n = 0; n < ngpus_; ++n) {
CHECK_CUDA(cudaSetDevice(n));
CHECK_CUDA(cudaMallocManaged(&dataf_d[n], sizeof(*dataf_d[n]) * dataf_info.size));
CHECK_CUDA(cudaMemAdvise(dataf_d[n], sizeof(*dataf_d[n]) * dataf_info.size, cudaMemAdviseSetPreferredLocation, n));
CUDA_MEM_ADVISE(dataf_d[n], sizeof(*dataf_d[n]) * dataf_info.size, cudaMemAdviseSetPreferredLocation, n);
CHECK_CUDA(cudaMalloc(&H_d[n], sizeof(*H_d[n]) * H_info.size));
CHECK_CUDA(cudaMalloc(&R_d[n], sizeof(*R_d[n]) * R_info.size));
CHECK_CUDA(cudaMalloc(&delta_b_d[n], sizeof(*delta_b_d[n]) * delta_b_info.size));
Expand Down
2 changes: 1 addition & 1 deletion cuslines/ptt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ __device__ int get_direction_ptt_d(
REAL_T *__first_val_sh = first_val_sh + tidy;

const REAL_T max_curvature = SIN(max_angle / 2) / step_size; // bigger numbers means wiggle more
const REAL_T probe_step_size = ((step_size / 2) / (PROBE_QUALITY - 1));
const REAL_T probe_step_size = ((step_size / PROBE_FRAC) / (PROBE_QUALITY - 1));

REAL_T __tmp;

Expand Down
10 changes: 5 additions & 5 deletions cuslines/ptt.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@

#define STEP_FRAC 20 // divides output step size (usually 0.5) into this many internal steps
#define PROBE_FRAC 2 // divides output step size (usually 0.5) to find probe length
#define PROBE_QUALITY 4
#define PROBE_QUALITY 4 // Number of probing steps
#define SAMPLING_QUALITY 4 // can be 2-7
#define PROBABILISTIC_BIAS 1 // 1 looks good. can be 0-log_2(N_WARPS) (typically 0-5). 0 is fully probabilistic, 4 is close to deterministic.
#define ALLOW_WEAK_LINK 1
#define DETERMINISTIC_BIAS 0 // Should be 0, higher values bias more towards higher fODF values when tracking
#define ALLOW_WEAK_LINK 0
#define TRIES_PER_REJECTION_SAMPLING 1024
#define DEFAULT_PTT_MINDATASUPPORT 0.05
#define DEFAULT_PTT_MINDATASUPPORT 0.01 // 0.01
#define K_SMALL 0.0001

#define NORM_MIN_SUPPORT (DEFAULT_PTT_MINDATASUPPORT * PROBE_QUALITY)
#define PROBABILISTIC_GROUP_SZ POW2(PROBABILISTIC_BIAS)
#define PROBABILISTIC_GROUP_SZ POW2(DETERMINISTIC_BIAS)

#if SAMPLING_QUALITY == 2
#define DISC_VERT_CNT DISC_2_VERT_CNT
Expand Down