Skip to content

Commit

Permalink
Fix scan issues (#2729)
Browse files Browse the repository at this point in the history
  • Loading branch information
blzheng authored Apr 7, 2024
1 parent e1095f7 commit 7bc3869
Show file tree
Hide file tree
Showing 18 changed files with 212 additions and 139 deletions.
6 changes: 3 additions & 3 deletions csrc/cpu/aten/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,10 @@ at::Tensor woq_linear_pack_weight(
int64_t weight_int4_size_bytes = weight_int4.numel();
int64_t pad_size_bytes = weight_int4_size_bytes - weight_size_bytes;
std::memcpy(weight_int4.data_ptr(), weight.data_ptr(), weight_size_bytes);
std::memset(
std::fill_n(
(uint8_t*)weight_int4.data_ptr() + weight_size_bytes,
0,
pad_size_bytes);
pad_size_bytes,
0);
return woq_tpp_gemm_packB_stub(
kCPU, weight_int4, weight_dtype, block_n, block_k, lowp_mode);
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/aten/MergedEmbeddingBag.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class EMBROW {
length = len;
arr.resize(length);
data = &arr[0];
memset(data, 0, len * sizeof(T));
std::fill_n(data, len, T(0));
}
};

Expand All @@ -36,7 +36,7 @@ class EMBROWFixLen {

EMBROWFixLen(int32_t len) {
data = &arr[0];
memset(data, 0, emb_dim * sizeof(T));
std::fill_n(data, emb_dim, T(0));
}
};

Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/aten/kernels/WoqLinearKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ void dot_update(
template <typename T>
void zero_fill(T* C, int M, int N, int stride) {
for (int m = 0; m < M; m++) {
memset(C + m * stride, 0, sizeof(T) * N);
std::fill_n(C + m * stride, N, T(0));
}
}

Expand Down
6 changes: 2 additions & 4 deletions csrc/cpu/aten/kernels/WoqTppKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2016,10 +2016,8 @@ void qlinear_woq_affine_impl(
64, num_threads * M * N * sizeof(TGemmOut));
y_private_valid = (bool*)std::aligned_alloc(
64, num_threads * (M / BLOCK_M) * Nc * sizeof(bool));
memset(
y_private_valid,
0,
sizeof(bool) * num_threads * (M / BLOCK_M) * Nc);
std::fill_n(
y_private_valid, num_threads * (M / BLOCK_M) * Nc, false);
}
auto y_private_ptr = GetVLAPtr<TGemmOut>(y_private, {M, Nc, Nb});
auto y_private_valid_ptr =
Expand Down
8 changes: 6 additions & 2 deletions csrc/cpu/aten/utils/woq.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,16 @@ class DotMicroKernel {
(trans_a ? LIBXSMM_GEMM_FLAG_TRANS_A : LIBXSMM_GEMM_FLAG_NONE) |
(trans_b ? LIBXSMM_GEMM_FLAG_TRANS_B : LIBXSMM_GEMM_FLAG_NONE);
libxsmm_gemm_batch_reduce_config brconfig;
memset(&brconfig, 0, sizeof(libxsmm_gemm_batch_reduce_config));
std::fill_n(
reinterpret_cast<char*>(&brconfig),
sizeof(libxsmm_gemm_batch_reduce_config),
0);
brconfig.br_type = LIBXSMM_GEMM_BATCH_REDUCE_NONE;

kernel_func_ = libxsmm_dispatch_brgemm_v2(
brshape, brflags, /*prefetch_flags=*/0, brconfig);
memset(&gemm_param_, 0, sizeof(libxsmm_gemm_param));
std::fill_n(
reinterpret_cast<char*>(&gemm_param_), sizeof(libxsmm_gemm_param), 0);
}

void operator()(void* A, void* B, void* C) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/comm/shm_reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ class ShmReduction {
shmCtx_.nblocks = MAX_SHM_BLOCK_COUNT;
if (rank_ == 0) {
torch_ipex::cpu::create_shm(&shmCtx_);
memset(shmCtx_.state, 0, shmCtx_.nstates * sizeof(int));
memset((void*)shmCtx_.blockState, 0, shmCtx_.nstates * shmCtx_.nblocks);
std::fill_n(shmCtx_.state, shmCtx_.nstates, 0);
std::fill_n(shmCtx_.blockState, shmCtx_.nstates * shmCtx_.nblocks, 0);
}

callback(shmCtx_.pid_fd, 2);
Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/tpp/jit_compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void* jit_compile_and_load(
int fd = mkstemp(libname);
unlink(libname);
char fdname[50];
sprintf(fdname, "/proc/self/fd/%d", fd);
snprintf(fdname, sizeof(fdname), "/proc/self/fd/%d", fd);
auto cmd = std::string("g++ -shared -fPIC -x c++ ") + flags;
cmd = cmd + " -o " + fdname + " " + filename;
printf("JIT COMPILE: %s\n", cmd.c_str());
Expand Down Expand Up @@ -66,7 +66,7 @@ void* jit_from_str(
int fd = mkstemp(filename);
unlink(filename);
char fdname[50];
sprintf(fdname, "/proc/self/fd/%d", fd);
snprintf(fdname, sizeof(fdname), "/proc/self/fd/%d", fd);
write(fd, src.c_str(), src.length());
return jit_from_file(fdname, flags, func_name);
#else
Expand Down
Loading

0 comments on commit 7bc3869

Please sign in to comment.