Skip to content

Commit

Permalink
kOutputDimensionsが4で割り切れない場合の処理がなかったので追加。 (#294)
Browse files Browse the repository at this point in the history
* `kOutputDimensions`が4で割り切れない場合の処理がなかったので追加。

* affine_transform_non_ssse3()を使うように修正した。
  • Loading branch information
KazApps authored Dec 4, 2024
1 parent 4933602 commit 1436a24
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
21 changes: 11 additions & 10 deletions source/eval/nnue/layers/affine_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ static void affine_transform_non_ssse3(std::int32_t* output,
#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)
#if defined(USE_SSE2)
// At least a multiple of 16, with SSE2.
constexpr IndexType NumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
const __m128i Zeros = _mm_setzero_si128();
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
const __m128i kZeros = _mm_setzero_si128();
const auto inputVector = reinterpret_cast<const __m128i*>(input);

#elif defined(USE_NEON)
constexpr IndexType NumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
#endif

Expand All @@ -36,16 +36,16 @@ static void affine_transform_non_ssse3(std::int32_t* output,

#if defined(USE_SSE2)
__m128i sumLo = _mm_cvtsi32_si128(biases[i]);
__m128i sumHi = Zeros;
__m128i sumHi = kZeros;
const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j)
for (IndexType j = 0; j < kNumChunks; ++j)
{
__m128i row_j = _mm_load_si128(&row[j]);
__m128i input_j = _mm_load_si128(&inputVector[j]);
__m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
__m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
__m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros);
__m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros);
__m128i extendedInputLo = _mm_unpacklo_epi8(input_j, kZeros);
__m128i extendedInputHi = _mm_unpackhi_epi8(input_j, kZeros);
__m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo);
__m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi);
sumLo = _mm_add_epi32(sumLo, productLo);
Expand All @@ -62,7 +62,7 @@ static void affine_transform_non_ssse3(std::int32_t* output,

int32x4_t sum = {biases[i]};
const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j)
for (IndexType j = 0; j < kNumChunks; ++j)
{
int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);
product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);
Expand Down Expand Up @@ -136,7 +136,7 @@ class AffineTransform {

static constexpr IndexType get_weight_index(IndexType i) {
#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)
return get_weight_index_scrambled(i);
return kOutputDimensions % 4 == 0 ? get_weight_index_scrambled(i) : i;
#else
return i;
#endif
Expand Down Expand Up @@ -310,7 +310,8 @@ class AffineTransform {
else
#endif

{}
affine_transform_non_ssse3<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
output, weights_, biases_, input);
}
else if constexpr (kOutputDimensions == 1)
{
Expand Down
5 changes: 3 additions & 2 deletions source/eval/nnue/layers/affine_transform_sparse_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class AffineTransformSparseInput {

static constexpr IndexType get_weight_index(IndexType i) {
#if defined(USE_SSSE3) || USE_NEON >= 8
return get_weight_index_scrambled(i);
return kOutputDimensions % 4 == 0 ? get_weight_index_scrambled(i) : i;
#else
return i;
#endif
Expand Down Expand Up @@ -363,7 +363,8 @@ class AffineTransformSparseInput {
}
else
#endif
{}
affine_transform_non_ssse3<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
output, weights_, biases_, input);

#undef vec_set_32
#undef vec_add_dpbusd_32
Expand Down

0 comments on commit 1436a24

Please sign in to comment.