Skip to content

How Many Schemes - 题解

标签与难度

标签: AC自动机, 动态规划, 树上算法, 矩阵快速幂, 二进制提升(LCA) 难度: 2500

题目大意喵~

各位Master,欢迎来到第八关呐!(ฅ'ω'ฅ)

这道题是这样的:我们面前有一棵 nn 个节点的树,每条边上都挂着一个非空的小写字母集合。除此之外,还有 mm 个“模式串”。

现在有 qq 次询问,每次询问会给我们树上的两个点 uuvv。我们需要从 uu 走到 vv,沿着唯一路径经过的每一条边,都从它对应的字母集合里选择一个字母。把这些字母按顺序拼接起来,就能形成一个字符串 str

我们的任务是,对于每次询问 (u,v)(u, v),计算有多少种选择字母的方案,可以使得最终生成的字符串 str 至少包含一个给定的模式串作为子串。因为答案可能很大,所以要对 998244353998244353 取模哦!

解题思路分析

这道题看起来好复杂呀,树、路径、字符串匹配、计数……各种要素都凑齐了呢,喵~ 但是不要怕,让我一步一步带你解开它的秘密!

Step 1: "至少一个"的反面是什么喵?

"至少包含一个模式串"这种条件,直接计算起来会很麻烦,因为需要考虑容斥原理,比如一个字符串同时包含多个模式串的情况。一个经典的思路是正难则反!我们来计算它的对立面:不包含任何一个模式串的方案数。

如果我们能求出以下两个值:

  1. uuvv 的路径上,总共有多少种可以生成的字符串?
  2. 其中,有多少种字符串不包含任何一个模式串?

那么,用 (总方案数 - 不包含任何模式串的方案数),就能得到我们想要的答案啦!

总方案数其实很好算。假设路径上的边依次是 e1,e2,,eke_1, e_2, \dots, e_k,它们对应的字母集合分别是 se1,se2,,seks_{e_1}, s_{e_2}, \dots, s_{e_k}。根据乘法原理,总方案数就是 i=1ksei\prod_{i=1}^{k} |s_{e_i}|

那么,问题的核心就变成了:如何计算不包含任何模式串的方案数呢?

Step 2: 多模式串匹配?AC自动机出动!

一看到“多模式串匹配”的问题,我们的DNA就动了,对不对?这不就是 Aho-Corasick (AC) 自动机 的标准应用场景嘛!

我们可以把所有 mm 个模式串建成一个AC自动机。AC自动机本质上是一个Trie树加上一些fail指针。当我们用一个文本串在上面匹配时,fail指针能帮助我们在失配时快速跳转到下一个可能匹配的位置,而不会丢失已匹配的前缀信息。

对于这道题,我们是要计数,而不是简单地判断是否存在。所以,我们需要对AC自动机做一点小小的改造:

  1. 标记终点状态:在Trie中,所有代表模式串末尾的节点,以及通过fail指针能够到达这些末尾节点的节点,都应该被标记为“危险状态”。因为只要到达了这些状态,就意味着我们生成的字符串已经包含了一个模式串作为后缀。
  2. 建立吸收态:为了方便计数,我们可以让这些“危险状态”变成吸收态。也就是说,一旦进入一个危险状态,无论接下来选择什么字符,都会一直停留在危险状态。这样,只要我们最终没有停在危险状态,就说明生成的整个字符串都不包含任何模式串。

具体操作是:

  • 建好Trie树和fail指针。
  • 通过fail指针链传递“危险”标记:如果 fail[u] 是危险的,那么 u 也是危险的。
  • 将所有危险状态的所有出边都指向它自己,形成一个吸收汇点。

这样改造后,我们就有了一个强大的工具:一个可以判断字符串前缀是否包含模式串的状态机。

Step 3: 路径上的DP与矩阵快速幂

现在,我们把在树上路径生成字符串的过程,看作是在AC自动机上行走的过程。

我们可以定义一个DP状态:dp[i][j] 表示从路径起点开始,经过了 ii 条边,最终停留在AC自动机状态 jj 的方案数。

初始状态是 dp[0][0] = 1(在AC自动机的起始状态,方案数为1),其他 dp[0][j] = 0

考虑路径上的第 ii 条边 eie_i,它上面有字母集合 seis_{e_i}。状态转移方程可以这样想:

dp[i][next_j]=j=0S1dp[i1][j]×(从状态 j 通过 sei 中字符转移到 next_j 的方式数)dp[i][\text{next\_j}] = \sum_{j=0}^{S-1} dp[i-1][j] \times (\text{从状态 } j \text{ 通过 } s_{e_i} \text{ 中字符转移到 } \text{next\_j} \text{ 的方式数})

其中 SS 是AC自动机的状态总数。

这个转移过程,其实可以用矩阵来表示!

  • 我们可以构造一个 S×SS \times S转移矩阵 MeiM_{e_i}
  • Mei[p][q]M_{e_i}[p][q] 表示从AC自动机的状态 qq 出发,通过边 eie_i 上的一个字符,能够转移到状态 pp 的方案数。
  • 这个矩阵可以预处理:对于 eie_i 上的每个字符 c,我们遍历AC自动机的所有状态 q,找到转移后的状态 p = next(q, c),然后 M_{e_i}[p][q]++

如果我们的DP状态是一个列向量 V,其中 V[j] = dp[...][j],那么经过一条边 eie_i 的转移就等价于左乘它的转移矩阵:Vnew=Mei×VoldV_{\text{new}} = M_{e_i} \times V_{\text{old}}

一条路径 e1,e2,,eke_1, e_2, \dots, e_k 的总转移矩阵就是这些边矩阵的乘积:Mpath=Mek××Me2×Me1M_{\text{path}} = M_{e_k} \times \dots \times M_{e_2} \times M_{e_1}

Step 4: 树上路径问题与二进制提升(LCA)

我们有好多好多查询,每次都去遍历路径然后一个个乘矩阵,肯定会超时的说!( T﹏T )

但是,树上路径问题有一个非常强大的优化技巧:二进制提升 (LCA)

我们可以把一条 uvu \to v 的路径拆分成两段:uLCA(u,v)u \to \text{LCA}(u,v)LCA(u,v)v\text{LCA}(u,v) \to v。我们可以预处理出树上每个节点向上/向下跳 2k2^k 步的路径所对应的转移矩阵!

  • up_mat[u][k]: 从节点 u 向上走到它的第 2k2^k 个祖先的路径所对应的转移矩阵。
  • down_mat[u][k]: 从节点 u 的第 2k2^k 个祖先向下走到 u 的路径所对应的转移矩阵。

它们的递推关系是:

  • up_mat[u][k] = up_mat[parent(u, 2^(k-1))][k-1] * up_mat[u][k-1] (先走 uu 开始的一半,再走后一半)
  • down_mat[u][k] = down_mat[u][k-1] * down_mat[parent(u, 2^(k-1))][k-1] (先走前一半,再走通往 uu 的后一半)

预处理出这些矩阵后,对于任意一个查询 (u,v)(u, v)

  1. 找到它们的最近公共祖先 L=LCA(u,v)L = \text{LCA}(u, v)
  2. 用二进制提升的方法,像求LCA一样,从 uu 跳到 LL,沿途把对应的 up_mat 矩阵乘起来,得到 MuLM_{u \to L}
  3. 同样地,从 vv 跳到 LL,把对应的 down_mat 矩阵乘起来,得到 MLvM_{L \to v}。注意,这里我们需要的是从 LLvv 的矩阵,所以组合 down_mat 的时候要小心顺序。(或者像参考代码那样,把 vLv \to L 路径上的 down_mat 存起来再反序相乘,喵~)
  4. 总的转移矩阵 Mtotal=MLv×MuLM_{\text{total}} = M_{L \to v} \times M_{u \to L}
  5. 用这个总矩阵乘以初始状态向量 V_init = [1, 0, ..., 0]^T,得到最终的状态向量 V_final
  6. V_final[j] 就是最终停留在状态 j 的方案数。我们把所有非危险状态 j 对应的 V_final[j] 加起来,就得到了不包含任何模式串的方案总数。

Step 5: 把它们拼起来!

好啦,现在我们把所有零件都组装起来:

  1. 总方案数:同样可以用二进制提升预处理。对于每条边,计算其字母集合大小。然后预处理每个节点向上/向下跳 2k2^k 步的路径上的总方案数(即路径上所有边字母集合大小的乘积)。查询时同样拆成两段计算。
  2. 合法方案数:直接计算!我们改变一下DP的定义。
    • 让我们直接计算包含模式串的方案数。
    • 在AC自动机上,只要进入了一个“危险状态”,就说明已经匹配成功了。
    • 我们的DP状态向量 VV[j] 表示到达AC自动机状态 j 的方案数。
    • 最终的状态向量 V_final 算出来后,我们把所有危险状态 j 对应的 V_final[j] 加起来,就是我们想要的答案了!因为一旦进入危险状态,就会被吸收,所以最终停留在危险状态的路径,一定是在某个时刻进入了危险状态的。

这种直接计算的方法更优雅,不需要再算总数然后相减啦!参考代码们用的就是这种聪明的办法,喵~

总结一下最终的算法流程:

  1. 将所有模式串构建成一个AC自动机,并处理好fail指针和“危险”吸收态。
  2. 预处理出树上每条边的转移矩阵 MeM_e
  3. 通过一次DFS,计算出每个节点的深度 dep[u]、父节点 parent[u][0],以及基础的转移矩阵 up_mat[u][0]down_mat[u][0]
  4. 利用二进制提升,计算出所有的 parent[u][k], up_mat[u][k], down_mat[u][k]
  5. 对于每个查询 (u,v)(u, v): a. 计算 LCA(u,v)\text{LCA}(u, v)。 b. 利用预处理的矩阵,计算出 uvu \to v 整条路径的总转移矩阵 MtotalM_{\text{total}}。 c. 计算最终状态向量 Vfinal=Mtotal×[1,0,,0]TV_{\text{final}} = M_{\text{total}} \times [1, 0, \dots, 0]^T。 d. 将 V_final 中所有对应AC自动机“危险状态”的分量求和,即为答案。

这个算法结合了多种强大的工具,虽然有点复杂,但是逻辑非常清晰,是不是感觉很有趣呢?喵~

代码实现

这是我根据上面的思路,精心重构的一份代码哦!希望能帮助到你理解,喵~

cpp
#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <numeric>
#include <algorithm>

using namespace std;

const int MOD = 998244353;
const int MAX_N = 2510;
const int MAX_AC_STATES = 45;
const int LOG_N = 12;

// 模块化的加减乘,防止溢出喵
int add(int a, int b) { return (a + b >= MOD) ? (a + b - MOD) : (a + b); }
int mul(int a, int b) { return (1LL * a * b) % MOD; }

int AC_STATE_COUNT = 0;

// AC自动机部分
namespace AhoCorasick {
    int trie[MAX_AC_STATES][26];
    int fail[MAX_AC_STATES];
    bool is_dangerous[MAX_AC_STATES];
    int node_count;

    void init() {
        node_count = 1; // 0是虚拟根,1是真实根
        for(int i = 0; i < MAX_AC_STATES; ++i) {
            fill(trie[i], trie[i] + 26, 0);
            fail[i] = 0;
            is_dangerous[i] = false;
        }
    }

    void insert(const string& s) {
        int curr = 1;
        for (char ch : s) {
            int c = ch - 'a';
            if (!trie[curr][c]) {
                trie[curr][c] = ++node_count;
            }
            curr = trie[curr][c];
        }
        is_dangerous[curr] = true;
    }

    void build() {
        queue<int> q;
        for (int i = 0; i < 26; ++i) {
            if (trie[1][i]) {
                fail[trie[1][i]] = 1;
                q.push(trie[1][i]);
            } else {
                trie[1][i] = 1;
            }
        }

        while (!q.empty()) {
            int u = q.front();
            q.pop();

            is_dangerous[u] |= is_dangerous[fail[u]];

            for (int i = 0; i < 26; ++i) {
                if (trie[u][i]) {
                    fail[trie[u][i]] = trie[fail[u]][i];
                    q.push(trie[u][i]);
                } else {
                    trie[u][i] = trie[fail[u]][i];
                }
            }
        }
        
        AC_STATE_COUNT = node_count + 1;
        
        // 将危险状态变为吸收态
        for(int i = 1; i <= node_count; ++i) {
            if (is_dangerous[i]) {
                for(int j = 0; j < 26; ++j) {
                    trie[i][j] = i;
                }
            }
        }
    }
}

// 矩阵结构体
struct Matrix {
    int mat[MAX_AC_STATES][MAX_AC_STATES];

    Matrix() {
        for (int i = 0; i < AC_STATE_COUNT; ++i) {
            fill(mat[i], mat[i] + AC_STATE_COUNT, 0);
        }
    }

    static Matrix identity() {
        Matrix I;
        for (int i = 0; i < AC_STATE_COUNT; ++i) {
            I.mat[i][i] = 1;
        }
        return I;
    }
};

Matrix operator*(const Matrix& A, const Matrix& B) {
    Matrix C;
    for (int i = 0; i < AC_STATE_COUNT; ++i) {
        for (int j = 0; j < AC_STATE_COUNT; ++j) {
            for (int k = 0; k < AC_STATE_COUNT; ++k) {
                C.mat[i][j] = add(C.mat[i][j], mul(A.mat[i][k], B.mat[k][j]));
            }
        }
    }
    return C;
}

vector<int> operator*(const Matrix& A, const vector<int>& v) {
    vector<int> res(AC_STATE_COUNT, 0);
    for (int i = 0; i < AC_STATE_COUNT; ++i) {
        for (int j = 0; j < AC_STATE_COUNT; ++j) {
            res[i] = add(res[i], mul(A.mat[i][j], v[j]));
        }
    }
    return res;
}

// 树结构和二进制提升
struct Edge {
    int to;
    string chars;
};

vector<Edge> adj[MAX_N];
int n, m, q;
int depth[MAX_N];
int parent[MAX_N][LOG_N];
Matrix up_matrices[MAX_N][LOG_N];
Matrix down_matrices[MAX_N][LOG_N];

void dfs(int u, int p, int d, const Matrix& up_mat, const Matrix& down_mat) {
    depth[u] = d;
    parent[u][0] = p;
    up_matrices[u][0] = up_mat;
    down_matrices[u][0] = down_mat;

    for (const auto& edge : adj[u]) {
        int v = edge.to;
        if (v == p) continue;
        
        Matrix edge_mat;
        for (char ch : edge.chars) {
            int c = ch - 'a';
            for (int i = 1; i <= AhoCorasick::node_count; ++i) {
                int next_state = AhoCorasick::trie[i][c];
                edge_mat.mat[next_state][i]++;
            }
        }
        
        dfs(v, u, d + 1, edge_mat, edge_mat);
    }
}

void precompute_binary_lifting() {
    for (int k = 1; k < LOG_N; ++k) {
        for (int i = 1; i <= n; ++i) {
            if (parent[i][k - 1] != 0) {
                int p = parent[i][k - 1];
                parent[i][k] = parent[p][k - 1];
                // up_mat[i][k] = up_mat[p][k-1] * up_mat[i][k-1]
                up_matrices[i][k] = up_matrices[p][k - 1] * up_matrices[i][k - 1];
                // down_mat[i][k] = down_mat[i][k-1] * down_mat[p][k-1]
                down_matrices[i][k] = down_matrices[i][k - 1] * down_matrices[p][k - 1];
            }
        }
    }
}

int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u, v);
    for (int k = LOG_N - 1; k >= 0; --k) {
        if (depth[u] - (1 << k) >= depth[v]) {
            u = parent[u][k];
        }
    }
    if (u == v) return u;
    for (int k = LOG_N - 1; k >= 0; --k) {
        if (parent[u][k] != parent[v][k]) {
            u = parent[u][k];
            v = parent[v][k];
        }
    }
    return parent[u][0];
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> m >> q;

    AhoCorasick::init();
    
    vector<pair<int, int>> edges_for_adj;
    vector<string> edge_char_sets(n);
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        string s;
        cin >> u >> v >> s;
        adj[u].push_back({v, s});
        adj[v].push_back({u, s});
    }

    for (int i = 0; i < m; ++i) {
        string t;
        cin >> t;
        AhoCorasick::insert(t);
    }

    AhoCorasick::build();

    dfs(1, 0, 0, Matrix(), Matrix()); // 根节点到父节点的矩阵是空的,用默认构造的0矩阵
    precompute_binary_lifting();

    for (int i = 0; i < q; ++i) {
        int u, v;
        cin >> u >> v;

        int l = lca(u, v);

        Matrix mat_u_to_lca = Matrix::identity();
        for (int k = LOG_N - 1; k >= 0; --k) {
            if (depth[u] - (1 << k) >= depth[l]) {
                mat_u_to_lca = up_matrices[u][k] * mat_u_to_lca;
                u = parent[u][k];
            }
        }

        Matrix mat_lca_to_v = Matrix::identity();
        vector<Matrix> v_path_matrices;
        for (int k = LOG_N - 1; k >= 0; --k) {
            if (depth[v] - (1 << k) >= depth[l]) {
                v_path_matrices.push_back(down_matrices[v][k]);
                v = parent[v][k];
            }
        }
        reverse(v_path_matrices.begin(), v_path_matrices.end());
        for(const auto& mat : v_path_matrices) {
            mat_lca_to_v = mat * mat_lca_to_v;
        }

        Matrix total_matrix = mat_lca_to_v * mat_u_to_lca;

        vector<int> initial_vec(AC_STATE_COUNT, 0);
        initial_vec[1] = 1; // 从AC自动机的真实根节点1开始

        vector<int> final_vec = total_matrix * initial_vec;

        int ans = 0;
        for (int j = 1; j <= AhoCorasick::node_count; ++j) {
            if (AhoCorasick::is_dangerous[j]) {
                ans = add(ans, final_vec[j]);
            }
        }
        cout << ans << "\n";
    }

    return 0;
}

复杂度分析

  • 时间复杂度: O(tL+(N+Q)logNS3)O(\sum|t| \cdot L + (N+Q)\log N \cdot S^3),其中 t\sum|t| 是所有模式串的总长度, L=26L=26 是字符集大小,NN 是节点数,QQ 是查询数,SS 是AC自动机的状态数(大约等于 t\sum|t|)。

    • AC自动机构建:O(tL)O(\sum|t| \cdot L)
    • 预处理每条边的转移矩阵在DFS中完成,总共是 O(NseavgS)O(N \cdot |s_e|_{avg} \cdot S)
    • 二进制提升预处理:O(NlogNS3)O(N \log N \cdot S^3),瓶颈在于 S×SS \times S 矩阵乘法。
    • 每次查询:LCA为 O(logN)O(\log N),路径矩阵计算需要 O(logN)O(\log N) 次矩阵乘法,所以是 O(QlogNS3)O(Q \log N \cdot S^3)
    • 考虑到 SS 最大约41,这个复杂度是可以通过的。
  • 空间复杂度: O(NlogNS2)O(N \log N \cdot S^2)

    • AC自动机:O(SL)O(S \cdot L)
    • 二进制提升的 parent 数组:O(NlogN)O(N \log N)
    • 存储 up_matricesdown_matricesO(NlogNS2)O(N \log N \cdot S^2),这是空间占用的主要部分。

知识点总结

这道题是一道非常好的综合题,像一个美味的算法大餐,融合了多种技巧,喵~

  1. AC自动机: 解决多模式串匹配问题的利器,本题用它来构建状态转移的基础。
  2. 动态规划: 核心思想是用DP来计数。dp[i][j] 的思想贯穿始终。
  3. 矩阵快速幂: 将DP转移过程抽象成矩阵乘法,是加速线性递推的经典方法。在这里,我们用它来合并路径上的DP转移。
  4. 树上算法与二进制提升 (LCA): 将线性序列上的算法(矩阵乘法)推广到树上路径的关键。通过预处理 2k2^k 步长的信息,可以快速回答任意路径的查询。
  5. 问题转化: "至少一个" -> "计算不包含的补集" 或者 "改造状态机直接计算",都是处理这类计数问题的常用思路。本题解采用了更直接的后者。

希望这篇题解能帮助你攻克这道难题,如果还有不明白的地方,随时可以再来问我哦,喵~ (ฅ^•ﻌ•^ฅ)