Skip to main content

Centroid Decomposition

Code

#include <bits/stdc++.h>
using namespace std;

const int N = 5'000'005;
set<pair<int, int>> g[N];
int sz[N], pa[N];

int dfs1(int u, int p = -1) {
sz[u] = 1;
for (auto &[v, _]: g[u]) {
if (v == p) continue;
sz[u] += dfs1(v, u);
}
return sz[u];
}

void dfs2(int u, int p, int c) {
// do something
for (auto &[v, ca]: g[u]) {
if (v == p) continue;
// do something
dfs2(v, u, c);
// do something
}
// do something
}

int centroid(int u, int p, int n) {
for (auto &[v, c]: g[u]) {
if (v == p) continue;
if (sz[v] > n / 2) return centroid(v, u, n);
}
return u;
}

void centroid_decomposition(int u, int p = -1) {
int n = dfs1(u, p);
int c = centroid(u, p, n);
if (p == -1) p = c;
pa[c] = p;

dfs2(c, p, c);

vector<pair<int, int>> ch(g[c].begin(), g[c].end());
for (auto &[v, ca]: ch) {
g[c].erase({v, ca}); g[v].erase({c, ca});
centroid_decomposition(v, c);
}
}

inline void solve() {
int n; cin >> n;
for (int i = 1; i < n; ++i) {
int a, b, c; cin >> a >> b >> c;
g[a].emplace(b, c); g[b].emplace(a, c);
}
centroid_decomposition(1);
}