-
Notifications
You must be signed in to change notification settings - Fork 224
/
test_graph_optimize_state.cc
72 lines (61 loc) · 3.3 KB
/
test_graph_optimize_state.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include "compiler/graph_optimize_state.h"
#include "doctest/doctest.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h"
using namespace FlexFlow;
TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("graph_optimize_state:equality") {
ParallelComputationGraphBuilder builder;
ParallelTensorShape input_shape =
ParallelTensorShape{ParallelTensorDims{
FFOrdered<ShardParallelDim>{
ShardParallelDim{32, 2},
ShardParallelDim{16, 1},
},
ReplicaParallelDimSet{
SumDegree{1},
DiscardCopyDegree{1},
},
},
DataType::FLOAT};
parallel_tensor_guid_t input0 =
builder.create_input_tensor(input_shape, true, "input0");
parallel_tensor_guid_t dense0 = builder.dense(input0,
8,
Activation::RELU,
true,
DataType::FLOAT,
std::nullopt,
std::nullopt,
"dense0");
parallel_tensor_guid_t dense1 = builder.dense(dense0,
4,
Activation::RELU,
true,
DataType::FLOAT,
std::nullopt,
std::nullopt,
"dense1");
ParallelComputationGraph pcg = builder.pcg;
// `machine_mapping` is determined by the PCG and the device mapping
// algorithm, and `runtime` is determined by the PCG and the device mapping,
// so their values here do not matter.
std::unordered_map<Node, MachineView> empty_machine_views;
MachineMapping empty_machine_mapping(empty_machine_views);
CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0) ==
GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0));
ParallelComputationGraphBuilder builder_;
parallel_tensor_guid_t input0_ =
builder.create_input_tensor(input_shape, true, "input0");
parallel_tensor_guid_t dense0_ = builder.dense(input0,
8,
Activation::RELU,
true,
DataType::FLOAT,
std::nullopt,
std::nullopt,
"dense0");
ParallelComputationGraph pcg_ = builder.pcg;
CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0) !=
GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), 0));
}
}