Skip to content

Commit

Permalink
【DUAL】fix TakePassSparseReferedValues (PaddlePaddle#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
danleifeng committed Sep 12, 2023
1 parent e7cb818 commit 06919f6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
9 changes: 6 additions & 3 deletions paddle/fluid/distributed/ps/service/ps_graph_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ void PsGraphClient::request_handler(const simple::RpcMessageHead &head,
pass_refered->shard_mutex = new std::mutex[info.shard_num];
pass_refered->values->resize(info.shard_num);
info.refered_feas[id].reset(pass_refered);
info.sem_wait.post();
VLOG(0) << "add request_handler table id=" << table_id
<< ", pass id=" << GET_PASS_ID(id) << ", shard_id=" << shard_id
<< ", total_ref=" << total_ref;
Expand Down Expand Up @@ -226,7 +227,7 @@ void PsGraphClient::request_handler(const simple::RpcMessageHead &head,
timeline.Pause();

shard_mutex.lock();
char** valsptr = &shard_values.values[shard_size];
char **valsptr = &shard_values.values[shard_size];
for (size_t i = 0; i < num; ++i) {
valsptr[i] = pull_vals[i];
}
Expand Down Expand Up @@ -257,8 +258,9 @@ std::shared_ptr<SparseShardValues> PsGraphClient::TakePassSparseReferedValues(
const size_t &table_id, const uint16_t &pass_id, const uint16_t &dim_id) {
SparseTableInfo &info = get_table_info(table_id);
uint32_t id = DIM_PASS_ID(dim_id, pass_id);

info.sem_wait.wait();
SparsePassValues *pass_refered = nullptr;

info.pass_mutex.lock();
auto it = info.refered_feas.find(id);
if (it == info.refered_feas.end()) {
Expand All @@ -270,6 +272,7 @@ std::shared_ptr<SparseShardValues> PsGraphClient::TakePassSparseReferedValues(
}
pass_refered = it->second.get();
info.pass_mutex.unlock();

int cnt = pass_refered->wg.count();
VLOG(0) << "table_id=" << table_id
<< ", begin TakePassSparseReferedValues pass_id=" << pass_id
Expand All @@ -280,7 +283,7 @@ std::shared_ptr<SparseShardValues> PsGraphClient::TakePassSparseReferedValues(
shard_ptr.reset(pass_refered->values);
pass_refered->values = nullptr;
// free shard mutex lock
delete [] pass_refered->shard_mutex;
delete[] pass_refered->shard_mutex;

info.pass_mutex.lock();
info.refered_feas.erase(id);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/ps/service/ps_graph_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class PsGraphClient : public PsLocalClient {
uint32_t shard_num;
std::mutex pass_mutex;
SparseFeasReferedMap refered_feas;
paddle::framework::Semaphore sem_wait;
};

public:
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1775,7 +1775,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
FLAGS_graph_edges_split_mode == "dbh" ||
FLAGS_graph_edges_split_mode == "DBH") {
if (!is_key_for_self_rank(id)) {
VLOG(2) << "id " << id << " not matched, node_id: " << node_id_
VLOG(3) << "id " << id << " not matched, node_id: " << node_id_
<< " , node_num:" << node_num_;
continue;
}
Expand Down Expand Up @@ -1854,7 +1854,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
FLAGS_graph_edges_split_mode == "dbh" ||
FLAGS_graph_edges_split_mode == "DBH") {
if (!is_key_for_self_rank(id)) {
VLOG(2) << "id " << id << " not matched, node_id: " << node_id_
VLOG(3) << "id " << id << " not matched, node_id: " << node_id_
<< " , node_num:" << node_num_;
continue;
}
Expand Down Expand Up @@ -1989,15 +1989,15 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_edge_file(
// only keep hash(src_id) = hash(dst_id) = node_id edges
// src id
if (!is_key_for_self_rank(src_id)) {
VLOG(2) << " node num :" << src_id
VLOG(3) << " node num :" << src_id
<< " not split into node_id_:" << node_id_
<< " node_num:" << node_num_;
continue;
}
// dst id
if (!FLAGS_graph_edges_split_only_by_src_id &&
!is_key_for_self_rank(dst_id)) {
VLOG(2) << " dest node num :" << dst_id
VLOG(3) << " dest node num :" << dst_id
<< " will not add egde, node_id_:" << node_id_
<< " node_num:" << node_num_;
continue;
Expand Down

0 comments on commit 06919f6

Please sign in to comment.