Skip to content

树的联结 - 题解

标签与难度

标签: 图论, 树, 树形DP, 换根DP, 树的直径, 距离, 组合计数 难度: 2100

题目大意喵~

哈喵~!各位算法大师们,今天我们来解决一个关于两棵可爱的树树的问题,喵~

题目给了我们两棵树,一棵有 nn 个节点(编号 1n1 \dots n),另一棵有 mm 个节点(编号 n+1n+mn+1 \dots n+m)。我们可以执行一个叫做“联结”的操作:从第一棵树里选一个节点 uu,从第二棵树里选一个节点 vv,然后在它们之间连一条边。这样一来,两棵小树就合体成一棵大树啦!

对于新形成的大树,任意两个节点 sstt 之间都有了唯一的路径,它们的距离记为 f(u,v,s,t)f(u, v, s, t)

现在,我们要对任意一对节点 (s,t)(s, t),定义一个“极限距离” d(s,t)d(s, t)。这个距离是我们在所有可能的联结方式(即所有可能的 (u,v)(u, v) 对)中,能得到的 sstt 之间距离的最大值。

d(s,t)=max1un,n+1vn+mf(u,v,s,t)d(s, t) = \max_{1 \le u \le n, n+1 \le v \le n+m} f(u, v, s, t)

我们的最终任务,就是计算所有可能的节点对 (s,t)(s, t)(其中 s<ts < t)的极限距离之和,也就是:

1s<tn+md(s,t)\sum_{1 \le s < t \le n+m} d(s, t)

解题思路分析

这道题看起来有点复杂,因为它涉及到了一个“最大化”的过程,还要对所有点对求和,喵~ 但是不要怕,只要我们把问题拆解开来,一步一步分析,就会发现它的脉络其实很清晰哦!

首先,我们来分析这个核心的“极限距离” d(s,t)d(s, t) 是怎么计算的。这取决于节点 sstt 的位置,呐。

我们用 dist_1(a, b) 表示在第一棵树中 a,ba, b 间的距离,dist_2(a, b) 表示在第二棵树中 a,ba, b 间的距离。

Case 1: sstt 在同一棵树里

比如说,sstt 都在第一棵树里。当我们在 uTree1u \in \text{Tree1}vTree2v \in \text{Tree2} 之间连边后, sstt 之间的最短路径仍然是它们在第一棵树里本来的那条路,喵~ 绕到第二棵树再回来肯定会更远。所以,这种情况下,联结方式 (u,v)(u, v) 并不会影响 s,ts, t 间的距离。 因此,如果 s,tTree1s, t \in \text{Tree1},那么 d(s,t)=dist1(s,t)d(s, t) = \text{dist}_1(s, t)。 同理,如果 s,tTree2s, t \in \text{Tree2},那么 d(s,t)=dist2(s,t)d(s, t) = \text{dist}_2(s, t)

Case 2: sstt 在不同的树里

比如说,sTree1s \in \text{Tree1}tTree2t \in \text{Tree2}。当我们用边 (u,v)(u, v) 联结两棵树后,从 sstt 的唯一路径必然是:suvts \to \dots \to u \to v \to \dots \to t。 这条路径的长度就是 dist_1(s, u) + 1 + dist_2(v, t)。 为了让这个距离最大化,也就是求 d(s,t)d(s, t),我们需要选择最好的 uuvv

d(s,t)=maxuTree1,vTree2(dist1(s,u)+1+dist2(v,t))d(s, t) = \max_{u \in \text{Tree1}, v \in \text{Tree2}} (\text{dist}_1(s, u) + 1 + \text{dist}_2(v, t))

这个最大化可以分开来看,喵~

d(s,t)=(maxuTree1dist1(s,u))+1+(maxvTree2dist2(v,t))d(s, t) = \left( \max_{u \in \text{Tree1}} \text{dist}_1(s, u) \right) + 1 + \left( \max_{v \in \text{Tree2}} \text{dist}_2(v, t) \right)

括号里的部分其实就是节点到它所在树中最远节点的距离!我们把这个值记作 farthest_1(s)farthest_2(t)。 所以,当 sTree1,tTree2s \in \text{Tree1}, t \in \text{Tree2} 时, d(s,t)=farthest1(s)+farthest2(t)+1d(s, t) = \text{farthest}_1(s) + \text{farthest}_2(t) + 1

汇总与求和

现在我们把总的求和公式拆成三部分:

  1. T1内部: 1s<tndist1(s,t)\sum_{1 \le s < t \le n} \text{dist}_1(s, t)
  2. T2内部: n+1s<tn+mdist2(s,t)\sum_{n+1 \le s < t \le n+m} \text{dist}_2(s, t)
  3. 跨树: s=1nt=n+1n+m(farthest1(s)+farthest2(t)+1)\sum_{s=1}^{n} \sum_{t=n+1}^{n+m} (\text{farthest}_1(s) + \text{farthest}_2(t) + 1)

前两部分是经典的“树上所有点对距离之和”问题。第三部分可以进一步展开:

s=1nt=n+1n+mfarthest1(s)+s=1nt=n+1n+mfarthest2(t)+s=1nt=n+1n+m1\sum_{s=1}^{n} \sum_{t=n+1}^{n+m} \text{farthest}_1(s) + \sum_{s=1}^{n} \sum_{t=n+1}^{n+m} \text{farthest}_2(t) + \sum_{s=1}^{n} \sum_{t=n+1}^{n+m} 1

=m(s=1nfarthest1(s))+n(t=n+1n+mfarthest2(t))+nm= m \cdot \left(\sum_{s=1}^{n} \text{farthest}_1(s)\right) + n \cdot \left(\sum_{t=n+1}^{n+m} \text{farthest}_2(t)\right) + n \cdot m

所以,我们只需要为两棵树分别计算两个值:

  1. 树内所有点对的距离之和。
  2. 树内每个点到其最远点的距离之和。

这两个问题都可以用树形DP,特别是换根DP(也叫二次扫描法)来高效解决,喵~

树形DP大法好!

对于一棵树(比如Tree1,大小为 NN),我们可以通过两遍DFS来搞定一切:

第一遍DFS (自底向上): 从任意根节点开始。对于每个节点 u,我们计算:

  • subtree_size[u]: 以 u 为根的子树大小。
  • sum_dist_down[u]: u 到其子树中所有节点的距离之和。
  • down1[u], down2[u]: 从 u 出发,完全在其子树内的最长和次长路径长度。

第二遍DFS (自顶向下): 再次从根节点出发。对于每个节点 u,利用其父节点 p 已经算好的信息,来计算 u 的最终信息:

  • 所有点对距离之和:
    • sum_dist_total[u]: u 到树中所有其他节点的距离之和。这个可以由 sum_dist_total[p] 推导出来。
    • sum_dist_total[u] = sum_dist_total[p] - subtree_size[u] + (N - subtree_size[u])
    • 所有点对距离之和就是 12u=1Nsum_dist_total[u]\frac{1}{2} \sum_{u=1}^{N} \text{sum\_dist\_total}[u]
  • 最远点距离:
    • up[u]: 从 u 出发,不进入其子树(即向上走)的最长路径长度。这个也可以由父节点 p 的信息 up[p], down1[p], down2[p] 推导出。
    • 节点 u 的最远点距离就是 max(down1[u], up[u])

我们对两棵树都执行这一套操作,得到各自的“所有点对距离之和”与“所有最远点距离之和”,然后代入我们之前的总公式,问题就解决啦!是不是感觉思路清晰多啦,喵~

代码实现

这是我根据上面的思路,精心重构的一份代码~ 希望能帮到你,呐!

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

const int MAX_NODES = 200005; // n+m <= 200000
using ll = long long;

vector<int> adj[MAX_NODES];
ll n, m;

// DP arrays for a single tree calculation
ll subtree_size[MAX_NODES];
ll sum_dist_down[MAX_NODES]; // Sum of distances from u to nodes in its subtree
ll down1[MAX_NODES], down2[MAX_NODES]; // Longest and second longest path downwards from u
ll up[MAX_NODES]; // Longest path upwards from u

// --- DFS Pass 1: Bottom-up ---
// Calculates subtree_size, sum_dist_down, down1, down2 for each node.
void dfs_bottom_up(int u, int p) {
    subtree_size[u] = 1;
    sum_dist_down[u] = 0;
    down1[u] = down2[u] = 0;

    for (int v : adj[u]) {
        if (v == p) continue;
        dfs_bottom_up(v, u);
        subtree_size[u] += subtree_size[v];
        sum_dist_down[u] += sum_dist_down[v] + subtree_size[v];
        
        ll path_len = down1[v] + 1;
        if (path_len > down1[u]) {
            down2[u] = down1[u];
            down1[u] = path_len;
        } else if (path_len > down2[u]) {
            down2[u] = path_len;
        }
    }
}

// --- DFS Pass 2: Top-down ---
// Calculates total sum of distances and sum of farthest distances for the whole tree.
void dfs_top_down(int u, int p, ll total_nodes, ll& total_pairs_dist_sum, ll& total_farthest_dist_sum) {
    // Calculate total distance from u to all other nodes
    ll sum_dist_total_u;
    if (p == 0) { // Root node
        sum_dist_total_u = sum_dist_down[u];
    } else {
        // Rerooting DP formula
        sum_dist_total_u = (total_pairs_dist_sum - (sum_dist_down[u] + subtree_size[u])) // Sum from parent to outside u's subtree
                         + (total_nodes - subtree_size[u]) // Paths from outside u's subtree to u are 1 longer
                         + sum_dist_down[u]; // Sum to inside u's subtree
    }
    total_pairs_dist_sum = sum_dist_total_u; // Temporarily store for children calculations

    // Calculate farthest distance from u
    ll farthest_dist_u = max(down1[u], up[u]);
    total_farthest_dist_sum += farthest_dist_u;

    // Recurse for children
    for (int v : adj[u]) {
        if (v == p) continue;

        // Calculate up[v]: longest path from v going "up"
        ll parent_up_path = up[u] + 1;
        ll sibling_down_path = 0;
        if (down1[u] == down1[v] + 1) { // If longest path from u goes through v
            sibling_down_path = down2[u] + 1;
        } else {
            sibling_down_path = down1[u] + 1;
        }
        up[v] = max(parent_up_path, sibling_down_path);

        dfs_top_down(v, u, total_nodes, total_pairs_dist_sum, total_farthest_dist_sum);
    }
}


// A helper function to process one tree and get the required values
void process_tree(int root, ll num_nodes, ll& out_sum_all_pairs_dist, ll& out_sum_farthest_dist) {
    if (num_nodes <= 1) {
        out_sum_all_pairs_dist = 0;
        out_sum_farthest_dist = 0;
        return;
    }
    
    dfs_bottom_up(root, 0);

    ll total_pairs_dist_sum_accumulator = 0;
    ll total_farthest_dist_sum = 0;
    up[root] = 0;
    
    // The `total_pairs_dist_sum_accumulator` is passed by value and modified by children to get their total sums
    // This is a bit of a trick to use the parameter for passing parent's total sum down
    ll root_total_dist = sum_dist_down[root];

    dfs_top_down(root, 0, num_nodes, root_total_dist, total_farthest_dist_sum);
    
    // Let's re-calculate the sum of all pairs distances properly
    // The previous method was a bit tricky. A clearer way:
    ll sum_of_all_total_dists = 0;
    vector<ll> sum_dist_total(num_nodes + (root == 1 ? 0 : n) + 1);
    vector<int> q;
    q.push_back(root);
    vector<bool> visited(num_nodes + (root == 1 ? 0 : n) + 1, false);
    visited[root] = true;
    sum_dist_total[root] = sum_dist_down[root];
    sum_of_all_total_dists += sum_dist_total[root];

    int head = 0;
    while(head < q.size()){
        int u = q[head++];
        for(int v : adj[u]){
            if(!visited[v]){
                visited[v] = true;
                q.push_back(v);
                sum_dist_total[v] = sum_dist_total[u] - subtree_size[v] + (num_nodes - subtree_size[v]);
                sum_of_all_total_dists += sum_dist_total[v];
            }
        }
    }

    out_sum_all_pairs_dist = sum_of_all_total_dists / 2;
    out_sum_farthest_dist = total_farthest_dist_sum;
}


int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> m;
    for (int i = 0; i < n + m - 2; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    ll sum_dist_t1 = 0, sum_farthest_t1 = 0;
    process_tree(1, n, sum_dist_t1, sum_farthest_t1);

    ll sum_dist_t2 = 0, sum_farthest_t2 = 0;
    process_tree(n + 1, m, sum_dist_t2, sum_farthest_t2);
    
    ll total_ans = 0;
    // Part 1: Pairs within Tree 1
    total_ans += sum_dist_t1;
    // Part 2: Pairs within Tree 2
    total_ans += sum_dist_t2;
    // Part 3: Pairs across trees
    total_ans += m * sum_farthest_t1;
    total_ans += n * sum_farthest_t2;
    total_ans += n * m;

    cout << total_ans << endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(N+M)O(N+M) 我们对两棵树分别进行了两遍DFS。每次DFS都会访问每个节点和每条边一次。第一棵树的节点数是 NN,边数是 N1N-1。第二棵树的节点数是 MM,边数是 M1M-1。所以总的时间复杂度是线性的,即 O(N+M)O(N+M),非常高效,喵~

  • 空间复杂度: O(N+M)O(N+M) 我们使用了邻接表来存储图,空间为 O(N+M)O(N+M)。同时,我们用了一些数组(如 subtree_size, down1 等)来辅助DP计算,它们的大小都与总节点数成正比。所以空间复杂度也是 O(N+M)O(N+M)

知识点总结

这道题是树形DP的一个非常棒的练习,它融合了几个经典的树上问题:

  1. 树上所有点对距离之和: 可以通过计算每条边的贡献,或者使用换根DP来解决。本解法中使用了换根DP的思想。
  2. 树上每个点到最远点的距离: 这个问题通常与树的直径有关。一个节点的最远点一定是树的某条直径的两个端点之一。但更通用的方法是使用换根DP,通过一次自底向上和一次自顶向下的DFS,就能求出所有点的最远距离。
  3. 换根DP (Rerooting / 二次扫描法): 这是解决本题的核心技术。它非常适合处理那些“对于树上每个节点,求XXX”类型的问题。其精髓在于,先通过一次DFS求出子树内的信息,再通过第二次DFS,结合父节点的信息,推导出全局的信息。
  4. 问题分解: 面对复杂的求和问题,将求和项根据不同情况分类讨论,然后分别计算,是一种非常重要的解题策略,喵~

希望这篇题解能让你对树形DP有更深的理解!继续加油,你超棒的,的说!