Skip to content

Commit

Permalink
- PolicyBookのコード、GitHubに反映忘れていたところ修正。
Browse files Browse the repository at this point in the history
- vectorに対して&vみたいにしてアドレスを取っていたところ、v.data()を使うように修正。
  • Loading branch information
yaneurao committed Dec 17, 2024
1 parent dbdbe9f commit acd1ef8
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 20 deletions.
17 changes: 9 additions & 8 deletions source/book/book.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1492,7 +1492,7 @@ namespace Book
Position pos;
string root_sfen = "startpos moves 7g7f 3c3d 6g6f 8b3b 8h7g 5a6b 2h8h 6b7b 8g8f 3d3e 8f8e 3e3f 3i2h 3f3g+ 2h3g 3a4b 4i3h P*3f 3g2h 4b3c 6i5h 3c4d 7i6h 1c1d P*3g 7a8b";
deque<StateInfo> si;
BookTools::feed_position_string(pos, root_sfen, si, [](Position&){});
BookTools::feed_position_string(pos, root_sfen, si, [](Position&,Move){});

string moves1 = "1g1f 2g2f 3g3f 4g4f 5g5f 6f6e 7f7e 8e8d 9g9f 1i1h 9i9h 2h3i 6h6g 6h7i 7g8f 7g9e 8h7h 8h8f 8h8g 8h9h 3h3i 3h4h 5h4h 5h6g 5i4h 5i4i 5i6i";
string moves2 = string();
Expand All @@ -1518,7 +1518,8 @@ namespace BookTools
// "sfen xxx moves yyy ..."
// また、局面を1つ進めるごとにposition_callback関数が呼び出される。
// 辿った局面すべてに対して何かを行いたい場合は、これを利用すると良い。
void feed_position_string(Position& pos, const std::string& root_sfen, std::deque<StateInfo>& si, const std::function<void(Position&)>& position_callback)
void feed_position_string(Position& pos, const std::string& root_sfen, std::deque<StateInfo>& si,
const std::function<void(Position&,Move)>& position_callback)
{
// issから次のtokenを取得する
auto feed_next = [](Parser::LineScanner& iss)
Expand Down Expand Up @@ -1576,9 +1577,6 @@ namespace BookTools
}
} while (token == "startpos" || token == "sfen" || token == "moves"/* movesは無視してループを回る*/ );

// callbackを呼び出してやる。
position_callback(pos);

// moves以降は1手ずつ進める
while (token != "")
{
Expand All @@ -1591,14 +1589,17 @@ namespace BookTools
if (!move.is_ok())
break;

// callbackを呼び出してやる。
position_callback(pos, move);

si.emplace_back(StateInfo());
pos.do_move(move, si.back());

// callbackを呼び出してやる。
position_callback(pos);

token = feed_next(iss);
}

// 最後の局面でcallbackを呼び出してやる。
position_callback(pos, Move::none());
}

// 平手、駒落ちの開始局面集
Expand Down
6 changes: 5 additions & 1 deletion source/book/book.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ namespace BookTools
// "sfen xxx moves yyy ..."
// また、局面を1つ進めるごとにposition_callback関数が呼び出される。
// 辿った局面すべてに対して何かを行いたい場合は、これを利用すると良い。
void feed_position_string(Position& pos, const std::string& root_sfen, std::deque<StateInfo>& si, const std::function<void(Position&)>& position_callback = [](Position&) {});
//
// position_callbackは、その局面と、その局面での指し手が引数にセットされて呼び出される。
// 与えたsfenの最後の局面では、MoveはMove::none()が入って呼び出される。
void feed_position_string(Position& pos, const std::string& root_sfen, std::deque<StateInfo>& si,
const std::function<void(Position&, Move)>& position_callback = [](Position&, Move) {});

// 平手、駒落ちの開始局面集
// ここで返ってきた配列の、[0]は平手のsfenであることは保証されている。
Expand Down
63 changes: 59 additions & 4 deletions source/book/policybook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "../position.h"
#include "../thread.h"
#include "../usi.h"

#include "../book/book.h"

// freqの和がUINT16_MAXに収まるようにする。
u16 MoveFreq32Record::overflow_check()
Expand Down Expand Up @@ -61,8 +61,8 @@ Tools::Result PolicyBook::read_book_db(std::string path)
reader.ReadLine(sfen);
if (sfen != "#YANEURAOU-POLICY-DB2024 1.00")
{
sync_cout << "info string Error! policy book header" << sync_endl;
Tools::exit();
sync_cout << "info string Error! invalid policy book header" << sync_endl;
return Tools::ResultCode::FileMismatch;
}
while (true)
{
Expand Down Expand Up @@ -161,6 +161,8 @@ Tools::Result PolicyBook::read_book_db_bin(std::string path)
// PolicyBookを読み込み、."db.bin"ファイルを書き出す。
Tools::Result PolicyBook::read_book()
{

#if !defined(ENABLE_POLICY_BOOK_LEARN)
// まだ読み込んでいないならば..
if (!is_loaded())
{
Expand All @@ -177,7 +179,35 @@ Tools::Result PolicyBook::read_book()
return result;
}

return Tools::Result::Ok(); // 読み込めたことにしておく。
#else
// ただし、ENABLE_POLICY_BOOK_LEARNが定義されているときは、毎回読み込む。(局後学習データがあるため)

// binary化されたPolicyBookがあるなら、それを読み込む。
Tools::Result result = read_book_db_bin();
if (result.is_not_ok())
{
result = read_book_db();
if (result.is_ok())
{
// "db.bin"形式で書き出しておく。(次回の読み込み高速化のため)
result = write_book_db_bin();
}
}

// そもそも読み込んでいないのでmerge不要。
if (result.is_not_ok())
result = read_book_db_bin(POLICY_BOOK_LEARN_DB_BIN_NAME);
else {
PolicyBook pb;
result = pb.read_book_db_bin(POLICY_BOOK_LEARN_DB_BIN_NAME);
// 読み込みに成功したのでmergeする。
if (result.is_ok())
merge_book(pb);
}

#endif

return result; // 読み込めたことにしておく。
}


Expand Down Expand Up @@ -309,6 +339,31 @@ PolicyBookEntry* PolicyBook::probe_policy_book(HASH_KEY key)
return (it != book_body.end() && it->key == key) ? &*it : nullptr;
}

// "position "コマンドのposition以降の文字列を渡して、それを
// POLICY_BOOK_LEARN_DB_BIN_NAMEにappendで書き出す。
void PolicyBook::append_sfen_to_db_bin(const std::string& sfen)
{
Position pos;
StateList si;
std::vector<PolicyBookEntry> entries;

BookTools::feed_position_string(pos, sfen, si, [&](Position& p, Move m) {
// 最後の局面は、m==Move::none()が入ってくる。
if (m == Move::none())
return;
PolicyBookEntry entry;
entry.key = p.hash_key();
entry.move_freq[0] = MoveFreq(m.to_move16(), 1);
entries.push_back(entry);
});

// ファイルにappendする。
SystemIO::BinaryWriter writer;
writer.Open(POLICY_BOOK_LEARN_DB_BIN_NAME, true);
auto result = writer.Write(entries.data(), sizeof(PolicyBookEntry) * entries.size());
sync_cout << "info string append " << POLICY_BOOK_LEARN_DB_BIN_NAME << ". status = " << result.to_string() << sync_endl;
}

#if 0
// PolicyBookのmergeが正常にできているかをテストするコード。
void merge_test()
Expand Down
9 changes: 7 additions & 2 deletions source/book/policybook.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

static_assert(HASH_KEY_BITS == 128 , "HASH_KEY_BITS must be 128");

#define POLICY_BOOK_DB_NAME "eval/policy_book.db"
#define POLICY_BOOK_DB_BIN_NAME "eval/policy_book.db.bin"
#define POLICY_BOOK_DB_NAME "eval/policy_book.db"
#define POLICY_BOOK_DB_BIN_NAME "eval/policy_book.db.bin"
#define POLICY_BOOK_LEARN_DB_BIN_NAME "eval/policy_book-learn.db.bin"

// ============================================================
// Policy Book
Expand Down Expand Up @@ -84,6 +85,10 @@ class PolicyBook
// PolicyBook同士のmerge
void PolicyBook::merge_book(const PolicyBook& book);

// "position "コマンドのposition以降の文字列を渡して、それを
// POLICY_BOOK_LEARN_DB_BIN_NAMEにappendで書き出す。
void append_sfen_to_db_bin(const std::string& sfen);

// ファイルから読み込んだか?
bool is_loaded() const { return book_body.size() != 0; }

Expand Down
4 changes: 2 additions & 2 deletions source/eval/evaluate_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ namespace EvalIO
{
std::vector<u8> buffer(input_block_size);
std::ifstream ifs(in_.file_or_memory.filename, std::ios::binary);
if (ifs) ifs.read(reinterpret_cast<char*>(&buffer[0]), input_block_size);
if (ifs) ifs.read(reinterpret_cast<char*>(buffer.data()), input_block_size);
else
{
std::cout << "info string read file error , file = " << in_.file_or_memory.filename << std::endl;
return false;
};
std::ofstream ofs(out_.file_or_memory.filename, std::ios::binary);
if (ofs) ofs.write(reinterpret_cast<char*>(&buffer[0]), output_block_size);
if (ofs) ofs.write(reinterpret_cast<char*>(buffer.data()), output_block_size);
else
{
std::cout << "info string write file error , file = " << out_.file_or_memory.filename << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions source/learn/learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ struct SfenWriter
{
for (auto ptr : buffers)
{
fs.write((const char*)&((*ptr)[0]), sizeof(PackedSfenValue) * ptr->size());
fs.write(reinterpret_cast<const char*>(ptr->data()), sizeof(PackedSfenValue) * ptr->size());

sfen_write_count += ptr->size();

Expand Down Expand Up @@ -2314,7 +2314,7 @@ void shuffle_files(const vector<string>& filenames , const string& output_file_n
// ファイルに書き出す
fstream fs;
fs.open(make_filename(write_file_count++), ios::out | ios::binary);
fs.write((char*)&buf[0], size * sizeof(PackedSfenValue));
fs.write((char*)buf.data(), size * sizeof(PackedSfenValue));
fs.close();
a_count.push_back(size);

Expand Down
2 changes: 1 addition & 1 deletion source/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ namespace SystemIO
// 今回のループで書き込むbyte数
write_size = buf_size - write_cursor;
std::memcpy(&buf[write_cursor], ptr2, write_size);
if (fwrite(&buf[0], buf_size, 1, fp) == 0)
if (fwrite(buf.data(), buf_size, 1, fp) == 0)
return Tools::ResultCode::FileWriteError;

// buf[0..write_cursor-1]が窓で、ループごとにその窓がbuf_sizeずつずれていくと考える。
Expand Down

0 comments on commit acd1ef8

Please sign in to comment.