From e087485e45594d1b0c25456c902a0b0e04122d73 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Fri, 12 Mar 2021 15:54:45 +0800 Subject: [PATCH 1/5] add load_nodes; change add_node function --- .../distributed/table/common_graph_table.cc | 83 ++++++++++++++----- .../distributed/table/common_graph_table.h | 8 +- 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index f4f235b114dba2..c96eb4680d5aa5 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -64,21 +64,22 @@ size_t GraphShard::get_size() { return res; } -std::list::iterator GraphShard::add_node(GraphNode *node) { - if (node_location.find(node->get_id()) != node_location.end()) - return node_location.find(node->get_id())->second; +std::list::iterator GraphShard::add_node(uint64_t id, std::string feature) { + if (node_location.find(id) != node_location.end()) + return node_location.find(id)->second; - int index = node->get_id() % shard_num % bucket_size; + int index = id % shard_num % bucket_size; + GraphNode *node = new GraphNode(id, std::string("")); std::list::iterator iter = bucket[index].insert(bucket[index].end(), node); - node_location[node->get_id()] = iter; + node_location[id] = iter; return iter; } void GraphShard::add_neighboor(uint64_t id, GraphEdge *edge) { - (*add_node(new GraphNode(id, std::string(""))))->add_edge(edge); + (*add_node(id, std::string("")))->add_edge(edge); } GraphNode *GraphShard::find_node(uint64_t id) { @@ -89,14 +90,54 @@ GraphNode *GraphShard::find_node(uint64_t id) { int32_t GraphTable::load(const std::string &path, const std::string ¶m) { auto cmd = paddle::string::split_string(param, "|"); std::set cmd_set(cmd.begin(), cmd.end()); - bool load_edge = cmd_set.count(std::string("edge")); bool reverse_edge = cmd_set.count(std::string("reverse")); - VLOG(0) << "Reverse Edge " << reverse_edge; - + bool load_edge = cmd_set.count(std::string("edge")); + if(load_edge) { + return this -> load_edges(path, reverse_edge); + } + else { + return this -> load_nodes(path); + } +} + +int32_t GraphTable::load_nodes(const std::string &path) { + auto paths = paddle::string::split_string(path, ";"); + for (auto path : paths) { + std::ifstream file(path); + std::string line; + while (std::getline(file, line)) { + auto values = paddle::string::split_string(line, "\t"); + if (values.size() < 2) continue; + auto id = std::stoull(values[1]); + + + size_t shard_id = id % shard_num; + if (shard_id >= shard_end || shard_id < shard_start) { + VLOG(0) << "will not load " << id << " from " << path + << ", please check id distribution"; + continue; + + } + + std::string node_type = values[0]; + std::vector feature; + feature.push_back(node_type); + for(size_t slice = 2; slice < values.size(); slice ++) { + feature.push_back(values[slice]); + } + auto feat = paddle::string::join_strings(feature, '\t'); + size_t index = shard_id - shard_start; + shards[index].add_node(id, std::string("")); + + } + } + return 0; +} + + +int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { auto paths = paddle::string::split_string(path, ";"); - VLOG(0) << paths.size(); int count = 0; - for (auto path : paths) { std::ifstream file(path); std::string line; @@ -113,6 +154,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) { if (values.size() == 3) { weight = std::stod(values[2]); } + size_t src_shard_id = src_id % shard_num; if (src_shard_id >= shard_end || src_shard_id < shard_start) { @@ -121,6 +163,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) { continue; } + size_t index = src_shard_id - shard_start; GraphEdge *edge = new GraphEdge(dst_id, weight); shards[index].add_neighboor(src_id, edge); @@ -129,19 +172,21 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) { VLOG(0) << "Load Finished Total Edge Count " << count; // Build Sampler j + for (auto &shard : shards) { - auto bucket = shard.get_bucket(); - for (int i = 0; i < bucket.size(); i++) { - std::list::iterator iter = bucket[i].begin(); - while (iter != bucket[i].end()) { - auto node = *iter; - node->build_sampler(); - iter++; - } + auto bucket = shard.get_bucket(); + for (int i = 0; i < bucket.size(); i ++) { + std::list::iterator iter = bucket[i].begin(); + while (iter != bucket[i].end()) { + auto node = *iter; + node->build_sampler(); + iter++; } + } } return 0; } + GraphNode *GraphTable::find_node(uint64_t id) { size_t shard_id = id % shard_num; if (shard_id >= shard_end || shard_id < shard_start) { diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 1f2b8c86d363bc..decf5f1f204623 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -52,7 +52,7 @@ class GraphShard { } return -1; } - std::list::iterator add_node(GraphNode *node); + std::list::iterator add_node(uint64_t id, std::string feature); GraphNode *find_node(uint64_t id); void add_neighboor(uint64_t id, GraphEdge *edge); std::unordered_map::iterator> @@ -74,7 +74,13 @@ class GraphTable : public SparseTable { virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer, int &actual_size); virtual int32_t initialize(); + int32_t load(const std::string &path, const std::string ¶m); + + int32_t load_edges(const std::string &path, bool reverse); + + int32_t load_nodes(const std::string &path); + GraphNode *find_node(uint64_t id); virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) { From 2dc7ebfe0342f39fbbfdc16170f39fd5fb6aa7b3 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Mon, 15 Mar 2021 11:39:44 +0800 Subject: [PATCH 2/5] resolve conflict --- .../distributed/table/common_graph_table.cc | 75 +++++++------------ 1 file changed, 25 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index c96eb4680d5aa5..600f17d6d02461 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -13,10 +13,10 @@ // limitations under the License. #include "paddle/fluid/distributed/table/common_graph_table.h" -#include -#include #include +#include #include +#include #include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/string_helper.h" @@ -136,23 +136,25 @@ int32_t GraphTable::load_nodes(const std::string &path) { int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { + auto paths = paddle::string::split_string(path, ";"); int count = 0; + for (auto path : paths) { std::ifstream file(path); std::string line; while (std::getline(file, line)) { auto values = paddle::string::split_string(line, "\t"); - count ++; + count++; if (values.size() < 2) continue; auto src_id = std::stoull(values[0]); auto dst_id = std::stoull(values[1]); - if(reverse_edge) { - std::swap(src_id, dst_id); + if (reverse_edge) { + std::swap(src_id, dst_id); } - double weight = 0; + float weight = 0; if (values.size() == 3) { - weight = std::stod(values[2]); + weight = std::stof(values[2]); } size_t src_shard_id = src_id % shard_num; @@ -161,7 +163,6 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { VLOG(0) << "will not load " << src_id << " from " << path << ", please check id distribution"; continue; - } size_t index = src_shard_id - shard_start; @@ -175,7 +176,8 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { for (auto &shard : shards) { auto bucket = shard.get_bucket(); - for (int i = 0; i < bucket.size(); i ++) { + + for (int i = 0; i < bucket.size(); i++) { std::list::iterator iter = bucket[i].begin(); while (iter != bucket[i].end()) { auto node = *iter; @@ -193,8 +195,6 @@ GraphNode *GraphTable::find_node(uint64_t id) { return NULL; } size_t index = shard_id - shard_start; - // VLOG(0)<<"try to find node-id "< res = node->sample_k(sample_size); std::vector node_list; - int total_size = 0; - for (auto x : res) { - GraphNode temp; - temp.set_id(x->id); - total_size += temp.get_size(); - node_list.push_back(temp); + actual_size = + res.size() * (GraphNode::id_size + GraphNode::weight_size); + buffer = new char[actual_size]; + int offset = 0; + uint64_t id; + float weight; + for (auto &x : res) { + id = x->get_id(); + weight = x->get_weight(); + memcpy(buffer + offset, &id, GraphNode::id_size); + offset += GraphNode::id_size; + memcpy(buffer + offset, &weight, GraphNode::weight_size); + offset += GraphNode::weight_size; } - buffer = new char[total_size]; - int index = 0; - for (auto x : node_list) { - x.to_buffer(buffer + index); - index += x.get_size(); - } - actual_size = total_size; - return 0; }) .get(); - // GraphNode *node = find_node(node_id, type); - // if (node == NULL) { - // actual_size = 0; - // rwlock_->UNLock(); - // return 0; - // } - // std::vector res = node->sample_k(sample_size); - // std::vector node_list; - // int total_size = 0; - // for (auto x : res) { - // GraphNode temp; - // temp.set_id(x->id); - // temp.set_graph_node_type(x->type); - // total_size += temp.get_size(); - // node_list.push_back(temp); - // } - // buffer = new char[total_size]; - // int index = 0; - // for (auto x : node_list) { - // x.to_buffer(buffer + index); - // index += x.get_size(); - // } - // actual_size = total_size; - // rwlock_->UNLock(); - // return 0; } int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer, int &actual_size) { @@ -338,3 +312,4 @@ int32_t GraphTable::initialize() { } } }; + From e35913568f27dd0bb4abe70242b59a4b860f3c6e Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Mon, 15 Mar 2021 11:45:11 +0800 Subject: [PATCH 3/5] resolved conflict --- .../distributed/service/graph_brpc_client.cc | 25 +++--- .../distributed/service/graph_brpc_client.h | 7 +- .../distributed/service/graph_py_service.h | 84 ++++++++++--------- paddle/fluid/distributed/service/ps_client.h | 6 +- paddle/fluid/distributed/table/graph_node.cc | 3 +- paddle/fluid/distributed/table/graph_node.h | 10 ++- .../distributed/table/weighted_sampler.cc | 15 ++-- .../distributed/table/weighted_sampler.h | 19 ++--- .../fluid/distributed/test/graph_node_test.cc | 55 +++++++----- 9 files changed, 116 insertions(+), 108 deletions(-) diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index 16014c8dbf23f1..bc6a03b9eaf849 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/distributed/service/graph_brpc_client.h" #include #include #include #include +#include #include #include "Eigen/Dense" - #include "paddle/fluid/distributed/service/brpc_ps_client.h" -#include "paddle/fluid/distributed/service/graph_brpc_client.h" #include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/string/string_helper.h" @@ -35,9 +35,9 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) { return id % shard_num / shard_per_server; } // char* &buffer,int &actual_size -std::future GraphBrpcClient::sample(uint32_t table_id, - uint64_t node_id, int sample_size, - std::vector &res) { +std::future GraphBrpcClient::sample( + uint32_t table_id, uint64_t node_id, int sample_size, + std::vector> &res) { int server_index = get_server_index_by_id(node_id); DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { int ret = 0; @@ -45,19 +45,16 @@ std::future GraphBrpcClient::sample(uint32_t table_id, if (closure->check_response(0, PS_GRAPH_SAMPLE) != 0) { ret = -1; } else { - VLOG(0) << "check sample response: " - << " " << closure->check_response(0, PS_GRAPH_SAMPLE); auto &res_io_buffer = closure->cntl(0)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); size_t bytes_size = io_buffer_itr.bytes_left(); char *buffer = new char[bytes_size]; io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); - int start = 0; - while (start < bytes_size) { - GraphNode node; - node.recover_from_buffer(buffer + start); - start += node.get_size(); - res.push_back(node); + int offset = 0; + while (offset < bytes_size) { + res.push_back({*(uint64_t *)(buffer + offset), + *(float *)(buffer + offset + GraphNode::id_size)}); + offset += GraphNode::id_size + GraphNode::weight_size; } } closure->set_promise_value(ret); @@ -69,9 +66,7 @@ std::future GraphBrpcClient::sample(uint32_t table_id, closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE); closure->request(0)->set_table_id(table_id); closure->request(0)->set_client_id(_client_id); - // std::string type_str = GraphNode::node_type_to_string(type); closure->request(0)->add_params((char *)&node_id, sizeof(uint64_t)); - // closure->request(0)->add_params(type_str.c_str(), type_str.size()); closure->request(0)->add_params((char *)&sample_size, sizeof(int)); PsService_Stub rpc_stub(get_cmd_channel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 8e472b96be94d1..84d4dbb78cba40 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -18,6 +18,7 @@ #include #include +#include #include "brpc/channel.h" #include "brpc/controller.h" #include "brpc/server.h" @@ -35,9 +36,9 @@ class GraphBrpcClient : public BrpcPsClient { public: GraphBrpcClient() {} virtual ~GraphBrpcClient() {} - virtual std::future sample(uint32_t table_id, uint64_t node_id, - int sample_size, - std::vector &res); + virtual std::future sample( + uint32_t table_id, uint64_t node_id, int sample_size, + std::vector> &res); virtual std::future pull_graph_list(uint32_t table_id, int server_index, int start, int size, diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index 25b70d11fe05b6..82272deabf721e 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -21,8 +21,8 @@ #include #include #include // NOLINT -#include #include +#include #include "google/protobuf/text_format.h" #include "gtest/gtest.h" @@ -47,7 +47,7 @@ class GraphPyService { std::vector keys; std::vector server_list, port_list, host_sign_list; int server_size, shard_num, rank, client_id; - std::unordered_map table_id_map; + std::unordered_map table_id_map; std::thread *server_thread, *client_thread; std::shared_ptr pserver_ptr; @@ -68,7 +68,8 @@ class GraphPyService { int get_shard_num() { return shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; } void GetDownpourSparseTableProto( - ::paddle::distributed::TableParameter* sparse_table_proto, uint32_t table_id) { + ::paddle::distributed::TableParameter* sparse_table_proto, + uint32_t table_id) { sparse_table_proto->set_table_id(table_id); sparse_table_proto->set_table_class("GraphTable"); sparse_table_proto->set_shard_num(shard_num); @@ -97,14 +98,13 @@ class GraphPyService { server_service_proto->set_start_server_port(0); server_service_proto->set_server_thread_num(12); - for(auto& tuple : this -> table_id_map) { - ::paddle::distributed::TableParameter* sparse_table_proto = - downpour_server_proto->add_downpour_table_param(); - GetDownpourSparseTableProto(sparse_table_proto, tuple.second); + for (auto& tuple : this->table_id_map) { + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(sparse_table_proto, tuple.second); } return server_fleet_desc; - } ::paddle::distributed::PSParameter GetWorkerProto() { @@ -116,10 +116,10 @@ class GraphPyService { ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = worker_proto->mutable_downpour_worker_param(); - for(auto& tuple : this -> table_id_map) { - ::paddle::distributed::TableParameter* worker_sparse_table_proto = - downpour_worker_proto->add_downpour_table_param(); - GetDownpourSparseTableProto(worker_sparse_table_proto, tuple.second); + for (auto& tuple : this->table_id_map) { + ::paddle::distributed::TableParameter* worker_sparse_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(worker_sparse_table_proto, tuple.second); } ::paddle::distributed::ServerParameter* server_proto = @@ -134,10 +134,10 @@ class GraphPyService { server_service_proto->set_start_server_port(0); server_service_proto->set_server_thread_num(12); - for(auto& tuple : this -> table_id_map) { - ::paddle::distributed::TableParameter* sparse_table_proto = - downpour_server_proto->add_downpour_table_param(); - GetDownpourSparseTableProto(sparse_table_proto, tuple.second); + for (auto& tuple : this->table_id_map) { + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(sparse_table_proto, tuple.second); } return worker_fleet_desc; @@ -148,44 +148,46 @@ class GraphPyService { void load_edge_file(std::string name, std::string filepath, bool reverse) { std::string params = "edge"; - if(reverse) { - params += "|reverse"; + if (reverse) { + params += "|reverse"; } - if (this -> table_id_map.count(name)) { - uint32_t table_id = this -> table_id_map[name]; - auto status = - get_ps_client()->load(table_id, std::string(filepath), params); - status.wait(); + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + get_ps_client()->load(table_id, std::string(filepath), params); + status.wait(); } } void load_node_file(std::string name, std::string filepath) { std::string params = "node"; - if (this -> table_id_map.count(name)) { - uint32_t table_id = this -> table_id_map[name]; - auto status = - get_ps_client()->load(table_id, std::string(filepath), params); - status.wait(); + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + get_ps_client()->load(table_id, std::string(filepath), params); + status.wait(); } } - std::vector sample_k(std::string name, uint64_t node_id, int sample_size) { - std::vector v; - if (this -> table_id_map.count(name)) { - uint32_t table_id = this -> table_id_map[name]; - auto status = worker_ptr->sample(table_id, node_id, sample_size, v); - status.wait(); + std::vector> sample_k(std::string name, + uint64_t node_id, + int sample_size) { + std::vector> v; + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = worker_ptr->sample(table_id, node_id, sample_size, v); + status.wait(); } return v; } - std::vector pull_graph_list(std::string name, int server_index, int start, - int size) { + std::vector pull_graph_list(std::string name, int server_index, + int start, int size) { std::vector res; - if (this -> table_id_map.count(name)) { - uint32_t table_id = this -> table_id_map[name]; - auto status = - worker_ptr->pull_graph_list(table_id, server_index, start, size, res); - status.wait(); + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + worker_ptr->pull_graph_list(table_id, server_index, start, size, res); + status.wait(); } return res; } diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index b6014b9aea1393..8da749931cc2f0 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -155,9 +155,9 @@ class PSClient { promise.set_value(-1); return fut; } - virtual std::future sample(uint32_t table_id, uint64_t node_id, - int sample_size, - std::vector &res) { + virtual std::future sample( + uint32_t table_id, uint64_t node_id, int sample_size, + std::vector> &res) { LOG(FATAL) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); diff --git a/paddle/fluid/distributed/table/graph_node.cc b/paddle/fluid/distributed/table/graph_node.cc index 78a586d507ef46..c63fff8883636e 100644 --- a/paddle/fluid/distributed/table/graph_node.cc +++ b/paddle/fluid/distributed/table/graph_node.cc @@ -16,9 +16,8 @@ #include namespace paddle { namespace distributed { -int GraphNode::enum_size = sizeof(int); +int GraphNode::weight_size = sizeof(float); int GraphNode::id_size = sizeof(uint64_t); -int GraphNode::double_size = sizeof(double); int GraphNode::int_size = sizeof(int); int GraphNode::get_size() { return feature.size() + id_size + int_size; } void GraphNode::build_sampler() { diff --git a/paddle/fluid/distributed/table/graph_node.h b/paddle/fluid/distributed/table/graph_node.h index 218d14e01edc17..a8fe5eca3e8244 100644 --- a/paddle/fluid/distributed/table/graph_node.h +++ b/paddle/fluid/distributed/table/graph_node.h @@ -20,11 +20,13 @@ namespace distributed { // enum GraphNodeType { user = 0, item = 1, query = 2, unknown = 3 }; class GraphEdge : public WeightedObject { public: - double weight; - uint64_t id; // GraphNodeType type; GraphEdge() {} - GraphEdge(uint64_t id, double weight) : weight(weight), id(id) {} + GraphEdge(uint64_t id, float weight) : id(id), weight(weight) {} + uint64_t get_id() { return id; } + float get_weight() { return weight; } + uint64_t id; + float weight; }; class GraphNode { public: @@ -35,7 +37,7 @@ class GraphNode { : id(id), feature(feature), sampler(NULL) {} virtual ~GraphNode() {} std::vector get_graph_edge() { return edges; } - static int enum_size, id_size, int_size, double_size; + static int id_size, int_size, weight_size; uint64_t get_id() { return id; } void set_id(uint64_t id) { this->id = id; } // GraphNodeType get_graph_node_type() { return type; } diff --git a/paddle/fluid/distributed/table/weighted_sampler.cc b/paddle/fluid/distributed/table/weighted_sampler.cc index c93bc551f54f34..09ecdc2b642e4a 100644 --- a/paddle/fluid/distributed/table/weighted_sampler.cc +++ b/paddle/fluid/distributed/table/weighted_sampler.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/table/weighted_sampler.h" +#include namespace paddle { namespace distributed { void WeightedSampler::build(WeightedObject **v, int start, int end) { @@ -37,11 +38,11 @@ std::vector WeightedSampler::sample_k(int k) { k = count; } std::vector sample_result; - double subtract; - std::unordered_map subtract_weight_map; + float subtract; + std::unordered_map subtract_weight_map; std::unordered_map subtract_count_map; while (k--) { - double query_weight = rand() % 100000 / 100000.0; + float query_weight = rand() % 100000 / 100000.0; query_weight *= weight - subtract_weight_map[this]; sample_result.push_back(sample(query_weight, subtract_weight_map, subtract_count_map, subtract)); @@ -49,10 +50,10 @@ std::vector WeightedSampler::sample_k(int k) { return sample_result; } WeightedObject *WeightedSampler::sample( - double query_weight, - std::unordered_map &subtract_weight_map, + float query_weight, + std::unordered_map &subtract_weight_map, std::unordered_map &subtract_count_map, - double &subtract) { + float &subtract) { if (left == NULL) { subtract_weight_map[this] = weight; subtract = weight; @@ -61,7 +62,7 @@ WeightedObject *WeightedSampler::sample( } int left_count = left->count - subtract_count_map[left]; int right_count = right->count - subtract_count_map[right]; - double left_subtract = subtract_weight_map[left]; + float left_subtract = subtract_weight_map[left]; WeightedObject *return_id; if (right_count == 0 || left_count > 0 && left->weight - left_subtract >= query_weight) { diff --git a/paddle/fluid/distributed/table/weighted_sampler.h b/paddle/fluid/distributed/table/weighted_sampler.h index 53bfaa8d301194..9ed2cc04649de8 100644 --- a/paddle/fluid/distributed/table/weighted_sampler.h +++ b/paddle/fluid/distributed/table/weighted_sampler.h @@ -22,15 +22,8 @@ class WeightedObject { public: WeightedObject() {} virtual ~WeightedObject() {} - virtual unsigned long long get_id() { return id; } - virtual double get_weight() { return weight; } - - virtual void set_id(unsigned long long id) { this->id = id; } - virtual void set_weight(double weight) { this->weight = weight; } - - private: - unsigned long long id; - double weight; + virtual uint64_t get_id() = 0; + virtual float get_weight() = 0; }; class WeightedSampler { @@ -38,16 +31,16 @@ class WeightedSampler { WeightedSampler *left, *right; WeightedObject *object; int count; - double weight; + float weight; void build(WeightedObject **v, int start, int end); std::vector sample_k(int k); private: WeightedObject *sample( - double query_weight, - std::unordered_map &subtract_weight_map, + float query_weight, + std::unordered_map &subtract_weight_map, std::unordered_map &subtract_count_map, - double &subtract); + float &subtract); }; } } diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index a2a87025aaed80..fccca80a173021 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include // NOLINT +#include #include #include "google/protobuf/text_format.h" @@ -49,10 +50,18 @@ namespace memory = paddle::memory; namespace distributed = paddle::distributed; void testGraphToBuffer(); -std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"), - std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"), - std::string("59\ttreat\t45;0.34\t145;0.31\t112;0.21"), - std::string("97\tfood\t48;1.4\t247;0.31\t111;1.21")}; +// std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"), +// std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"), +// std::string("59\ttreat\t45;0.34\t145;0.31\t112;0.21"), +// std::string("97\tfood\t48;1.4\t247;0.31\t111;1.21")}; + +std::string nodes[] = { + std::string("37\t45\t0.34"), std::string("37\t145\t0.31"), + std::string("37\t112\t0.21"), std::string("96\t48\t1.4"), + std::string("96\t247\t0.31"), std::string("96\t111\t1.21"), + std::string("59\t45\t0.34"), std::string("59\t145\t0.31"), + std::string("59\t122\t0.21"), std::string("97\t48\t0.34"), + std::string("97\t247\t0.31"), std::string("97\t111\t0.21")}; char file_name[] = "nodes.txt"; void prepare_file(char file_name[]) { std::ofstream ofile; @@ -210,7 +219,7 @@ void RunBrpcPushSparse() { worker_ptr_->load(0, std::string(file_name), std::string("")); pull_status.wait(); - std::vector v; + std::vector> v; pull_status = worker_ptr_->sample(0, 37, 4, v); pull_status.wait(); // for (auto g : v) { @@ -220,40 +229,46 @@ void RunBrpcPushSparse() { v.clear(); pull_status = worker_ptr_->sample(0, 96, 4, v); pull_status.wait(); + std::unordered_set s = { 111, 48, 247 } ASSERT_EQ(3, v.size()); for (auto g : v) { - std::cout << g.get_id() << std::endl; + // std::cout << g.first << std::endl; + ASSERT_EQ(true, s.find(g.first) != s.end()) } - // ASSERT_EQ(v.size(),3); v.clear(); - pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, v); + std::vector nodes; + pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes); pull_status.wait(); - ASSERT_EQ(v.size(), 1); - ASSERT_EQ(v[0].get_id(), 37); + ASSERT_EQ(nodes.size(), 1); + ASSERT_EQ(nodes[0].get_id(), 37); // for (auto g : v) { // std::cout << g.get_id() << " " << g.get_graph_node_type() << std::endl; // } // ASSERT_EQ(v.size(),1); - v.clear(); - pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, v); + nodes.clear(); + pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, nodes); pull_status.wait(); - ASSERT_EQ(v.size(), 1); - ASSERT_EQ(v[0].get_id(), 59); - for (auto g : v) { + ASSERT_EQ(nodes.size(), 1); + ASSERT_EQ(nodes[0].get_id(), 59); + for (auto g : nodes) { std::cout << g.get_id() << std::endl; } distributed::GraphPyService gps1, gps2; std::string ips_str = "127.0.0.1:4211;127.0.0.1:4212"; - std::vector edge_types = { std::string("user2item")}; + std::vector edge_types = {std::string("user2item")}; gps1.set_up(ips_str, 127, 0, 0, edge_types); gps2.set_up(ips_str, 127, 1, 1, edge_types); gps1.load_edge_file(std::string("user2item"), std::string(file_name), 0); - v.clear(); - v = gps2.pull_graph_list(std::string("user2item"), 0, 1, 4); - ASSERT_EQ(v[0].get_id(), 59); - v.clear(); + nodes.clear(); + nodes = gps2.pull_graph_list(std::string("user2item"), 0, 1, 4); + ASSERT_EQ(nodes[0].get_id(), 59); + nodes.clear(); v = gps2.sample_k(std::string("user2item"), 96, 4); ASSERT_EQ(v.size(), 3); + std::cout << "sample result" << std::endl; + for (auto p : v) { + std::cout << p.first << " " << p.second << std::endl; + } // to test in python,try this: // from paddle.fluid.core import GraphPyService // ips_str = "127.0.0.1:4211;127.0.0.1:4212" From 290840f47dc62b942acd09777b572da5a3bbd044 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Mon, 15 Mar 2021 11:59:14 +0800 Subject: [PATCH 4/5] resolved conflict --- paddle/fluid/distributed/test/graph_node_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index fccca80a173021..50b4c0eda9b6f1 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -229,10 +229,11 @@ void RunBrpcPushSparse() { v.clear(); pull_status = worker_ptr_->sample(0, 96, 4, v); pull_status.wait(); - std::unordered_set s = { 111, 48, 247 } ASSERT_EQ(3, v.size()); + std::unordered_set s = { 111, 48, 247 }; + ASSERT_EQ(3, v.size()); for (auto g : v) { // std::cout << g.first << std::endl; - ASSERT_EQ(true, s.find(g.first) != s.end()) + ASSERT_EQ(true, s.find(g.first) != s.end()); } v.clear(); std::vector nodes; From 2feadfee4f06622af692e5f9a1c2da05bd748a46 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Mon, 15 Mar 2021 17:39:30 +0800 Subject: [PATCH 5/5] resolved conflict --- paddle/fluid/distributed/table/common_graph_table.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index a352c24c383df0..107c619235ad1a 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -69,7 +69,7 @@ std::list::iterator GraphShard::add_node(uint64_t id, std::string f return node_location.find(id)->second; int index = id % shard_num % bucket_size; - GraphNode *node = new GraphNode(id, std::string("")); + GraphNode *node = new GraphNode(id, feature); std::list::iterator iter = bucket[index].insert(bucket[index].end(), node); @@ -127,7 +127,7 @@ int32_t GraphTable::load_nodes(const std::string &path) { } auto feat = paddle::string::join_strings(feature, '\t'); size_t index = shard_id - shard_start; - shards[index].add_node(id, std::string("")); + shards[index].add_node(id, feat); } }