-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdisjoint_set.hpp
80 lines (68 loc) · 2.16 KB
/
disjoint_set.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#pragma once
#include <numeric>
#include <vector>
#include <array>
// TODO: convert find() to iterative instead of recursive.
// TODO: use an enum class for I/O
// adapted from code found here https://www.geeksforgeeks.org/disjoint-set-data-structures/
template <size_t N>
class DisjointSet {
std::array<int, N> rank, parent;
public:
// Constructor to create and
// initialize sets of N items
DisjointSet() {
std::fill(rank.begin(), rank.end(), 0);
make_set();
}
auto get_size() -> size_t {
return N;
}
// Creates N single item sets
void make_set() {
// fills parent with the numbers from 0 -> n - 1.
std::iota(parent.begin(), parent.end(), 0);
}
// Finds the "set id" of given item x
// which is the ID of a select member
// of x's set.
auto find(int x) -> int {
// Finds the representative of the set
// that x is an element of
if (parent[x] != x) {
// if x is not the parent of itself
// Then x is not the representative of
// its set,
parent[x] = find(parent[x]);
// so we recursively call find() on its parent
// and move i's node directly under the
// representative of this set
}
return parent[x];
}
// unify the set that contains x and the set that contains y.
void unify(int x, int y) {
// Find current sets of x and y
auto xset = find(x);
auto yset = find(y);
// If they are already in same set
if (xset == yset)
return;
// Put smaller ranked item under
// bigger ranked item if ranks are
// different
if (rank[xset] < rank[yset]) {
parent[xset] = yset;
} else if (rank[xset] > rank[yset]) {
parent[yset] = xset;
} else {
// If ranks are same, then increment
// rank.
parent[yset] = xset;
rank[xset]++;
}
}
auto are_same_set(int x, int y) -> bool {
return find(x) == find(y);
}
};