Skip to content

Commit

Permalink
HalfATTAv2_hm
Browse files Browse the repository at this point in the history
  • Loading branch information
PikaCat-OuO committed Jan 29, 2025
1 parent 4f011f7 commit 9d719bc
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 40 deletions.
11 changes: 5 additions & 6 deletions src/nnue/features/half_ka_v2_hm.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ class HalfKAv2_hm {
static constexpr std::uint32_t HashValue = 0xd17b100;

// Number of feature dimensions
static constexpr IndexType Dimensions = 6 * 2 * 3 * static_cast<IndexType>(PS_NB);
static constexpr IndexType Dimensions = 2 * 3 * static_cast<IndexType>(PS_NB);

// Get king_index and mirror information
static constexpr auto KingBuckets = []() {
// Get Mirror information
static constexpr auto NeedMirror = []() {
#define M(s) ((1 << 3) | s)
// Stored as (mirror << 3 | bucket)
constexpr uint8_t KingBuckets[SQUARE_NB] = {
Expand All @@ -96,16 +96,15 @@ class HalfKAv2_hm {
// clang-format on
};
#undef M
std::array<std::array<std::pair<int, bool>, SQUARE_NB>, SQUARE_NB> v{};
std::array<std::array<bool, SQUARE_NB>, SQUARE_NB> v{};
for (uint8_t ksq = SQ_A0; ksq <= SQ_I9; ++ksq)
for (uint8_t oksq = SQ_A0; oksq <= SQ_I9; ++oksq)
{
uint8_t king_bucket_ = KingBuckets[ksq];
int king_bucket = king_bucket_ & 0x7;
bool mirror =
(king_bucket_ >> 3) || ((king_bucket & 1) && (KingBuckets[oksq] >> 3));
v[ksq][oksq].first = king_bucket;
v[ksq][oksq].second = mirror;
v[ksq][oksq] = mirror;
}
return v;
}();
Expand Down
17 changes: 1 addition & 16 deletions src/nnue/nnue_accumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,6 @@ struct alignas(CacheLineSize) Accumulator {
// is commonly referred to as "Finny Tables".
struct AccumulatorCaches {

// clang-format off
static constexpr uint8_t KingCacheMaps[SQUARE_NB] = {
0, 0, 0, 6, 0, 3, 0, 0, 0,
0, 0, 0, 7, 1, 4, 0, 0, 0,
0, 0, 0, 8, 2, 5, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 8, 2, 5, 0, 0, 0,
0, 0, 0, 7, 1, 4, 0, 0, 0,
0, 0, 0, 6, 0, 3, 0, 0, 0,
};
// clang-format on

template<typename Network>
AccumulatorCaches(const Network& network) {
clear(network);
Expand Down Expand Up @@ -95,7 +80,7 @@ struct AccumulatorCaches {

std::array<Entry, COLOR_NB>& operator[](int index) { return entries[index]; }

std::array<std::array<Entry, COLOR_NB>, (9 + 3) * 2 * 3> entries;
std::array<std::array<Entry, COLOR_NB>, 2 * 3 * 2> entries;
};

template<typename Network>
Expand Down
24 changes: 9 additions & 15 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,10 @@ class FeatureTransformer {
assert(computed->accumulator.computed[Perspective]);
assert(computed->next != nullptr);

const Square ksq = pos.king_square(Perspective);
const Square oksq = pos.king_square(~Perspective);
auto [king_bucket, mirror] = FeatureSet::KingBuckets[ksq][oksq];
auto attack_bucket = FeatureSet::make_attack_bucket(pos, Perspective);
auto bucket = king_bucket * 6 + attack_bucket;
const Square ksq = pos.king_square(Perspective);
const Square oksq = pos.king_square(~Perspective);
bool mirror = FeatureSet::NeedMirror[ksq][oksq];
auto bucket = FeatureSet::make_attack_bucket(pos, Perspective);

// The size must be enough to contain the largest possible update.
// That might depend on the feature set and generally relies on the
Expand Down Expand Up @@ -689,17 +688,12 @@ class FeatureTransformer {
void update_accumulator_refresh(const Position& pos, AccumulatorCaches::Cache* cache) const {
assert(cache != nullptr);

const Square ksq = pos.king_square(Perspective);
const Square oksq = pos.king_square(~Perspective);
auto [king_bucket, mirror] = FeatureSet::KingBuckets[ksq][oksq];
auto attack_bucket = FeatureSet::make_attack_bucket(pos, Perspective);
auto bucket = king_bucket * 6 + attack_bucket;
const Square ksq = pos.king_square(Perspective);
const Square oksq = pos.king_square(~Perspective);
bool mirror = FeatureSet::NeedMirror[ksq][oksq];
auto bucket = FeatureSet::make_attack_bucket(pos, Perspective);

auto cache_index = AccumulatorCaches::KingCacheMaps[ksq];
if (cache_index < 3 && mirror)
cache_index += 9;

auto& entry = (*cache)[cache_index * 6 + attack_bucket][Perspective];
auto& entry = (*cache)[bucket * 2 + mirror][Perspective];
FeatureSet::IndexList removed, added;

for (Color c : {WHITE, BLACK})
Expand Down
8 changes: 5 additions & 3 deletions src/position.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,11 @@ void Position::do_move(Move m,

if (pc == make_piece(us, KING))
{
dp.requires_refresh[us] = true;
bool mirror_before = Eval::NNUE::FeatureSet::KingBuckets[king_square(them)][from].second;
bool mirror_after = Eval::NNUE::FeatureSet::KingBuckets[king_square(them)][to].second;
bool mirror_before = Eval::NNUE::FeatureSet::NeedMirror[from][king_square(them)];
bool mirror_after = Eval::NNUE::FeatureSet::NeedMirror[to][king_square(them)];
dp.requires_refresh[us] = (mirror_before != mirror_after);
mirror_before = Eval::NNUE::FeatureSet::NeedMirror[king_square(them)][from];
mirror_after = Eval::NNUE::FeatureSet::NeedMirror[king_square(them)][to];
dp.requires_refresh[them] = (mirror_before != mirror_after);
}
else
Expand Down

0 comments on commit 9d719bc

Please sign in to comment.