Skip to content

I - Infinity - 题解

标签与难度

标签: 组合数学, 生成函数, 多项式NTT, 多项式对数函数, 多项式指数函数, 群论 难度: 2900

题目大意喵~

哈喵~!各位算法大师们,今天本喵要带大家攻略一道超级有趣的题目!

题目是这样哒:

首先,我们有一个集合 SnS_n,它包含了所有 nn 个元素的排列,也就是我们熟悉的对称群,喵~ 对于 SnS_n 中的每一个排列 σ\sigma,我们定义一个值 ν(σ)\nu(\sigma),它表示集合 {μ1σμμSn}\{ \mu^{-1}\sigma\mu \mid \mu \in S_n \} 中不同元素的数量。

然后,题目会给我们一个固定的整数 kk,和好多好多个查询 nn。对每一个 nn,我们都要计算下面这个式子的值,并对 998244353998244353 取模:

σSnν(σ)k\sum_{\sigma \in S_n} \nu(\sigma)^k

简单来说,就是对所有 nn-阶排列 σ\sigma,计算出它的 ν(σ)\nu(\sigma) 值的 kk 次方,然后把它们全部加起来。是不是听起来就很有挑战性呀?喵~

解题思路分析

这道题看起来披着一层群论的神秘外衣,可能会吓到一些小伙伴,但别怕,本喵会像解开毛线球一样,把思路一点点理清楚的!

步骤一:理解 ν(σ)\nu(\sigma) 的真面目

首先,我们来研究一下 ν(σ)\nu(\sigma) 是什么。在群论中,集合 {μ1σμμSn}\{ \mu^{-1}\sigma\mu \mid \mu \in S_n \} 被称为排列 σ\sigma共轭类(Conjugacy Class)。而 ν(σ)\nu(\sigma) 就是这个共轭类的大小。

对称群 SnS_n 有一个非常美妙的性质:两个排列是共轭的,当且仅当它们有相同的轮换分解结构(Cycle Type)

什么叫轮换分解结构呢?就是一个排列分解成不相交的循环置换后,这些循环的长度组成的集合。比如说,在 S5S_5 中,排列 (1 2)(3 4 5)(1\ 2)(3\ 4\ 5) 有一个长度为2的循环和一个长度为3的循环,它的轮换结构就是 {2,3}\{2, 3\}。所有轮换结构是 {2,3}\{2, 3\} 的排列(比如 (1 3)(2 4 5)(1\ 3)(2\ 4\ 5))都和它在同一个共轭类里。

步骤二:转化求和式

知道了这一点,我们就可以把原来的求和式进行一个漂亮的变身!原来的和是枚举所有的排列 σ\sigma,但我们可以换个角度,按共轭类来分组计算。

SnS_n 的所有共轭类为 C1,C2,,Cm\mathcal{C}_1, \mathcal{C}_2, \dots, \mathcal{C}_m

σSnν(σ)k=j=1mσCjν(σ)k\sum_{\sigma \in S_n} \nu(\sigma)^k = \sum_{j=1}^{m} \sum_{\sigma \in \mathcal{C}_j} \nu(\sigma)^k

因为在同一个共轭类 Cj\mathcal{C}_j 中,所有排列 σ\sigmaν(σ)\nu(\sigma) 值都是相等的,都等于这个类的大小 Cj|\mathcal{C}_j|。所以,内层的和可以简化:

σCjν(σ)k=σCjCjk=CjCjk=Cjk+1\sum_{\sigma \in \mathcal{C}_j} \nu(\sigma)^k = \sum_{\sigma \in \mathcal{C}_j} |\mathcal{C}_j|^k = |\mathcal{C}_j| \cdot |\mathcal{C}_j|^k = |\mathcal{C}_j|^{k+1}

于是,我们的总和就变成了:

j=1mCjk+1\sum_{j=1}^{m} |\mathcal{C}_j|^{k+1}

这可真是个不得了的发现!我们只需要知道所有共轭类的大小,然后求它们的 (k+1)(k+1) 次方和就可以啦!

步骤三:共轭类大小的公式

一个排列的轮换结构可以用一个序列 (c1,c2,,cn)(c_1, c_2, \dots, c_n) 来表示,其中 cic_i 是长度为 ii 的循环的个数。这个结构必须满足 i=1nici=n\sum_{i=1}^n i \cdot c_i = n

对于一个给定的轮换结构 (c1,,cn)(c_1, \dots, c_n),它对应的共轭类的大小是:

C=n!i=1nicici!|\mathcal{C}| = \frac{n!}{\prod_{i=1}^n i^{c_i} c_i!}

一个更有趣的事实是,具有这个轮换结构的排列的个数,也恰好是这个公式!这说明,对于一个共轭类里的任何一个排列 σ\sigmaν(σ)\nu(\sigma)(共轭类大小)就等于与它同类型的排列的数量。

所以,我们的目标是计算:

Answern=ici=n(n!i=1nicici!)k+1\text{Answer}_n = \sum_{\sum i \cdot c_i = n} \left( \frac{n!}{\prod_{i=1}^n i^{c_i} c_i!} \right)^{k+1}

这个和是枚举所有满足 ici=n\sum i \cdot c_i = n 的轮换结构。为了方便,我们令 K=k+1K = k+1

步骤四:请出我们的魔法道具——生成函数!

这个公式看起来非常复杂,直接计算所有整数分拆是不可能的。这种组合计数问题,正是生成函数大显身手的时候!

我们把公式稍微变形一下:

Answern=(n!)Kici=ni=1n1(icici!)K\text{Answer}_n = (n!)^K \sum_{\sum i \cdot c_i = n} \prod_{i=1}^n \frac{1}{(i^{c_i} c_i!)^K}

Bn=ici=ni=1n1(icici!)KB_n = \sum_{\sum i \cdot c_i = n} \prod_{i=1}^n \frac{1}{(i^{c_i} c_i!)^K}。我们来构造 BnB_n 的生成函数 B(x)=n=0BnxnB(x) = \sum_{n=0}^\infty B_n x^n

观察 BnB_n 的结构,它是对所有满足 ici=n\sum i \cdot c_i = ncic_i 求和,这种形式的卷积可以表示为多个多项式的乘积。我们为每个循环长度 ii 定义一个多项式:

Pi(x)=m=01(imm!)KximP_i(x) = \sum_{m=0}^\infty \frac{1}{(i^m m!)^K} x^{im}

这里的 mm 就相当于原来的 cic_i。那么 B(x)B(x) 就是所有这些 Pi(x)P_i(x) 的乘积!

B(x)=i=1Pi(x)B(x) = \prod_{i=1}^\infty P_i(x)

[xn]B(x)[x^n]B(x) 就会自动地从每个 Pi(x)P_i(x) 中选出一项 1(icici!)Kxici\frac{1}{(i^{c_i} c_i!)^K} x^{ic_i},使得指数和 ici=n\sum ic_i = n,这正好就是 BnB_n 的定义!

步骤五:对数-指数魔法(ln-exp 技巧)

无限乘积太难处理了,喵~ 我们可以用对数把它变成加法:

ln(B(x))=ln(i=1Pi(x))=i=1ln(Pi(x))\ln(B(x)) = \ln\left(\prod_{i=1}^\infty P_i(x)\right) = \sum_{i=1}^\infty \ln(P_i(x))

现在我们来仔细看看 ln(Pi(x))\ln(P_i(x))Pi(x)=m=01(m!)K(iK)m(xi)mP_i(x) = \sum_{m=0}^\infty \frac{1}{(m!)^K (i^K)^m} (x^i)^m。 如果我们定义一个基础多项式 F(z)=m=0zm(m!)KF(z) = \sum_{m=0}^\infty \frac{z^m}{(m!)^K},那么 Pi(x)=F(xiiK)P_i(x) = F\left(\frac{x^i}{i^K}\right)。 令 D(x)=ln(B(x))D(x) = \ln(B(x)),则 B(x)=exp(D(x))B(x) = \exp(D(x))

D(x)=i=1ln(F(xiiK))D(x) = \sum_{i=1}^\infty \ln\left(F\left(\frac{x^i}{i^K}\right)\right)

再令 H(z)=ln(F(z))=r=1hrzrH(z) = \ln(F(z)) = \sum_{r=1}^\infty h_r z^r。那么:

ln(F(xiiK))=H(xiiK)=r=1hr(xiiK)r=r=1hr(iK)rxir\ln\left(F\left(\frac{x^i}{i^K}\right)\right) = H\left(\frac{x^i}{i^K}\right) = \sum_{r=1}^\infty h_r \left(\frac{x^i}{i^K}\right)^r = \sum_{r=1}^\infty \frac{h_r}{(i^K)^r} x^{ir}

把所有 ii 的贡献加起来,我们就能得到 D(x)D(x) 的系数 dnd_n

dn=[xn]D(x)=ir=n,i,r1hr(iK)r=rnhr((n/r)K)rd_n = [x^n]D(x) = \sum_{ir=n, i,r \ge 1} \frac{h_r}{(i^K)^r} = \sum_{r|n} \frac{h_r}{((n/r)^K)^r}

最终的算法流程

好耶!我们已经推导出了一条清晰的计算路径,喵~

  1. 预处理: 计算 K=(k+1)(mod998244352)K = (k+1) \pmod{998244352}。因为模数是质数 PP,根据费马小定理,指数需要对 P1P-1 取模。同时预计算阶乘和阶乘逆元。
  2. 构造 F(x)F(x): 构造多项式 F(x)=m=0Nmax1(m!)KxmF(x) = \sum_{m=0}^{N_{max}} \frac{1}{(m!)^K} x^m,其中 NmaxN_{max} 是所有查询中最大的 nn
  3. 计算 H(x)=ln(F(x))H(x) = \ln(F(x)): 使用多项式对数函数算法 (NTT实现),得到 H(x)H(x) 的系数 hrh_r
  4. 计算 D(x)D(x): 根据公式 dn=rnhr((n/r)K)rd_n = \sum_{r|n} \frac{h_r}{((n/r)^K)^r},计算出 D(x)D(x) 的系数 dnd_n。这可以通过枚举每个 rr 和它的倍数 n=irn=ir 来高效完成,复杂度是 O(NlogN)O(N \log N) 级别。
  5. 计算 B(x)=exp(D(x))B(x) = \exp(D(x)): 使用多项式指数函数算法 (NTT实现),得到 B(x)B(x) 的系数 BnB_n
  6. 计算答案: 对于每个查询 nn,最终答案就是 Answern=(n!)KBn(mod998244353)\text{Answer}_n = (n!)^K \cdot B_n \pmod{998244353}。我们可以预先计算出所有 nn 的答案。

这套流程下来,我们就可以在 O(NmaxlogNmax)O(N_{max} \log N_{max}) 的时间复杂度内解决所有查询啦!是不是感觉充满了智慧的光芒,喵~

代码实现

下面就是本喵根据上面的思路,精心为大家准备的代码实现啦!注释很详细的哦,希望能帮到大家,喵~

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

// MOD is a prime number: 998244353
const int MOD = 998244353;
const int NTT_PRIMITIVE_ROOT = 3;
const int MAXN = 200005;

// Fast power function
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// Modular inverse
long long modInverse(long long n) {
    return power(n, MOD - 2);
}

namespace Poly {
    std::vector<int> rev;
    std::vector<long long> roots;
    int ntt_limit = 1;

    void ensure_ntt_limit(int limit) {
        if (ntt_limit >= limit) return;
        int old_limit = ntt_limit;
        ntt_limit = 1;
        while (ntt_limit < limit) ntt_limit <<= 1;
        
        rev.resize(ntt_limit);
        for (int i = 0; i < ntt_limit; ++i) {
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (ntt_limit >> 1) : 0);
        }

        roots.resize(ntt_limit);
        for (int len = 2; len <= ntt_limit; len <<= 1) {
            long long w_len = power(NTT_PRIMITIVE_ROOT, (MOD - 1) / len);
            roots[len / 2] = 1;
            for (int i = len / 2 + 1; i < len; ++i) {
                roots[i] = (roots[i - 1] * w_len) % MOD;
            }
        }
    }

    void ntt(std::vector<long long>& a, bool invert) {
        int n = a.size();
        for (int i = 0; i < n; ++i) {
            if (i < rev[i]) std::swap(a[i], a[rev[i]]);
        }
        for (int len = 2; len <= n; len <<= 1) {
            for (int i = 0; i < n; i += len) {
                for (int j = 0; j < len / 2; ++j) {
                    long long u = a[i + j];
                    long long v = (a[i + j + len / 2] * roots[len / 2 + j]) % MOD;
                    a[i + j] = (u + v) % MOD;
                    a[i + j + len / 2] = (u - v + MOD) % MOD;
                }
            }
        }
        if (invert) {
            std::reverse(a.begin() + 1, a.end());
            long long inv_n = modInverse(n);
            for (long long& x : a) {
                x = (x * inv_n) % MOD;
            }
        }
    }

    std::vector<long long> multiply(std::vector<long long> a, std::vector<long long> b) {
        int sz = a.size() + b.size() - 1;
        int limit = 1;
        while (limit < sz) limit <<= 1;
        ensure_ntt_limit(limit);
        a.resize(limit);
        b.resize(limit);
        ntt(a, false);
        ntt(b, false);
        for (int i = 0; i < limit; ++i) a[i] = (a[i] * b[i]) % MOD;
        ntt(a, true);
        a.resize(sz);
        return a;
    }

    std::vector<long long> inverse(const std::vector<long long>& a, int deg) {
        if (deg == 0) return {};
        std::vector<long long> b = {modInverse(a[0])};
        int current_deg = 1;
        while (current_deg < deg) {
            current_deg <<= 1;
            std::vector<long long> a_slice(current_deg);
            for(int i = 0; i < std::min((int)a.size(), current_deg); ++i) a_slice[i] = a[i];

            int limit = current_deg * 2;
            ensure_ntt_limit(limit);

            std::vector<long long> b_ntt = b;
            b_ntt.resize(limit);
            ntt(b_ntt, false);

            std::vector<long long> a_ntt = a_slice;
            a_ntt.resize(limit);
            ntt(a_ntt, false);

            for (int i = 0; i < limit; ++i) {
                b_ntt[i] = (2 - (a_ntt[i] * b_ntt[i]) % MOD + MOD) % MOD * b_ntt[i] % MOD;
            }
            ntt(b_ntt, true);
            b.resize(current_deg);
            for(int i = 0; i < current_deg; ++i) b[i] = b_ntt[i];
        }
        b.resize(deg);
        return b;
    }

    std::vector<long long> derivative(const std::vector<long long>& a) {
        if (a.empty()) return {};
        std::vector<long long> res(a.size() - 1);
        for (size_t i = 1; i < a.size(); ++i) {
            res[i - 1] = (a[i] * i) % MOD;
        }
        return res;
    }

    std::vector<long long> integral(const std::vector<long long>& a) {
        std::vector<long long> res(a.size() + 1, 0);
        for (size_t i = 0; i < a.size(); ++i) {
            res[i + 1] = (a[i] * modInverse(i + 1)) % MOD;
        }
        return res;
    }

    std::vector<long long> logarithm(const std::vector<long long>& a, int deg) {
        if (deg == 0) return {};
        auto deriv_a = derivative(a);
        auto inv_a = inverse(a, deg);
        auto res = multiply(deriv_a, inv_a);
        res.resize(deg - 1);
        return integral(res);
    }
    
    std::vector<long long> exponentiation(const std::vector<long long>& a, int deg) {
        if (deg == 0) return {};
        std::vector<long long> b = {1};
        int current_deg = 1;
        while (current_deg < deg) {
            current_deg <<= 1;
            auto log_b = logarithm(b, current_deg);
            std::vector<long long> next_b(current_deg);
            for (int i = 0; i < current_deg; ++i) {
                next_b[i] = (a[i] - log_b[i] + MOD) % MOD;
            }
            next_b[0] = (next_b[0] + 1) % MOD;
            b = multiply(b, next_b);
            b.resize(current_deg);
        }
        b.resize(deg);
        return b;
    }

} // namespace Poly

long long fact[MAXN], invFact[MAXN];

void precompute_factorials(int n) {
    fact[0] = 1;
    invFact[0] = 1;
    for (int i = 1; i <= n; ++i) {
        fact[i] = (fact[i - 1] * i) % MOD;
        invFact[i] = modInverse(fact[i]);
    }
}

int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int t;
    long long k;
    std::cin >> t >> k;

    std::vector<int> queries(t);
    int max_n = 0;
    for (int i = 0; i < t; ++i) {
        std::cin >> queries[i];
        if (queries[i] > max_n) {
            max_n = queries[i];
        }
    }

    precompute_factorials(max_n);

    long long K_exp = (k + 1) % (MOD - 1);
    if (K_exp == 0) K_exp = MOD - 1;

    // Step 2: Construct F(x)
    std::vector<long long> poly_F(max_n + 1);
    for (int i = 0; i <= max_n; ++i) {
        poly_F[i] = power(invFact[i], K_exp);
    }

    // Step 3: Compute H(x) = ln(F(x))
    auto poly_H = Poly::logarithm(poly_F, max_n + 1);

    // Step 4: Compute D(x)
    std::vector<long long> poly_D(max_n + 1, 0);
    for (int r = 1; r <= max_n; ++r) {
        if (poly_H[r] == 0) continue;
        for (int i = 1; r * i <= max_n; ++i) {
            long long n_over_r_K_r = power(modInverse(i), (K_exp * (long long)r) % (MOD - 1));
            long long term = (poly_H[r] * n_over_r_K_r) % MOD;
            poly_D[r * i] = (poly_D[r * i] + term) % MOD;
        }
    }

    // Step 5: Compute B(x) = exp(D(x))
    auto poly_B = Poly::exponentiation(poly_D, max_n + 1);

    // Step 6: Precompute all answers
    std::vector<long long> answers(max_n + 1);
    for (int n = 1; n <= max_n; ++n) {
        long long n_fact_K = power(fact[n], K_exp);
        answers[n] = (n_fact_K * poly_B[n]) % MOD;
    }

    for (int n : queries) {
        std::cout << answers[n] << "\n";
    }

    return 0;
}

复杂度分析

  • 时间复杂度: O(NlogN)O(N \log N),其中 NN 是查询中最大的 nn

    • 预计算阶乘和逆元需要 O(N)O(N)
    • 构造多项式 F(x)F(x) 需要 O(NlogK)O(N \log K)
    • 计算多项式对数 H(x)=ln(F(x))H(x) = \ln(F(x)) 需要 O(NlogN)O(N \log N)
    • 计算多项式 D(x)D(x) 的系数需要 r=1NNrO(NlogN)\sum_{r=1}^N \frac{N}{r} \approx O(N \log N) 次运算,每次运算包含一个快速幂,所以总的是 O(NlogNlogK)O(N \log N \log K)
    • 计算多项式指数 B(x)=exp(D(x))B(x) = \exp(D(x)) 需要 O(NlogN)O(N \log N)
    • 计算最终答案需要 O(NlogK)O(N \log K)
    • 整个算法的瓶颈在于计算 D(x)D(x) 的系数和多项式运算,所以总时间复杂度是 O(NlogNlogK)O(N \log N \log K)。但由于 kk 很大,我们对 P1P-1 取模,所以可以看作 O(NlogN)O(N \log N)
  • 空间复杂度: O(N)O(N)

    • 我们需要存储几个大小为 NN 的多项式(F,H,D,BF, H, D, B)以及阶乘等预处理数组。NTT 过程中需要临时空间,但也在 O(N)O(N) 范围内。

知识点总结

这道题是群论、组合数学和多项式算法的完美结合,做出来之后一定会让你成就感满满的,喵~

  1. 群论与组合: 理解对称群 SnS_n 的共轭类和轮换结构的关系是解题的第一步。将问题从对排列求和转化为对共轭类(或整数分拆)求和是关键的简化。
  2. 生成函数: 面对复杂的组合求和式,生成函数是我们的最强武器。特别是将乘积关系转化为卷积,再通过 lnexp\ln-\exp 技巧处理乘积形式,是解决这类问题的经典思路。
  3. 多项式全家桶: 本题的核心计算依赖于一系列多项式操作,包括 NTT(快速数论变换)、多项式求逆、求对数、求指数。熟练掌握这些板子是解决高难度组合计数问题的基础。
  4. 算法优化: 通过 F(z)F(z)H(z)H(z) 的关系,我们将计算许多个 ln(Pi(x))\ln(P_i(x)) 的复杂问题,转化为了计算一次 ln(F(x))\ln(F(x)) 和一次 O(NlogN)O(N \log N) 的求和,大大提高了效率。

希望这篇题解能帮助你更好地理解这道题的精髓!下次再遇到难题,也要像小猫一样,充满好奇心和勇气去挑战哦!喵~