Skip to content

回文计数问题 - 题解

标签与难度

标签: 字符串, 回文自动机, PAM, 动态规划, 组合计数 难度: 2400

题目大意喵~

你好呀,指挥官!这道题是这样的喵~

我们有一个由小写字母组成的字符串 s。需要找到所有满足特定条件的非空字符串集合 A 的数量,结果对 998244353 取模。

集合 A 需要满足的条件是:

  1. 集合 A 里的每一个字符串 p,都必须是 s 的一个非空子串。
  2. 对于集合 A任意两个字符串 pq,我们定义 t 为它们的最长公共后缀 (lcs(p,q))。那么 p, q, p-t (p去掉后缀t), q-t (q去掉后缀t) 这四个字符串都必须是回文串(或者是空串),并且 t 本身不能为空。

简单来说,我们要找一些回文子串组成的集合,这些集合内部的成员两两之间关系非常和谐,满足一个很强的回文性质约束,呐。

解题思路分析

喵哈哈,这道题看起来条件好多,有点吓人呢!不过别怕,让我来一层一层地剥开它的心,看看里面到底藏着什么秘密,喵~

关键条件的转化

我们先来仔细研究一下那个最复杂的条件:对于集合 A 里的任意 pq,令 t = lcs(p,q),则 p, q, t, p-tq-t 都得是回文串。

首先,如果一个集合 A 满足条件,那么对于任意一个元素 p ∈ A,我们可以取 q=p。此时 t = lcs(p,p) = p。根据条件,pp-t(也就是空串)都得是回文串。这告诉我们一个基本事实:集合 A 中的所有元素本身都必须是回文串

现在,我们考虑集合中两个不同的元素 pq。设 p = u + tq = v + t,其中 t 是它们的最长公共后缀。条件要求 u, v, t, p, q 都是回文串。

一个字符串由两部分 ut 拼接而成,即 u+t。如果 ut 本身都是回文串,u+t 什么时候也是回文串呢? reverse(u+t) = reverse(t) + reverse(u)。因为 ut 是回文串,所以 reverse(t)=t, reverse(u)=u。 所以 reverse(u+t) = t+u。 要使 u+t 成为回文串,必须满足 u+t = reverse(u+t),也就是 u+t = t+u。 两个字符串 ut 可交换 (ut=tu) 的充要条件是,它们都是同一个更短的字符串 w 的幂!即 u = w^a, t = w^b。 又因为 ut 是回文串,所以它们的"基底" w 也必须是回文串。

这个结论太重要了!它告诉我们:

  • p = u+tw 的幂。
  • q = v+tw 的幂。
  • t 也是 w 的幂。

因为 t 同时是 pq 的后缀,并且它们都是同一个回文基底 w 的幂,所以 pq 必须都是 w 的幂次。 这意味着,一个合法的集合 A 中,所有字符串都必须是同一个“本原回文串” w 的不同次幂!(本原回文串是指它本身不能表示为另一个更短的回文串的幂)。

例如,如果 w = "a",那么合法的集合可以是 {"a", "aa"}{"a", "aaa", "aaaa"} 等。 如果 w = "aba",那么合法的集合可以是 {"aba", "abaaba"}

问题的新形式

我们的问题转化为了:

  1. 找出 s 的所有回文子串。
  2. 将这些回文子串按它们的“本原回文根” w 分组。
  3. 对于每一个根 w,它所对应的所有幂次(且是s的子串)构成了_一个_可以从中挑选元素的“超级集合” G_w
  4. 任何一个合法的集合 A 必须是某个 G_w 的非空子集。
  5. 所以,总方案数就是对每一个 w,计算 G_w 的非空子集数量(即 2Gw12^{|G_w|} - 1),然后把它们全部加起来。

如何高效实现?

要解决这个问题,我们需要一个强大的工具来处理回文串——回文自动机(Palindromic Automaton, PAM),也叫回文树,喵~

PAM 可以在线性时间内找出字符串 s 的所有本质不同的回文子串。PAM 的每个节点(除了两个初始根节点)都唯一对应一个回文子串。

我们还需要一个关键性质:如果 w^k 是 s 的一个子串,那么对于所有 1jk1 \le j \le k,w^j 也必然是 s 的子串(因为 w^jw^k 的前缀)。 所以,对于一个本原回文根 w,如果它在 s 中出现的最大幂次是 k,那么它的“超级集合” G_w 的大小就是 k,包含了 {w^1, w^2, ..., w^k}。我们需要计算的贡献就是 2k12^k-1

那么,我们如何用 PAM 找到每个本原回文根 w 和它对应的最大幂次 k 呢?

这就要用到 PAM 的一个进阶技巧——回文串的算术级数。 在 PAM 中,每个节点 u 都有一个 fail_link 指针,指向 u 所代表的回文串的最长回文后缀。 我们可以观察 ufail_link[u] 的长度差 diff[u] = len[u] - len[fail_link[u]]。如果一个回文串 P 是另一个回文串 Q 的幂(比如 P=Q^k),它们往往会形成一个 diff 值相同的链。

我们可以为 PAM 的每个节点 u 额外维护两个信息:

  • series_root[u]: 指向 u 所在的回文幂次系列的最短的那个回文串(也就是本原回文根)对应的节点。
  • power[u]: u 是它的根的多少次幂。

在构建 PAM 的过程中,当我们创建一个新节点 u 时:

  1. 找到它的最长回文后缀 v = fail_link[u]
  2. 比较它们的长度差:diff[u] = len[u] - len[v]diff[v] = len[v] - len[fail_link[v]]
  3. 如果 diff[u] == diff[v],说明 u 延续了 v 的幂次系列!所以 series_root[u] = series_root[v],并且 power[u] = power[v] + 1
  4. 如果 diff[u] != diff[v],说明 u 开启了一个新的系列。它自己就是根,所以 series_root[u] = u,并且 power[u] = 1

这个递推关系可以更高效地实现,就像参考代码中那样,通过维护一个指向当前系列上一个不同 diff 的节点的指针(有时叫做 slinkans)。

最终的计算

有了这些信息,计算答案就简单了:

  1. 完整构建 PAM,并为每个节点计算出 series_rootpower
  2. 我们只需要对每个本原回文根 w 算一次贡献。
  3. 我们可以遍历所有节点 u(从长到短遍历比较方便)。
  4. 用一个 visited 数组记录哪些 series_root 已经被计算过。
  5. 对于节点 u,如果 visited[series_root[u]]false,说明这是我们第一次遇到这个 w 系列的串。此时 u 就是这个系列里最长的串,它的幂次 power[u] 就是这个系列的最大幂次 k
  6. 我们将 2k12^k-1 加入总答案,并标记 visited[series_root[u]] = true
  7. 这样,每个系列只会被其最长的成员统计一次,完美!

好啦,思路就是这样!接下来就让我们用代码把这只我的奇思妙想变成现实吧,喵~

代码实现

cpp
#include <iostream>
#include <vector>
#include <string>
#include <cstring>

using namespace std;

const int MAXN = 5000005;
const int MOD = 998244353;

// 预处理2的幂,方便计算 2^k - 1
long long power_of_2[MAXN];

// 回文自动机 (PAM)
namespace PAM {
    struct Node {
        int len;          // 节点代表的回文串长度
        int fail_link;    // 指向最长回文后缀的节点
        int series_root;  // 所在回文幂次系列的根节点
        int child[26];    // 转移边
    };

    Node nodes[MAXN];
    char text[MAXN];
    int node_count;      // 节点总数
    int last_node;       // 上一个插入字符后对应的最长回文后缀节点

    // 初始化PAM,建立两个根节点
    // 0号节点代表偶数长度回文串的虚根,长度为0
    // 1号节点代表奇数长度回文串的虚根,长度为-1
    void init() {
        node_count = 1;
        nodes[0].len = 0;
        nodes[0].fail_link = 1;
        nodes[1].len = -1;
        nodes[1].fail_link = 1; // 1的fail指向自己
        last_node = 0;
        text[0] = -1; // 设置哨兵,防止越界
    }

    // 沿着fail链找到一个节点p,使得 text[i-len[p]-1] == text[i]
    int get_fail(int current_node, int text_idx) {
        while (text[text_idx - nodes[current_node].len - 1] != text[text_idx]) {
            current_node = nodes[current_node].fail_link;
        }
        return current_node;
    }

    // 插入一个新字符
    void insert(int char_code, int text_idx) {
        int p = get_fail(last_node, text_idx);

        // 如果这个新回文串不存在,则创建新节点
        if (!nodes[p].child[char_code]) {
            int new_node = ++node_count;
            nodes[new_node].len = nodes[p].len + 2;

            // 确定新节点的fail_link
            if (p == 1) { // 如果是从奇根转移过来,fail指向偶根
                nodes[new_node].fail_link = 0;
            } else {
                nodes[new_node].fail_link = nodes[get_fail(nodes[p].fail_link, text_idx)].child[char_code];
            }
            
            nodes[p].child[char_code] = new_node;

            // 核心逻辑:判断是否延续了回文幂次系列
            int v = nodes[new_node].fail_link;
            // 如果新回文串的长度差等于其fail节点的根的长度,说明是同一个系列
            if (nodes[nodes[v].series_root].len == nodes[new_node].len - nodes[v].len) {
                nodes[new_node].series_root = nodes[v].series_root;
            } else {
                // 否则,开启一个新系列,自己就是根
                nodes[new_node].series_root = new_node;
            }
        }
        last_node = nodes[p].child[char_code];
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    // 预处理2的幂
    power_of_2[0] = 1;
    for (int i = 1; i < MAXN; ++i) {
        power_of_2[i] = (power_of_2[i - 1] * 2) % MOD;
    }

    string s;
    cin >> s;
    int n = s.length();

    PAM::init();
    for (int i = 0; i < n; ++i) {
        PAM::text[i + 1] = s[i];
        PAM::insert(s[i] - 'a', i + 1);
    }
    
    long long total_sets = 0;
    vector<bool> visited_series(PAM::node_count + 1, false);

    // 从最长的回文串开始遍历 (节点编号越大,一般长度越长)
    // 这样可以保证第一次遇到一个系列时,是该系列的最长者
    for (int i = PAM::node_count; i >= 2; --i) {
        int root_node = PAM::nodes[i].series_root;
        if (visited_series[root_node]) {
            continue;
        }
        
        // 计算当前回文串是其根的多少次幂
        // len(i) / len(root) 就是幂次 k
        int power = PAM::nodes[i].len / PAM::nodes[root_node].len;
        
        // 该系列贡献 2^k - 1
        long long contribution = (power_of_2[power] - 1 + MOD) % MOD;
        total_sets = (total_sets + contribution) % MOD;
        
        visited_series[root_node] = true;
    }

    cout << total_sets << endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(N)O(N),其中 NN 是字符串 s 的长度。构建回文自动机是均摊 O(N)O(N) 的。之后遍历所有节点计算答案也是 O(N)O(N) 的,因为 PAM 的节点数最多为 N+2N+2
  • 空间复杂度: O(NΣ)O(N \cdot |\Sigma|),其中 Σ|\Sigma| 是字符集大小(这里是26)。主要是 PAM 的节点和转移边占用的空间。如果用 map 存转移边,可以优化到 O(N)O(N),但常数会大一些。对于这道题,O(N26)O(N \cdot 26) 是完全可以接受的。

知识点总结

  1. 问题转化: 解题的第一步,也是最重要的一步,就是将题目中复杂的约束条件转化为一个更清晰、更具数学结构的等价问题。这里我们将复杂的 lcs 和回文条件转化为了“集合中所有元素都是同一个本原回文串的幂”。
  2. 回文自动机 (PAM): 这是解决各种回文子串问题的神器,喵~ 它能在线性时间内找出所有本质不同的回文子串,并建立它们之间的后缀关系(fail_link)。
  3. PAM 的进阶应用: 本题不仅用了 PAM 的基本功能,还利用了其 fail 链上的“算术级数”性质来高效地识别一个回文串是否是另一个回文串的幂。这是 PAM 一个非常巧妙的应用,值得好好体会和学习,呐!
  4. 组合计数: 最终的答案统计用到了基本的组合知识。对于一个大小为 k 的集合,它的非空子集数量为 2k12^k - 1

希望这篇题解能帮助你更好地理解这道有趣的题目!如果还有不明白的地方,随时可以再来问我哦,喵~