@@ -50,9 +50,9 @@ SOFTWARE
50
50
#define SIMD_DIV (x, y ) _mm512_div_ps(x, y)
51
51
#define SIMD_LOAD_HALF (x ) \
52
52
_mm512_cvtph_ps (_mm256_loadu_si256((const __m256i *)(x)))
53
- #define SIMD_STORE_HALF (x, d ) \
54
- _mm256_store_ps ( \
55
- x, _mm256_castsi256_ps(_mm512_cvtps_ph( d, _MM_FROUND_TO_NEAREST_INT)))
53
+ #define SIMD_STORE_HALF (x, d ) \
54
+ _mm256_storeu_ps (( float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
55
+ d, _MM_FROUND_TO_NEAREST_INT)))
56
56
57
57
#elif defined(__AVX256__) or defined(__AVX2__)
58
58
#define SIMD_WIDTH 8
@@ -66,9 +66,9 @@ SOFTWARE
66
66
#define SIMD_SQRT (x ) _mm256_sqrt_ps(x)
67
67
#define SIMD_DIV (x, y ) _mm256_div_ps(x, y)
68
68
#define SIMD_LOAD_HALF (x ) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
69
- #define SIMD_STORE_HALF (x, d ) \
70
- _mm_store_ps ( \
71
- x, _mm_castsi128_ps(_mm256_cvtps_ph( d, _MM_FROUND_TO_NEAREST_INT)))
69
+ #define SIMD_STORE_HALF (x, d ) \
70
+ _mm_storeu_ps (( float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
71
+ d, _MM_FROUND_TO_NEAREST_INT)))
72
72
73
73
#endif
74
74
@@ -83,11 +83,12 @@ union AVX_Data {
83
83
84
84
#endif
85
85
86
- #define STEP (SPAN ) \
87
- void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
88
- float *_exp_avg_sq, size_t _param_size, \
89
- bool param_half_precision = false , \
90
- bool grad_half_precision = false , float loss_scale = -1 );
86
+ #define STEP (SPAN ) \
87
+ void Step_##SPAN( \
88
+ float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
89
+ size_t _param_size, bool param_half_precision = false , \
90
+ bool grad_half_precision = false , bool momentum_half_precision = false , \
91
+ bool variance_half_precision = false , float loss_scale = -1 );
91
92
92
93
class Adam_Optimizer {
93
94
public:
@@ -141,6 +142,24 @@ class Adam_Optimizer {
141
142
}
142
143
}
143
144
145
+ inline void simd_load (bool is_half, float *ptr, __half *h_ptr,
146
+ AVX_Data &data) {
147
+ if (is_half) {
148
+ data.data = SIMD_LOAD_HALF (h_ptr);
149
+ } else {
150
+ data.data = SIMD_LOAD (ptr);
151
+ }
152
+ }
153
+
154
+ inline void simd_store (bool is_half, float *ptr, __half *h_ptr,
155
+ AVX_Data &data) {
156
+ if (is_half) {
157
+ SIMD_STORE_HALF (h_ptr, data.data );
158
+ } else {
159
+ SIMD_STORE (ptr, data.data );
160
+ }
161
+ }
162
+
144
163
void step (size_t step, float lr, float beta1, float beta2, float epsilon,
145
164
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
146
165
torch::Tensor &grads, torch::Tensor &exp_avg,
0 commit comments