Skip to content

Commit

Permalink
Merge pull request #2 from WeiyueSu/batch_sample_k
Browse files Browse the repository at this point in the history
Batch sample k
  • Loading branch information
seemingwang authored Mar 16, 2021
2 parents ba57877 + 86ff4d9 commit 2abf38c
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 88 deletions.
106 changes: 77 additions & 29 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,91 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_per_server;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::sample(
uint32_t table_id, uint64_t node_id, int sample_size,
std::vector<std::pair<uint64_t, float>> &res) {
int server_index = get_server_index_by_id(node_id);
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
std::future<int32_t> GraphBrpcClient::batch_sample(uint32_t table_id,
std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float> > > &res) {

std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
int server_index = get_server_index_by_id(node_ids[query_idx]);
if(server2request[server_index] == -1){
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
//res.push_back(std::vector<GraphNode>());
res.push_back(std::vector<std::pair<uint64_t, float>>());
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t> > node_id_buckets(request_call_num);
std::vector<std::vector<int> > query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}

DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int offset = 0;
while (offset < bytes_size) {
res.push_back({*(uint64_t *)(buffer + offset),
*(float *)(buffer + offset + GraphNode::id_size)});
offset += GraphNode::id_size + GraphNode::weight_size;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
++fail_num;
} else {
VLOG(0) << "check sample response: "
<< " " << closure->check_response(request_idx, PS_GRAPH_SAMPLE);
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;

int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx){
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back({*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start + GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
}
offset += actual_size;
}
}
if (fail_num == request_call_num){
ret = -1;
}
}
closure->set_promise_value(ret);
});

auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&node_id, sizeof(uint64_t));
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);

for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
// std::string type_str = GraphNode::node_type_to_string(type);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)node_id_buckets[request_idx].data(), sizeof(uint64_t)*node_num);
closure->request(request_idx)->add_params((char *)&sample_size, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx),
closure);
}

return fut;
}
Expand Down Expand Up @@ -124,4 +172,4 @@ int32_t GraphBrpcClient::initialize() {
return 0;
}
}
}
}
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
virtual std::future<int32_t> sample(
uint32_t table_id, uint64_t node_id, int sample_size,
std::vector<std::pair<uint64_t, float>> &res);
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size,
Expand Down
27 changes: 22 additions & 5 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,29 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
"graph_random_sample request requires at least 2 arguments");
return 0;
}
uint64_t node_id = *(uint64_t *)(request.params(0).c_str());
size_t num_nodes = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
char *buffer;
int actual_size;
table->random_sample(node_id, sample_size, buffer, actual_size);
cntl->response_attachment().append(buffer, actual_size);

std::vector<std::future<int>*> tasks;
std::vector<char*> buffers(num_nodes);
std::vector<int> actual_sizes(num_nodes);

for (size_t idx = 0; idx < num_nodes; ++idx){
//std::future<int> task = table->random_sample(node_data[idx], sample_size,
//buffers[idx], actual_sizes[idx]);
table->random_sample(node_data[idx], sample_size,
buffers[idx], actual_sizes[idx]);
//tasks.push_back(&task);
}
//for (size_t idx = 0; idx < num_nodes; ++idx){
//tasks[idx]->get();
//}
cntl->response_attachment().append(&num_nodes, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*num_nodes);
for (size_t idx = 0; idx < num_nodes; ++idx){
cntl->response_attachment().append(buffers[idx], actual_sizes[idx]);
}
return 0;
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
status.wait();
}
}
std::vector<std::pair<uint64_t, float>> GraphPyClient::sample_k(
std::string name, uint64_t node_id, int sample_size) {
std::vector<std::pair<uint64_t, float>> v;
std::vector<std::vector<std::pair<uint64_t, float> > > GraphPyClient::batch_sample_k(
std::string name, std::vector<uint64_t> node_ids, int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float> > > v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = worker_ptr->sample(table_id, node_id, sample_size, v);
auto status = worker_ptr->batch_sample(table_id, node_ids, sample_size, v);
status.wait();
}
return v;
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::vector<std::pair<uint64_t, float>> sample_k(std::string name,
uint64_t node_id,
int sample_size);
std::vector<std::vector<std::pair<uint64_t, float> > > batch_sample_k(
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<GraphNode> pull_graph_list(std::string name, int server_index,
int start, int size);
::paddle::distributed::PSParameter GetWorkerProto();
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ class PSClient {
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> sample(
uint32_t table_id, uint64_t node_id, int sample_size,
std::vector<std::pair<uint64_t, float>> &res) {
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ GraphNode *GraphTable::find_node(uint64_t id) {
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num_per_table % task_pool_size_;
}
//std::future<int> GraphTable::random_sample(uint64_t node_id, int sample_size,
//char *&buffer, int &actual_size) {
int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
char *&buffer, int &actual_size) {
return _shards_task_pool[get_thread_pool_index(node_id)]
Expand All @@ -226,6 +228,7 @@ int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
memcpy(buffer + offset, &weight, GraphNode::weight_size);
offset += GraphNode::weight_size;
}
return 0;
})
.get();
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class GraphTable : public SparseTable {
virtual ~GraphTable() {}
virtual int32_t pull_graph_list(int start, int size, char *&buffer,
int &actual_size);
//virtual std::future<int> random_sample(uint64_t node_id, int sampe_size, char *&buffer,
//int &actual_size);
virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer,
int &actual_size);
virtual int32_t initialize();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/table/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class Table {
int &actual_size) {
return 0;
}
//virtual std::future<int> random_sample(uint64_t node_id, int sampe_size, char *&buffer,
//int &actual_size) {
//return std::future<int>();
//}
virtual int32_t pour() { return 0; }

virtual void clear() = 0;
Expand Down
60 changes: 21 additions & 39 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,26 @@ void RunBrpcPushSparse() {

/*-----------------------Test Server Init----------------------------------*/
auto pull_status =
worker_ptr_->load(0, std::string(file_name), std::string(""));
worker_ptr_->load(0, std::string(file_name), std::string("edge"));

pull_status.wait();
std::vector<std::pair<uint64_t, float>> v;
pull_status = worker_ptr_->sample(0, 37, 4, v);
std::vector<std::vector<std::pair<uint64_t, float> > > vs;
//std::vector<std::pair<uint64_t, float>> v;
//pull_status = worker_ptr_->sample(0, 37, 4, v);
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 37), 4, vs);
pull_status.wait();
ASSERT_EQ(v.size(), 3);
v.clear();
pull_status = worker_ptr_->sample(0, 96, 4, v);
ASSERT_EQ(vs[0].size(), 3);
vs.clear();
//pull_status = worker_ptr_->sample(0, 96, 4, v);
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 96), 4, vs);
pull_status.wait();
std::unordered_set<int> s = {111, 48, 247};
ASSERT_EQ(3, v.size());
for (auto g : v) {
ASSERT_EQ(3, vs[0].size());
for (auto g : vs[0]) {
// std::cout << g.first << std::endl;
ASSERT_EQ(true, s.find(g.first) != s.end());
}
v.clear();
vs.clear();
std::vector<distributed::GraphNode> nodes;
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes);
pull_status.wait();
Expand Down Expand Up @@ -276,38 +279,17 @@ void RunBrpcPushSparse() {
nodes = client2.pull_graph_list(std::string("user2item"), 0, 1, 4);
ASSERT_EQ(nodes[0].get_id(), 59);
nodes.clear();
v = client1.sample_k(std::string("user2item"), 96, 4);
ASSERT_EQ(v.size(), 3);
std::cout << "sample result" << std::endl;
for (auto p : v) {
vs = client1.batch_sample_k(std::string("user2item"), std::vector<uint64_t>(1, 96), 4);
ASSERT_EQ(vs[0].size(), 3);
std::cout << "batch sample result" << std::endl;
for (auto p : vs[0]) {
std::cout << p.first << " " << p.second << std::endl;
}
/*
from paddle.fluid.core import GraphPyService
ips_str = "127.0.0.1:4211;127.0.0.1:4212"
server1 = GraphPyServer()
server2 = GraphPyServer()
client1 = GraphPyClient()
client2 = GraphPyClient()
edge_types = ["user2item"]
server1.set_up(ips_str,127,edge_types,0);
server2.set_up(ips_str,127,edge_types,1);
client1.set_up(ips_str,127,edge_types,0);
client2.set_up(ips_str,127,edge_types,1);
server1.start_server();
server2.start_server();
client1.start_client();
client2.start_client();
client1.load_edge_file(user2item", "input.txt", 0);
list = client2.pull_graph_list("user2item",0,1,4)
for x in list:
print(x.get_id())
list = client1.sample_k("user2item",96, 4);
for x in list:
print(x.get_id())
*/

std::vector<uint64_t> node_ids;
node_ids.push_back(96);
node_ids.push_back(37);
vs = client1.batch_sample_k(std::string("user2item"), node_ids, 4);
ASSERT_EQ(vs.size(), 2);
// to test in python,try this:
// from paddle.fluid.core import GraphPyService
// ips_str = "127.0.0.1:4211;127.0.0.1:4212"
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ void BindGraphPyClient(py::module* m) {
.def("load_node_file", &GraphPyClient::load_node_file)
.def("set_up", &GraphPyClient::set_up)
.def("pull_graph_list", &GraphPyClient::pull_graph_list)
.def("sample_k", &GraphPyClient::sample_k)
.def("start_client", &GraphPyClient::start_client);
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_k", &GraphPyClient::batch_sample_k);
}

} // end namespace pybind
Expand Down

0 comments on commit 2abf38c

Please sign in to comment.