-
Notifications
You must be signed in to change notification settings - Fork 330
/
Copy pathnode2vec_randomwalk.cc
329 lines (275 loc) · 10.4 KB
/
node2vec_randomwalk.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
/*
Tencent is pleased to support the open source community by making
Plato available.
Copyright (C) 2019 THL A29 Limited, a Tencent company.
All rights reserved.
Licensed under the BSD 3-Clause License (the "License"); you may
not use this file except in compliance with the License. You may
obtain a copy of the License at
https://opensource.org/licenses/BSD-3-Clause
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" basis,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied. See the License for the specific language governing
permissions and limitations under the License.
See the AUTHORS file for names of contributors.
*/
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <memory>
#include <string>
#include <algorithm>
#include "boost/format.hpp"
#include "plato/util/perf.hpp"
#include "plato/util/hdfs.hpp"
#include "plato/util/foutput.h"
#include "plato/graph/graph.hpp"
#include "plato/engine/walk.hpp"
DEFINE_string(input, "", "input graph file, in csv format, eg: src,dst[,weight]");
DEFINE_string(output, "", "output path, in csv format, gzip compressed");
DEFINE_bool(is_weighted, false, "random walk with bias or not");
DEFINE_double(p, 1.0, "backward bias for randomwalk");
DEFINE_double(q, 0.5, "forward bias for randomwalk");
DEFINE_uint32(epoch, 1, "how many epoch should perform");
DEFINE_uint32(step, 10, "steps per epoch");
DEFINE_double(rate, 0.02, "start 'rate'% walker per one bsp, reduce memory consumption");
bool string_not_empty(const char*, const std::string& value) {
if (0 == value.length()) { return false; }
return true;
}
DEFINE_validator(input, &string_not_empty);
void init(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
google::LogToStderr();
}
using partition_t = plato::hash_by_source_t<>;
#define N2V_COMMAND_MOVE (0x01)
#define N2V_COMMAND_IS_NEIGHBOR (0x02)
#define N2V_COMMAND_RESPONSE (0x03)
#define N2V_COMMAND_KEEPFOOTPRINT (0x04)
struct n2v_walk_t {
bool is_negb_;
uint8_t command_;
plato::vid_t proposal_;
float prob_;
plato::vid_t from_;
};
/*
* Implement Node2vec with rejection sampling
*
* reference:
* Ke Yang, MingXing Zhang, Kang Chen, Xiaosong Ma, Yang Bai, and Yong Jiang. 2019.
* KnightKing: A Fast Distributed Graph Ran- dom Walk Engine.
* In ACM SIGOPS 27th Symposium on Operating Systems Principles (SOSP ’19),
* October 27–30, 2019
*
* */
template <typename ENGINE>
void walk(ENGINE& engine) {
using walker_spec_t = plato::walker_t<n2v_walk_t>;
using walk_context_spec_t = plato::walk_context_t<n2v_walk_t, typename ENGINE::partition_t>;
auto& cluster_info = plato::cluster_info_t::get_instance();
float upbnd = std::max(1.0, std::max(1.0 / FLAGS_p, 1.0 / FLAGS_q)); // upper bound
float lwbnd = std::min(1.0, std::min(1.0 / FLAGS_p, 1.0 / FLAGS_q)); // lower bound
// init output
std::unique_ptr<plato::fs_mt_omp_output_t> output;
if (0 != FLAGS_output.length()) {
output.reset(new plato::fs_mt_omp_output_t(FLAGS_output,
(boost::format("%04d_") % cluster_info.partition_id_).str(), true));
}
plato::walk_opts_t opts;
opts.max_steps_ = FLAGS_step;
opts.epochs_ = FLAGS_epoch;
opts.start_rate = FLAGS_rate;
auto* g_sampler = engine.sampler();
engine.template walk<n2v_walk_t>(
[&](walker_spec_t*) { },
[&](walk_context_spec_t&& context, walker_spec_t& walker) {
auto& wdata = walker.udata_;
auto is_terminate = [&](void) {
return walker.step_id_ >= opts.max_steps_;
};
auto output_footprint = [&](plato::footprint_t& footprint) {
if (output) {
auto& ostream = output->ostream();
if (0 != footprint.idx_) {
ostream << footprint.path_[0];
for (plato::vid_t i = 1; i < footprint.idx_; ++i) {
ostream << " " << footprint.path_[i];
}
ostream << "\n";
}
}
};
auto oot = [&](void) { // output when terminated
if (is_terminate()) {
plato::footprint_t footprint = context.erase_footprint(walker);
output_footprint(footprint);
return true;
}
return false;
};
auto make_proposal = [&](void) {
std::uniform_real_distribution<float> dist(0, upbnd);
auto choose_edge = g_sampler->sample(walker.current_v_id_, context.urng());
CHECK(choose_edge != NULL);
plato::vid_t proposal = choose_edge->neighbour_;
float prob = dist(context.urng());
wdata.proposal_ = proposal;
wdata.prob_ = prob;
};
auto accept_proposal = [&](void) {
++walker.step_id_;
wdata.from_ = walker.current_v_id_;
walker.current_v_id_ = wdata.proposal_;
};
auto akm = [&](void) { // accept && keep_footprint && move
accept_proposal();
wdata.command_ = N2V_COMMAND_KEEPFOOTPRINT;
context.move_to(walker.walk_id_, walker);
if (false == is_terminate()) {
wdata.command_ = N2V_COMMAND_MOVE;
context.move_to(walker.current_v_id_, walker);
}
};
if (0 == walker.step_id_) {
if (is_terminate()) { return ; } // nonsense
++walker.step_id_;
context.keep_footprint(walker);
if (oot()) { return ; }
make_proposal();
accept_proposal();
context.keep_footprint(walker);
if (oot()) { return ; }
wdata.command_ = N2V_COMMAND_MOVE;
context.move_to(walker.current_v_id_, walker);
} else {
switch (wdata.command_) {
case N2V_COMMAND_MOVE:
{
while (true) {
make_proposal();
if (wdata.prob_ < lwbnd) { // accept directly
akm();
break;
} else if (wdata.proposal_ == wdata.from_) { // proposal_ is last visted vertex
if (wdata.prob_ < (1.0 / FLAGS_p)) { // accept
akm();
break;
}
} else { // ask last visted vertex
wdata.command_ = N2V_COMMAND_IS_NEIGHBOR;
context.move_to(wdata.from_, walker);
break;
}
}
break;
}
case N2V_COMMAND_IS_NEIGHBOR:
{
wdata.is_negb_ = g_sampler->existed(wdata.from_, wdata.proposal_);
wdata.command_ = N2V_COMMAND_RESPONSE;
context.move_to(walker.current_v_id_, walker);
break;
}
case N2V_COMMAND_RESPONSE:
{
if (wdata.is_negb_) {
if (wdata.prob_ < 1.0) {
akm();
} else { // proposal again
wdata.command_ = N2V_COMMAND_MOVE;
context.move_to(walker.current_v_id_, walker);
}
} else {
if (wdata.prob_ < (1.0 / FLAGS_q)) {
akm();
} else { // proposal again
wdata.command_ = N2V_COMMAND_MOVE;
context.move_to(walker.current_v_id_, walker);
}
}
break;
}
case N2V_COMMAND_KEEPFOOTPRINT:
{
context.keep_footprint(walker);
oot();
break;
}
default:
CHECK(false) << "unknown command: " << wdata.command_;
break;
}
}
}, opts);
}
void biased_walk(void) {
//using walk_engine_spec_t = plato::walk_engine_t<true>;
using part_spec_t = plato::hash_by_source_t<>;
using walk_engine_spec_t = plato::walk_engine_t<plato::cbcsr_t<float, part_spec_t>, float>;
plato::stop_watch_t watch;
plato::graph_info_t graph_info(false);
watch.mark("t1");
auto cache = plato::load_edges_cache<float, plato::vid_t, plato::edge_cache_t>(&graph_info, FLAGS_input, plato::edge_format_t::CSV,
plato::float_decoder);
auto& cluster_info = plato::cluster_info_t::get_instance();
if (0 == cluster_info.partition_id_) {
LOG(INFO) << "edges: " << graph_info.edges_;
LOG(INFO) << "vertices: " << graph_info.vertices_;
LOG(INFO) << "max_v_id: " << graph_info.max_v_i_;
LOG(INFO) << "is_directed_: " << graph_info.is_directed_;
LOG(INFO) << "load edges cache cost: " << watch.show("t1") / 1000.0 << "s";
}
std::shared_ptr<partition_t> partitioner(new partition_t());
//walk_engine_spec_t engine(graph_info, *cache, partitioner);
walk_engine_spec_t engine(graph_info, *cache, partitioner);
cache.reset();
walk(engine);
}
void unbiased_walk(void) {
//using walk_engine_spec_t = plato::walk_engine_t<false>;
using part_spec_t = plato::hash_by_source_t<>;
using walk_engine_spec_t = plato::walk_engine_t<plato::cbcsr_t<plato::empty_t, part_spec_t>, plato::empty_t>;
plato::stop_watch_t watch;
plato::graph_info_t graph_info(false);
watch.mark("t1");
auto cache = plato::load_edges_cache<plato::empty_t, plato::vid_t, plato::edge_cache_t>(&graph_info, FLAGS_input, plato::edge_format_t::CSV,
plato::dummy_decoder<plato::empty_t>);
auto& cluster_info = plato::cluster_info_t::get_instance();
if (0 == cluster_info.partition_id_) {
LOG(INFO) << "edges: " << graph_info.edges_;
LOG(INFO) << "vertices: " << graph_info.vertices_;
LOG(INFO) << "max_v_id: " << graph_info.max_v_i_;
LOG(INFO) << "is_directed_: " << graph_info.is_directed_;
LOG(INFO) << "load edges cache cost: " << watch.show("t1") / 1000.0 << "s";
}
std::shared_ptr<partition_t> partitioner(new partition_t());
walk_engine_spec_t engine(graph_info, *cache, partitioner);
cache.reset();
walk(engine);
}
int main(int argc, char** argv) {
init(argc, argv);
plato::stop_watch_t watch;
auto& cluster_info = plato::cluster_info_t::get_instance();
cluster_info.initialize(&argc, &argv);
if (0 == cluster_info.partition_id_) {
LOG(INFO) << "input: " << FLAGS_input;
LOG(INFO) << "output: " << FLAGS_output;
LOG(INFO) << "is_weighted: " << FLAGS_is_weighted;
LOG(INFO) << "p: " << FLAGS_p;
LOG(INFO) << "q: " << FLAGS_q;
LOG(INFO) << "epoch: " << FLAGS_epoch;
LOG(INFO) << "step: " << FLAGS_step;
LOG(INFO) << "rate: " << FLAGS_rate;
}
if (FLAGS_is_weighted) {
biased_walk();
} else {
unbiased_walk();
}
return 0;
}