diff --git a/cpp/include/wholememory/device_reference.cuh b/cpp/include/wholememory/device_reference.cuh
index 4ffde7d44..8f2146ae9 100644
--- a/cpp/include/wholememory/device_reference.cuh
+++ b/cpp/include/wholememory/device_reference.cuh
@@ -27,39 +27,20 @@ class device_reference {
   __device__ __forceinline__ explicit device_reference(const wholememory_gref_t& gref)
     : pointer_(static_cast<DataTypeT*>(gref.pointer)),
       typed_stride_(gref.stride / sizeof(DataTypeT)),
-      rank_memory_offsets_(gref.rank_memory_offsets),
       world_size_(gref.world_size),
-      same_chunk_(gref.same_chunk),
-      estimated_stride_(0),
-      cache_rank_(0),
-      cache_offset_(0),
-      cache_size_(0)
+      same_chunk_(gref.same_chunk)
   {
     assert(gref.stride % sizeof(DataTypeT) == 0);
-    if (typed_stride_ > 0 && !same_chunk_) {
-      estimated_stride_ = rank_memory_offsets_[world_size_] / world_size_;
-      cache_rank_       = 0;
-      cache_offset_     = 0;
-      cache_size_       = rank_memory_offsets_[1] - rank_memory_offsets_[0];
+    if (typed_stride_ != 0 && !same_chunk_) {
+      assert(world_size_ <= 8);  // intra-node WHOLEMEMORY_MT_CHUNKED
+      for (int i = 0; i < world_size_ + 1; i++) {
+        assert(gref.rank_memory_offsets[i] % sizeof(DataTypeT) == 0);
+        typed_rank_mem_offsets_[i] = gref.rank_memory_offsets[i] / sizeof(DataTypeT);
+      }
     }
   }
   __device__ device_reference() = delete;
 
-  __device__ __forceinline__ size_t copy_offsets_to_shmem(char* shmem, size_t maxsize)
-  {
-    if (typed_stride_ == 0 || same_chunk_) return 0;
-    size_t used_shmem_size = (world_size_ + 1) * sizeof(size_t);
-    if (used_shmem_size > maxsize) return 0;
-    size_t* shmem_offsets = reinterpret_cast<size_t*>(shmem);
-    for (int i = threadIdx.x; i <= world_size_; i += blockDim.x) {
-      shmem_offsets[i] = rank_memory_offsets_[i];
-    }
-    __syncthreads();
-    rank_memory_offsets_           = shmem_offsets;
-    size_t aligned_used_shmem_size = ((used_shmem_size - 1) / 128 + 1) * 128;
-    return aligned_used_shmem_size;
-  }
-
   __device__ __forceinline__ DataTypeT& operator[](size_t index)
   {
     if (typed_stride_ == 0) { return pointer_[index]; }
@@ -68,47 +49,25 @@ class device_reference {
       return static_cast<DataTypeT**>(
         static_cast<void*>(pointer_))[rank][index - rank * typed_stride_];
     } else {
-      size_t rank   = 0;
-      size_t offset = index * sizeof(DataTypeT);
-      if (offset >= cache_offset_ && offset < cache_offset_ + cache_size_) {
-        rank = cache_rank_;
-      } else {
-        int estimated_rank = max(world_size_ - 1, int(offset / estimated_stride_));
-        if (rank_memory_offsets_[estimated_rank] > offset) {
-          for (int i = estimated_rank - 1; i >= 0; i--) {
-            if (rank_memory_offsets_[i] <= offset) {
-              rank = i;
-              break;
-            }
-          }
-        } else {
-          for (int i = estimated_rank + 1; i <= world_size_; i++) {
-            if (rank_memory_offsets_[i] > offset) {
-              rank = i - 1;
-              break;
-            }
-          }
+      size_t rank = 0;
+      for (int i = 1; i < world_size_ + 1; i++) {
+        if (index < typed_rank_mem_offsets_[i]) {
+          rank = i - 1;
+          break;
         }
-        cache_rank_   = rank;
-        cache_offset_ = rank_memory_offsets_[rank];
-        cache_size_   = rank_memory_offsets_[rank + 1] - rank_memory_offsets_[rank];
       }
       return static_cast<DataTypeT**>(
-        static_cast<void*>(pointer_))[rank][index - cache_offset_ / sizeof(DataTypeT)];
+        static_cast<void*>(pointer_))[rank][index - typed_rank_mem_offsets_[rank]];
     }
   }
 
  private:
   DataTypeT* pointer_;
-  size_t* rank_memory_offsets_;
   int world_size_;
   size_t typed_stride_;
 
-  size_t estimated_stride_;
   bool same_chunk_;
-  int cache_rank_;
-  size_t cache_offset_;
-  size_t cache_size_;
+  size_t typed_rank_mem_offsets_[8 + 1];
 };
 
 }  // namespace wholememory
diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu
index f5b25390b..6bd6b6c44 100644
--- a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu
+++ b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu
@@ -29,10 +29,10 @@
 namespace wholememory_ops {
 
 template <typename IndexT>
-__device__ int dest_rank(IndexT entry_idx,
-                         size_t total_entry_count,
-                         const size_t* embedding_entry_offsets,
-                         int world_size)
+__device__ __forceinline__ int dest_rank(IndexT entry_idx,
+                                         size_t total_entry_count,
+                                         const size_t* embedding_entry_offsets,
+                                         int world_size)
 {
   size_t estimated_entry_per_rank = total_entry_count / world_size;
   int estimated_rank              = max(world_size - 1, int(entry_idx / estimated_entry_per_rank));
@@ -60,13 +60,18 @@ __global__ void bucket_ids_for_ranks_kernel(const IndexT* indices,
   for (int idx = threadIdx.x; idx < world_size; idx += blockDim.x) {
     rank_count_shared[idx] = 0;
   }
+  size_t* embedding_entry_offsets_shared =
+    reinterpret_cast<size_t*>(shmem + sizeof(size_t) * world_size);
+  for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) {
+    embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx];
+  }
   __syncthreads();
-  size_t total_entry_count = embedding_entry_offsets[world_size];
+  size_t total_entry_count = embedding_entry_offsets_shared[world_size];
   for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count;
        idx += blockDim.x * gridDim.x) {
     IndexT node_idx = indices[idx];
     if (node_idx < 0) continue;
-    int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets, world_size);
+    int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets_shared, world_size);
     assert(rank >= 0 && rank < world_size);
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
     atomicAdd_block(&rank_count_shared[rank], 1);
@@ -95,7 +100,10 @@ void bucket_ids_for_ranks_temp_fn(void* indices,
   block_count         = std::min(block_count, sm_count * 4);
   IndexT* indices_ptr = static_cast<IndexT*>(indices);
   indices_ptr += indice_desc.storage_offset;
-  bucket_ids_for_ranks_kernel<<<block_count, BLOCK_SIZE, sizeof(int) * world_size, stream>>>(
+  bucket_ids_for_ranks_kernel<<<block_count,
+                                BLOCK_SIZE,
+                                sizeof(size_t) * (world_size * 2 + 1),
+                                stream>>>(
     indices_ptr, indice_desc.size, dev_rank_id_count_ptr, embedding_entry_offsets, world_size);
 }
 
diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
index fd6d0f8d5..e12b3756b 100644
--- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
+++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
@@ -272,13 +272,12 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
   int64_t output_stride    = output_desc.stride;
 
   wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
-  size_t used_shm_size = embedding_dev_ref.copy_offsets_to_shmem(shm_in_char, shm_max_size);
 
   typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
   typed_data_vector<OutputT, ALIGNMENT> outputs;
 
-  int shm_size    = (shm_max_size - used_shm_size) / sizeof(OutputT);
-  OutputT* all_sh = reinterpret_cast<OutputT*>(shm_in_char + used_shm_size);
+  int shm_size    = shm_max_size / sizeof(OutputT);
+  OutputT* all_sh = reinterpret_cast<OutputT*>(shm_in_char);
   OutputT* my_shared;
   bool use_shm = true;
   if (shm_size / (blockDim.x / 32) < output_desc.sizes[1]) {  //
@@ -346,10 +345,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
 
   int lane_id_in_sub_warp = subwarp.thread_rank();
 
-  constexpr size_t shm_max_size = 1024 * sizeof(size_t);
-  __shared__ char shmem[shm_max_size];
   wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
-  embedding_dev_ref.copy_offsets_to_shmem(shmem, shm_max_size);
 
   int embedding_size       = embedding_desc.sizes[1];
   int64_t embedding_stride = embedding_desc.stride;
@@ -542,10 +538,9 @@ __global__ void scatter_func_kernel(const InputT* input,
   int async_copy_align     = sizeof(InputT) > 4 ? 1 : 4 / sizeof(InputT);
 
   wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
-  size_t used_shm_size = embedding_dev_ref.copy_offsets_to_shmem(shm_in_char, shm_max_size);
 
-  int shm_size   = (shm_max_size - used_shm_size) / sizeof(InputT);
-  InputT* all_sh = reinterpret_cast<InputT*>(shm_in_char + used_shm_size);
+  int shm_size   = shm_max_size / sizeof(InputT);
+  InputT* all_sh = reinterpret_cast<InputT*>(shm_in_char);
   InputT* my_shared;
   int batch_size = (shm_size / (blockDim.x / 32) - async_copy_align) /
                    input_stride;  // indices batch size in lines