[CF734E] Anton and Tree - 树的直径

Description

给定 n 个节点的树,每个点为黑色或白色,一次操作可以使一个相同颜色的连通块变成另一种颜色,求使整棵树变成一种颜色的最少操作数。

Solution

将同色连通块缩点后,答案显然等于树的半径

#include <bits/stdc++.h>
using namespace std;

#define int long long
const int N = 1000005;

int a[N], n;
vector<int> g[N];

namespace segment1
{
    int fa[N];

    int find(int x)
    {
        return x == fa[x] ? x : fa[x] = find(fa[x]);
    }

    void merge(int x, int y)
    {
        x = find(x);
        y = find(y);
        fa[x] = y;
    }

    void main()
    {
        ios::sync_with_stdio(false);
        cin >> n;
        for (int i = 1; i <= n; i++)
            fa[i] = i;
        for (int i = 1; i <= n; i++)
            cin >> a[i];
        for (int i = 1; i < n; i++)
        {
            int u, v;
            cin >> u >> v;
            g[u].push_back(v);
            g[v].push_back(u);
            if (a[u] == a[v])
                merge(u, v);
        }
        set<pair<int, int>> s;
        for (int i = 1; i <= n; i++)
        {
            int p = i;
            for (auto j : g[i])
            {
                int q = j;
                p = find(i);
                q = find(j);
                if (p != q)
                {
                    if (p > q)
                        swap(p, q);
                    s.insert({p, q});
                }
            }
        }
        map<int, int> mp;
        for (auto [x, y] : s)
            mp[x]++, mp[y]++;
        int ind = 0;
        for (auto &[x, y] : mp)
            y = ++ind;
        for (int i = 1; i <= n; i++)
            g[i].clear();
        n = ind;
        for (auto [x, y] : s)
        {
            x = mp[x];
            y = mp[y];
            g[x].push_back(y);
            g[y].push_back(x);
        }
    }
} // namespace segment1

namespace segment2
{
    int dis[N];

    void dfs(int p, int from)
    {
        for (int q : g[p])
            if (q != from)
            {
                dis[q] = dis[p] + 1;
                dfs(q, p);
            }
    }

    void main()
    {
        dfs(1, 0);
        int mx = 0;
        for (int i = 1; i <= n; i++)
            if (dis[i] > dis[mx])
                mx = i;
        dis[mx] = 0;
        dfs(mx, 0);
        mx = 0;
        for (int i = 1; i <= n; i++)
            if (dis[i] > dis[mx])
                mx = i;
        cout << (dis[mx] + 1) / 2 << endl;
    }
} // namespace segment2

signed main()
{
    segment1::main();
    segment2::main();
}

相关文章: