Skip to content

Commit

Permalink
Merge pull request #5 from WeiyueSu/sample
Browse files Browse the repository at this point in the history
sample with srand
  • Loading branch information
seemingwang authored Mar 19, 2021
2 parents 6627904 + ec2555a commit 46dc17f
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 18 deletions.
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int count = 0;
std::string sample_type = "random";

for (auto path : paths) {
std::ifstream file(path);
Expand All @@ -159,9 +160,10 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
if (reverse_edge) {
std::swap(src_id, dst_id);
}
float weight = 0;
float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
}

size_t src_shard_id = src_id % shard_num;
Expand All @@ -184,8 +186,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler();
}
bucket[i]->build_sampler(sample_type); }
}
return 0;
}
Expand Down
15 changes: 10 additions & 5 deletions paddle/fluid/distributed/table/graph_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ int GraphNode::int_size = sizeof(int);
int GraphNode::get_size(bool need_feature) {
return id_size + int_size + (need_feature ? feature.size() : 0);
}
void GraphNode::build_sampler() {
sampler = new WeightedSampler();
GraphEdge** arr = edges.data();
sampler->build((WeightedObject**)arr, 0, edges.size());
void GraphNode::build_sampler(std::string sample_type) {
if (sample_type == "random"){
sampler = new RandomSampler();
} else if (sample_type == "weighted"){
sampler = new WeightedSampler();
}
//GraphEdge** arr = edges.data();
//sampler->build((WeightedObject**)arr, 0, edges.size());
sampler->build((std::vector<WeightedObject*>*)&edges);
}
void GraphNode::to_buffer(char* buffer, bool need_feature) {
int size = get_size(need_feature);
Expand All @@ -51,4 +56,4 @@ void GraphNode::recover_from_buffer(char* buffer) {
// type = GraphNodeType(int_state);
}
}
}
}
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/table/graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GraphNode {
void set_feature(std::string feature) { this->feature = feature; }
std::string get_feature() { return feature; }
virtual int get_size(bool need_feature);
virtual void build_sampler();
virtual void build_sampler(std::string sample_type);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual void add_edge(GraphEdge *edge) { edges.push_back(edge); }
Expand All @@ -58,7 +58,7 @@ class GraphNode {
protected:
uint64_t id;
std::string feature;
WeightedSampler *sampler;
Sampler *sampler;
std::vector<GraphEdge *> edges;
};
}
Expand Down
79 changes: 74 additions & 5 deletions paddle/fluid/distributed/table/weighted_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,87 @@

#include "paddle/fluid/distributed/table/weighted_sampler.h"
#include <iostream>
#include<unordered_map>
namespace paddle {
namespace distributed {
void WeightedSampler::build(WeightedObject **v, int start, int end) {

void RandomSampler::build(std::vector<WeightedObject*>* edges) {
this->edges = edges;
}

std::vector<WeightedObject *> RandomSampler::sample_k(int k) {
int n = edges->size();
if (k > n){
k = n;
}
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::vector<WeightedObject *> sample_result;
std::unordered_map<int, int> replace_map;
while(k--){
int rand_int = rand() % n;
auto iter = replace_map.find(rand_int);
if(iter == replace_map.end()){
sample_result.push_back(edges->at(rand_int));
}else{
sample_result.push_back(edges->at(iter->second));
}

iter = replace_map.find(n - 1);
if(iter == replace_map.end()){
replace_map[rand_int] = n - 1;
}else{
replace_map[rand_int] = iter->second;
}
--n;
}
return sample_result;
}

WeightedSampler::WeightedSampler(){
left = nullptr;
right = nullptr;
object = nullptr;
}

WeightedSampler::~WeightedSampler() {
if(left != nullptr){
delete left;
left = nullptr;
}
if(right != nullptr){
delete right;
right = nullptr;
}
}

void WeightedSampler::build(std::vector<WeightedObject*>* edges) {
if(left != nullptr){
delete left;
left = nullptr;
}
if(right != nullptr){
delete right;
right = nullptr;
}
WeightedObject** v = edges->data();
return build_one(v, 0, edges->size());
}

void WeightedSampler::build_one(WeightedObject **v, int start, int end) {
count = 0;
if (start + 1 == end) {
left = right = NULL;
left = right = nullptr;
weight = v[start]->get_weight();
object = v[start];
count = 1;

} else {
left = new WeightedSampler();
right = new WeightedSampler();
left->build(v, start, start + (end - start) / 2);
right->build(v, start + (end - start) / 2, end);
left->build_one(v, start, start + (end - start) / 2);
right->build_one(v, start + (end - start) / 2, end);
weight = left->weight + right->weight;
count = left->count + right->count;
}
Expand All @@ -41,6 +107,9 @@ std::vector<WeightedObject *> WeightedSampler::sample_k(int k) {
float subtract;
std::unordered_map<WeightedSampler *, float> subtract_weight_map;
std::unordered_map<WeightedSampler *, int> subtract_count_map;
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
while (k--) {
float query_weight = rand() % 100000 / 100000.0;
query_weight *= weight - subtract_weight_map[this];
Expand All @@ -54,7 +123,7 @@ WeightedObject *WeightedSampler::sample(
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract) {
if (left == NULL) {
if (left == nullptr) {
subtract_weight_map[this] = weight;
subtract = weight;
subtract_count_map[this] = 1;
Expand Down
25 changes: 22 additions & 3 deletions paddle/fluid/distributed/table/weighted_sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <vector>
namespace paddle {
namespace distributed {

class WeightedObject {
public:
WeightedObject() {}
Expand All @@ -26,14 +27,32 @@ class WeightedObject {
virtual float get_weight() = 0;
};

class WeightedSampler {
class Sampler {
public:
virtual ~Sampler() {}
virtual void build(std::vector<WeightedObject*>* edges) = 0;
virtual std::vector<WeightedObject *> sample_k(int k) = 0;
};

class RandomSampler: public Sampler {
public:
virtual ~RandomSampler() {}
virtual void build(std::vector<WeightedObject*>* edges);
virtual std::vector<WeightedObject *> sample_k(int k);
std::vector<WeightedObject*>* edges;
};

class WeightedSampler: public Sampler {
public:
WeightedSampler();
virtual ~WeightedSampler();
WeightedSampler *left, *right;
WeightedObject *object;
int count;
float weight;
void build(WeightedObject **v, int start, int end);
std::vector<WeightedObject *> sample_k(int k);
virtual void build(std::vector<WeightedObject*>* edges);
virtual void build_one(WeightedObject **v, int start, int end);
virtual std::vector<WeightedObject *> sample_k(int k);

private:
WeightedObject *sample(
Expand Down

0 comments on commit 46dc17f

Please sign in to comment.