-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path2316_Count_Unreachable_Pairs_of_Nodes_in_a_Undirected_Graph.cpp
158 lines (144 loc) · 3.51 KB
/
2316_Count_Unreachable_Pairs_of_Nodes_in_a_Undirected_Graph.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/*******************************************************************************
* 2316_Count_Unreachable_Pairs_of_Nodes_in_a_Undirected_Graph.cpp
* Billy.Ljm
* 25 Mar 2023
*
* =======
* Problem
* =======
* https://leetcode.com/problems/count-unreachable-pairs-of-nodes-in-an-undirected-graph/
*
* You are given an integer n. There is an undirected graph with n nodes,
* numbered from 0 to n - 1. You are given a 2D integer array edges where
* edges[i] = [ai, bi] denotes that there exists an undirected edge connecting
* nodes ai and bi. Return the number of pairs of different nodes that are
* unreachable from each other.
*
* ===========
* My Approach
* ===========
* We'll use union-find to find the number of disjoint sets in the graph. Then,
* pairs from disjoint sets would be unreachable from each other.
*
* This will have a time complexity of O(n+e) and space complexity of O(n),
* where n is the number of nodes, and e is the number of edges.
******************************************************************************/
#include <iostream>
#include <vector>
#include <numeric>
/**
* Union-find/Disjoint-set data structure
*/
class UnionFind {
private:
std::vector<int> parent, rank;
public:
/**
* Class Constructor
*
* @param size total number of nodes
*/
UnionFind(int size) {
parent = std::vector<int>(size);
std::iota(std::begin(parent), std::end(parent), 0);
rank = std::vector<int>(size, 0);
}
/**
* Find set of node. Uses path compression.
*
* @param i node to find parent of
*
* @return parent of node[i]
*/
int find(int i) {
if (parent[i] != i) {
parent[i] = find(parent[i]);
}
return parent[i];
}
/**
* Union of connected cities. Uses union by rank.
*
* @param x node to union with y
* @param y node to union with x
*/
void unionn(int x, int y) {
int xroot = find(x);
int yroot = find(y);
if (rank[xroot] < rank[yroot]) {
parent[xroot] = yroot;
}
else if (rank[xroot] > rank[yroot]) {
parent[yroot] = xroot;
}
else {
parent[yroot] = xroot;
rank[xroot]++;
}
}
};
/**
* Solution
*/
class Solution {
public:
/**
* Counts the number of pairs of nodes that are disjoint in a graph
*
* @param n number of nodes
* @param edges edges specified as [i,j] for node i to node j
*
* @return number of disjoint pairs
*/
long long countPairs(int n, std::vector<std::vector<int>>& edges) {
// union-find algorithm
UnionFind uf(n);
for (std::vector<int> edge : edges) {
uf.unionn(edge[0], edge[1]);
}
// count members in each disjoint set
std::vector<int> members(n, 0);
for (int i = 0; i < n; i++) {
members[uf.find(i)]++;
}
// count pairs across each disjoint subset
long long pairs = 0LL; // number of pairs
long long sum = 0LL; // cumulative sum
for (int c : members) {
pairs += sum * c;
sum += c;
}
return pairs;
}
};
/**
* << operator for vectors
*/
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
os << "[";
for (int i = 0; i < v.size(); i++) {
os << v[i] << ",";
}
os << "\b]";
return os;
}
/**
* Test cases
*/
int main(void) {
Solution sol;
int n;
std::vector<std::vector<int>> edges;
// test case 1
n = 3;
edges = { {0, 1}, {0, 2}, {1, 2} };
std::cout << "countPairs(" << n << ", " << edges << ") = "
<< sol.countPairs(n, edges) << std::endl;
// test case 2
n = 7;
edges = { {0,2}, {0,5}, {2,4}, {1,6}, {5,4} };
std::cout << "countPairs(" << n << ", " << edges << ") = "
<< sol.countPairs(n, edges) << std::endl;
return 0;
}