Skip to content

抽奖 - 题解

标签与难度

标签: 期望DP, 生成函数, 组合数学, 数论, NTT, 线性递推, 斯特林数 难度: 2800

题目大意喵~

一位叫做 NaCly Fish 的朋友要去抽奖,每次抽奖,她都有 p=abp = \frac{a}{b} 的概率抽中一个奖品。她非常想要 mm 个奖品,所以她会一直抽呀抽,直到集齐 mm 个为止。

现在,我们设她总共抽奖的次数是一个随机变量 XX。题目想让我们计算 XkX^k 的数学期望,也就是 E[Xk]E[X^k]

因为答案可能是一个分数,所以需要我们输出它对 998244353998244353 取模的结果。

简单来说,就是求一个负二项分布的 kk 阶原点矩,喵~

解题思路分析

这道题是概率期望和组合数学的结合,看起来就超级有料,本喵已经兴奋起来了呢,喵~

直接根据定义计算 E[Xk]=n=mnkP(X=n)E[X^k] = \sum_{n=m}^{\infty} n^k P(X=n) 是非常困难的,因为这是一个无穷级数,而且 nkn^k 这个项处理起来很麻烦。

所以,我们要换一个更聪明的思路!在处理和式的时候,如果幂函数 nkn^k 让我们头疼,一个常见的技巧就是把它转换成下降幂或者组合数,因为后者在和式运算中通常有更好的性质。

这里,我们选择组合数作为桥梁。我们知道 nkn^k 可以表示为组合数的线性组合:

nk=i=0kS2(k,i)i!(ni)n^k = \sum_{i=0}^{k} S_2(k, i) \cdot i! \cdot \binom{n}{i}

其中 S2(k,i)S_2(k, i) 是第二类斯特林数,表示将 kk 个不同元素划分成 ii 个非空子集的方案数。

根据期望的线性性质,我们可以把求和符号交换一下:

E[Xk]=E[i=0kS2(k,i)i!(Xi)]=i=0kS2(k,i)i!E[(Xi)]E[X^k] = E\left[\sum_{i=0}^{k} S_2(k, i) \cdot i! \cdot \binom{X}{i}\right] = \sum_{i=0}^{k} S_2(k, i) \cdot i! \cdot E\left[\binom{X}{i}\right]

这样,问题就分解成两个子问题了,喵~

  1. 计算 gi=E[(Xi)]g_i = E[\binom{X}{i}] 对于 i=0,1,,ki = 0, 1, \dots, k 的值。
  2. 计算系数 Ci=S2(k,i)i!C_i = S_2(k, i) \cdot i! 对于 i=0,1,,ki = 0, 1, \dots, k 的值。

只要解决了这两个问题,把它们乘起来再相加,就能得到最终答案啦!

Step 1: 计算 gi=E[(Xi)]g_i = E\left[\binom{X}{i}\right]

这个期望值 gig_i 被称为“二项矩”。直接计算它仍然不容易,但我们可以借助强大的生成函数

XX 的概率生成函数 (PGF) 为 G(z)=E[zX]G(z) = E[z^X]。对于这个问题中的负二项分布,它的 PGF 是:

G(z)=(pz1(1p)z)mG(z) = \left(\frac{pz}{1-(1-p)z}\right)^m

其中 p=a/bp = a/b 是单次中奖概率。为了方便,我们记 q=1pq = 1-p

我们想求的 gi=E[(Xi)]g_i = E[\binom{X}{i}] 和另一个叫做“阶乘矩生成函数”的东西关系密切。考虑 E[(1+y)X]E[(1+y)^X]

E[(1+y)X]=E[i=0X(Xi)yi]=i=0E[(Xi)]yi=i=0giyiE[(1+y)^X] = E\left[\sum_{i=0}^X \binom{X}{i} y^i\right] = \sum_{i=0}^{\infty} E\left[\binom{X}{i}\right] y^i = \sum_{i=0}^{\infty} g_i y^i

同时,我们发现 E[(1+y)X]=G(1+y)E[(1+y)^X] = G(1+y)。所以,我们只需要把 z=1+yz=1+y 代入 G(z)G(z),然后展开成 yy 的幂级数,第 ii 项的系数就是 gig_i

H(y)=G(1+y)H(y) = G(1+y):

H(y)=(p(1+y)1q(1+y))m=(p(1+y)pqy)m=(1+y1(q/p)y)mH(y) = \left(\frac{p(1+y)}{1-q(1+y)}\right)^m = \left(\frac{p(1+y)}{p-qy}\right)^m = \left(\frac{1+y}{1 - (q/p)y}\right)^m

直接从这个式子展开求系数还是有点复杂。但对于一个生成函数,我们总可以尝试找到它满足的微分方程,从而得到其系数的递推式!

H(y)H(y) 求导,经过一番推导(喵~,这里的数学推导有点小魔法,但相信我哦!),可以得到下面这个等式:

(1(q/p)y)(1+y)H(y)=m(1+q/p)H(y)(1 - (q/p)y)(1+y) H'(y) = m(1 + q/p) H(y)

H(y)=giyiH(y) = \sum g_i y^iH(y)=igiyi1H'(y) = \sum i g_i y^{i-1} 代入,比较等式两边 yi1y^{i-1} 的系数,我们就能得到 gig_i 的线性递推关系:

ipgi=(m(pq)(i1))gi1+q(i2)gi2i \cdot p \cdot g_i = (m - (p-q)(i-1)) g_{i-1} + q(i-2) g_{i-2}

这个递推式太棒啦!我们可以从 g0=E[(X0)]=1g_0 = E[\binom{X}{0}] = 1g1=E[(X1)]=E[X]=m/pg_1 = E[\binom{X}{1}] = E[X] = m/p 开始,在 O(k)O(k) 的时间内计算出所有我们需要的 gig_i

Step 2: 计算系数 Ci=S2(k,i)i!C_i = S_2(k, i) \cdot i!

现在我们来解决第二个子问题。根据第二类斯特林数的公式,我们有:

S2(k,i)=1i!j=0i(1)ij(ij)jkS_2(k, i) = \frac{1}{i!} \sum_{j=0}^{i} (-1)^{i-j} \binom{i}{j} j^k

所以,我们想要的系数是:

Ci=S2(k,i)i!=j=0i(1)ij(ij)jkC_i = S_2(k, i) \cdot i! = \sum_{j=0}^{i} (-1)^{i-j} \binom{i}{j} j^k

这个形式是一个典型的卷积!让我们把它写得更清楚一点:

Ci=j=0i((1)ij1(ij)!)(jkj!)1j!    Cii!=j=0i((1)ij(ij)!)(jkj!)C_i = \sum_{j=0}^{i} \left( (-1)^{i-j} \frac{1}{(i-j)!} \right) \cdot (j^k \cdot j!) \cdot \frac{1}{j!} \implies \frac{C_i}{i!} = \sum_{j=0}^{i} \left(\frac{(-1)^{i-j}}{(i-j)!}\right) \cdot \left(\frac{j^k}{j!}\right)

At=(1)tt!A_t = \frac{(-1)^t}{t!}Bt=tkt!B_t = \frac{t^k}{t!}。那么 Cii!\frac{C_i}{i!} 就是序列 AABB 的卷积结果的第 ii 项。

我们可以用NTT (快速数论变换)O(klogk)O(k \log k) 的时间内计算这个卷积。 具体步骤是:

  1. 预处理阶乘、逆阶乘,以及用线性筛预处理出 jk(mod998244353)j^k \pmod{998244353}
  2. 构造多项式 A(x)=AtxtA(x) = \sum A_t x^tB(x)=BtxtB(x) = \sum B_t x^t
  3. 用 NTT 计算出 D(x)=A(x)B(x)D(x) = A(x)B(x)
  4. D(x)D(x) 的第 ii 项系数 did_i 就是 Cii!\frac{C_i}{i!},所以 Ci=dii!C_i = d_i \cdot i!

重要提示喵:题目的数据范围 k107k \le 10^7 非常大,O(klogk)O(k \log k) 的 NTT 算法会超时。给出的参考代码使用了一种更高级的 O(k)O(k) 算法,它通过寻找 CiC_i 自身的线性递推关系来计算。那个推导过程非常复杂,涉及到更深的生成函数和微分方程技巧。不过,理解 O(klogk)O(k \log k) 的 NTT 方法对于解决类似问题已经是非常重要的一步了,所以本喵将为你展示这个版本的实现,它足以通过 kk 较小的情况,并且更有教育意义哦!

总结一下

我们的解题路线图是:

  1. E[Xk]E[X^k] 转化为 E[(Xi)]E[\binom{X}{i}] 的线性组合。
  2. 用生成函数推导出 E[(Xi)]E[\binom{X}{i}]O(k)O(k) 线性递推式,并计算出它们。
  3. 用 NTT 在 O(klogk)O(k \log k) 时间内计算出组合系数 S2(k,i)i!S_2(k, i) \cdot i!
  4. 将两部分结果合并,得到最终答案。

准备好了吗?让我们开始写代码吧,喵~

代码实现

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

// 使用 long long 防止溢出
using ll = long long;

const int MOD = 998244353;
const int NTT_G = 3; // NTT的原根

// 快速幂,喵~
ll power(ll base, ll exp) {
    ll res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// 模逆元
ll modInverse(ll n) {
    return power(n, MOD - 2);
}

// NTT 核心实现
namespace NTT {
    vector<ll> rev;
    void ntt(vector<ll>& a, bool invert) {
        int n = a.size();
        if (rev.size() != n) {
            rev.resize(n);
            int log_n = __builtin_ctz(n);
            for (int i = 0; i < n; i++) {
                rev[i] = 0;
                for (int j = 0; j < log_n; j++) {
                    if ((i >> j) & 1) {
                        rev[i] |= 1 << (log_n - 1 - j);
                    }
                }
            }
        }

        for (int i = 0; i < n; i++) {
            if (i < rev[i]) {
                swap(a[i], a[rev[i]]);
            }
        }

        for (int len = 2; len <= n; len <<= 1) {
            ll wlen = power(NTT_G, (MOD - 1) / len);
            if (invert) {
                wlen = modInverse(wlen);
            }
            for (int i = 0; i < n; i += len) {
                ll w = 1;
                for (int j = 0; j < len / 2; j++) {
                    ll u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
                    a[i + j] = (u + v) % MOD;
                    a[i + j + len / 2] = (u - v + MOD) % MOD;
                    w = (w * wlen) % MOD;
                }
            }
        }

        if (invert) {
            ll n_inv = modInverse(n);
            for (ll& x : a) {
                x = (x * n_inv) % MOD;
            }
        }
    }

    vector<ll> multiply(vector<ll> a, vector<ll> b) {
        int sz = 1;
        while (sz < a.size() + b.size()) sz <<= 1;
        a.resize(sz);
        b.resize(sz);

        ntt(a, false);
        ntt(b, false);

        for (int i = 0; i < sz; i++) {
            a[i] = (a[i] * b[i]) % MOD;
        }

        ntt(a, true);
        return a;
    }
}

// 线性筛,用于计算 i^k
vector<ll> power_k;
void sieve(int k) {
    power_k.resize(k + 1);
    vector<int> primes;
    vector<bool> is_prime(k + 1, true);
    is_prime[0] = is_prime[1] = false;
    power_k[1] = 1;
    for (int i = 2; i <= k; i++) {
        if (is_prime[i]) {
            primes.push_back(i);
            power_k[i] = power(i, k);
        }
        for (int p : primes) {
            if ((ll)i * p > k) break;
            is_prime[i * p] = false;
            power_k[i * p] = (power_k[i] * power_k[p]) % MOD;
            if (i % p == 0) break;
        }
    }
}


int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int m, k, a, b;
    cin >> m >> k >> a >> b;

    ll p_prob = (ll)a * modInverse(b) % MOD;
    ll q_prob = (1 - p_prob + MOD) % MOD;

    // --- Step 1: 计算 g_i = E[binom(X, i)] ---
    vector<ll> g(k + 1);
    g[0] = 1;
    ll p_inv = modInverse(p_prob);
    g[1] = (ll)m * p_inv % MOD;

    ll p_minus_q = (p_prob - q_prob + MOD) % MOD;

    for (int i = 2; i <= k; i++) {
        ll term1_coeff = (m - (p_minus_q * (i - 1)) % MOD + MOD) % MOD;
        ll term1 = (term1_coeff * g[i - 1]) % MOD;
        ll term2 = (q_prob * (i - 2)) % MOD;
        term2 = (term2 * g[i - 2]) % MOD;
        
        ll rhs = (term1 + term2) % MOD;
        g[i] = (rhs * p_inv) % MOD;
        g[i] = (g[i] * modInverse(i)) % MOD;
    }

    // --- Step 2: 计算系数 C_i = S2(k, i) * i! ---
    // 预处理
    sieve(k);
    vector<ll> fact(k + 1);
    vector<ll> inv_fact(k + 1);
    fact[0] = 1;
    inv_fact[0] = 1;
    for (int i = 1; i <= k; i++) {
        fact[i] = (fact[i - 1] * i) % MOD;
        inv_fact[i] = modInverse(fact[i]);
    }

    // 构造卷积多项式
    vector<ll> A(k + 1), B(k + 1);
    for (int i = 0; i <= k; i++) {
        A[i] = inv_fact[i];
        if (i % 2 == 1) {
            A[i] = (MOD - A[i]) % MOD;
        }
        B[i] = (power_k[i] * inv_fact[i]) % MOD;
    }

    // NTT计算卷积
    vector<ll> C_div_fact = NTT::multiply(A, B);
    
    vector<ll> C(k + 1);
    for (int i = 0; i <= k; i++) {
        C[i] = (C_div_fact[i] * fact[i]) % MOD;
    }

    // --- Final Step: 合并结果 ---
    ll final_ans = 0;
    for (int i = 0; i <= k; i++) {
        final_ans = (final_ans + C[i] * g[i]) % MOD;
    }

    cout << final_ans << endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(klogk)O(k \log k)

    • 计算 gig_i 的递推需要 O(k)O(k)
    • 线性筛预处理 iki^k 需要 O(k)O(k)
    • 预处理阶乘和逆阶乘需要 O(k)O(k)
    • 构造多项式 A(x)A(x)B(x)B(x) 需要 O(k)O(k)
    • NTT 卷积是主要瓶颈,需要 O(klogk)O(k \log k)
    • 最后合并结果需要 O(k)O(k)
    • 所以总时间复杂度由 NTT 决定,为 O(klogk)O(k \log k)
  • 空间复杂度: O(k)O(k)

    • 我们需要存储 gig_i, CiC_i, 阶乘,逆阶乘,以及 NTT 所需的数组,它们的大小都与 kk 线性相关。

知识点总结

  1. 期望的线性性质: 它是我们能够将复杂期望分解为简单期望之和的基础,是解决期望问题的万能钥匙,喵~
  2. 组合恒等式: 核心是将 nkn^k 通过第二类斯特林数展开成 (ni)\binom{n}{i} 的线性组合,这是从幂函数到组合世界的桥梁。
  3. 生成函数: 解决序列和计数问题的超级大杀器!通过将序列表示为函数,我们可以用求导、积分等分析工具来研究序列的性质,比如推导递推关系。
  4. 线性递推: 许多组合计数问题最终都可以归结为求解一个线性递推式,这通常比直接计算要快得多。
  5. NTT/FFT: 快速计算多项式卷积的利器。看到形如 Ci=j=0iAjBijC_i = \sum_{j=0}^i A_j B_{i-j} 的式子,就要立刻想到它!
  6. 线性筛: 不仅可以筛素数,还可以用来在 O(N)O(N) 时间内计算一些积性函数或者像 iki^k 这样的值。

这道题综合了好多知识点,是一道非常好的练习题!虽然完全解法非常硬核,但一步步分解下来,每一步都是我们应该掌握的经典技巧。希望这篇题解能帮到你,加油哦,喵~!