Skip to content

Commit

Permalink
faster ntt by avoiding 64b in some places
Browse files Browse the repository at this point in the history
  • Loading branch information
yshekel committed May 20, 2024
1 parent 61d16b7 commit 0d3a51f
Showing 1 changed file with 67 additions and 55 deletions.
122 changes: 67 additions & 55 deletions icicle/src/ntt/thread_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct stage_metadata {
uint32_t th_stride;
uint32_t ntt_block_size;
uint32_t batch_id;
uint32_t ntt_block_id;
uint32_t ntt_inp_id;
uint64_t ntt_block_id;
};

#define STAGE_SIZES_DATA \
Expand Down Expand Up @@ -194,205 +194,217 @@ public:
}

DEVICE_INLINE void
loadGlobalData(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
loadGlobalData(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
}

UNROLL
for (uint32_t i = 0; i < 8; i++) {
X[i] = data[s_meta.th_stride * i * data_stride];
X[i] = data[s_meta.th_stride * i * data_stride_u64];
}
}

DEVICE_INLINE void loadGlobalDataColumnBatch(
const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
const uint64_t data_stride_u64 = data_stride;
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;

UNROLL
for (uint32_t i = 0; i < 8; i++) {
X[i] = data[s_meta.th_stride * i * data_stride * batch_size];
X[i] = data[s_meta.th_stride * i * data_stride_u64 * batch_size];
}
}

DEVICE_INLINE void
storeGlobalData(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
}

UNROLL
for (uint32_t i = 0; i < 8; i++) {
data[s_meta.th_stride * i * data_stride] = X[i];
data[s_meta.th_stride * i * data_stride_u64] = X[i];
}
}

DEVICE_INLINE void storeGlobalDataColumnBatch(
E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
const uint64_t data_stride_u64 = data_stride;
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;

UNROLL
for (uint32_t i = 0; i < 8; i++) {
data[s_meta.th_stride * i * data_stride * batch_size] = X[i];
data[s_meta.th_stride * i * data_stride_u64 * batch_size] = X[i];
}
}

DEVICE_INLINE void
loadGlobalData32(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
}

UNROLL
for (uint32_t j = 0; j < 2; j++) {
UNROLL
for (uint32_t i = 0; i < 4; i++) {
X[4 * j + i] = data[(8 * i + j) * data_stride];
X[4 * j + i] = data[(8 * i + j) * data_stride_u64];
}
}
}

DEVICE_INLINE void loadGlobalData32ColumnBatch(
const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
const uint64_t data_stride_u64 = data_stride;
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;

UNROLL
for (uint32_t j = 0; j < 2; j++) {
UNROLL
for (uint32_t i = 0; i < 4; i++) {
X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size];
X[4 * j + i] = data[(8 * i + j) * data_stride_u64 * batch_size];
}
}
}

DEVICE_INLINE void
storeGlobalData32(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
}

UNROLL
for (uint32_t j = 0; j < 2; j++) {
UNROLL
for (uint32_t i = 0; i < 4; i++) {
data[(8 * i + j) * data_stride] = X[4 * j + i];
data[(8 * i + j) * data_stride_u64] = X[4 * j + i];
}
}
}

DEVICE_INLINE void storeGlobalData32ColumnBatch(
E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
const uint64_t data_stride_u64 = data_stride;
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;

UNROLL
for (uint32_t j = 0; j < 2; j++) {
UNROLL
for (uint32_t i = 0; i < 4; i++) {
data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i];
data[(8 * i + j) * data_stride_u64 * batch_size] = X[4 * j + i];
}
}
}

DEVICE_INLINE void
loadGlobalData16(const E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
loadGlobalData16(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
}

UNROLL
for (uint32_t j = 0; j < 4; j++) {
UNROLL
for (uint32_t i = 0; i < 2; i++) {
X[2 * j + i] = data[(8 * i + j) * data_stride];
X[2 * j + i] = data[(8 * i + j) * data_stride_u64];
}
}
}

DEVICE_INLINE void loadGlobalData16ColumnBatch(
const E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
const uint64_t data_stride_u64 = data_stride;
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;

UNROLL
for (uint32_t j = 0; j < 4; j++) {
UNROLL
for (uint32_t i = 0; i < 2; i++) {
X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size];
X[2 * j + i] = data[(8 * i + j) * data_stride_u64 * batch_size];
}
}
}

DEVICE_INLINE void
storeGlobalData16(E* data, uint64_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
}

UNROLL
for (uint32_t j = 0; j < 4; j++) {
UNROLL
for (uint32_t i = 0; i < 2; i++) {
data[(8 * i + j) * data_stride] = X[2 * j + i];
data[(8 * i + j) * data_stride_u64] = X[2 * j + i];
}
}
}

DEVICE_INLINE void storeGlobalData16ColumnBatch(
E* data, uint64_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
const uint64_t data_stride_u64 = data_stride;
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;

UNROLL
for (uint32_t j = 0; j < 4; j++) {
UNROLL
for (uint32_t i = 0; i < 2; i++) {
data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i];
data[(8 * i + j) * data_stride_u64 * batch_size] = X[2 * j + i];
}
}
}
Expand Down

0 comments on commit 0d3a51f

Please sign in to comment.