Skip to content

Commit

Permalink
Xu fix rand on windows (#2538)
Browse files Browse the repository at this point in the history
* build csrc\cpu\tpp\init.cpp pass.
* add omp issue jira id for track.
* implement srand48_r as empty function.
  • Loading branch information
xuhancn authored Jan 31, 2024
1 parent a1946e4 commit 901b377
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion csrc/cpu/tpp/bert/fused_self_attention_fwd_tmpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ if (training) {
RECORD_SCOPE(ac_gemm, {t_QL, t_KL_TV});
{
RECORD_FUNCTION("parallel_for", std::vector<c10::IValue>());
#ifndef _WIN32 // TODO: Fix crash on ICX Windows.
#ifndef _WIN32 // TODO: Fix crash on ICX Windows. CMPLRLLVM-55384
#pragma omp parallel for collapse(2) schedule(static, 1)
#else
#pragma omp for
Expand Down
5 changes: 5 additions & 0 deletions csrc/cpu/tpp/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ thread_local struct drand48_data drng_state; // For non AVX512 version
unsigned int saved_seed = 0;
void xsmm_manual_seed(unsigned int seed) {
saved_seed = seed;
#ifndef _WIN32
#pragma omp parallel
#else
// TODO: Fix crash on ICX Windows. CMPLRLLVM-55384 ?
//#pragma omp parallel
#endif
{
int tid = omp_get_thread_num();
#ifdef __x86_64__
Expand Down
32 changes: 30 additions & 2 deletions csrc/cpu/tpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
#include <torch/csrc/autograd/VariableTypeUtils.h>
//#include <torch/extension.h>

#ifdef _WIN32
#include <intrin.h>
#include <stdint.h>
#include <stdexcept>
#endif

#include <iostream>
#include <vector>
#ifdef _OPENMP
Expand Down Expand Up @@ -55,6 +61,22 @@ typedef at::Half half;
type(*name) dims = (type(*) dims)(t.data_ptr<type>())
#endif

#ifdef _WIN32
struct drand48_data {
uint16_t __x[3] = {0}; /* Current state. */
uint16_t __old_x[3] = {0}; /* Old state. */
uint16_t __c = 0; /* Additive const. in congruential formula. */
uint16_t __init = 0; /* Flag for initializing. */
uint64_t __a = 0; /* Factor in congruential formula. */
};

int srand48_r(uint64_t seed_val, struct drand48_data* buffer) {
throw std::runtime_error("not implemented.");

return 0;
}
#endif

// defined in init.cpp
extern double ifreq;
extern thread_local unsigned int* rng_state;
Expand All @@ -64,11 +86,17 @@ void init_libxsmm();
void xsmm_manual_seed(unsigned int seed);

#ifdef __x86_64__
#ifdef _WIN32
inline uint64_t rdtsc() {
return __rdtsc();
}
#else
static __inline__ unsigned long long rdtsc(void) {
unsigned hi, lo;
__asm__ __volatile__("rdtsc" : "=a"(lo), "=d"(hi));
return ((unsigned long long)lo) | (((unsigned long long)hi) << 32);
}
#endif
#elif defined(__aarch64__)
static __inline__ unsigned long long rdtsc(void) {
unsigned long long val;
Expand All @@ -88,8 +116,8 @@ static __inline__ unsigned long long rdtsc(void) {
#error "Unsupported architecture for rdtsc"
#endif
inline double getFreq() {
long long int s = rdtsc();
long long int e = rdtsc();
uint64_t s = rdtsc();
uint64_t e = rdtsc();
return (e - s) * 1.0;
}

Expand Down

0 comments on commit 901b377

Please sign in to comment.