抽奖 - 题解
标签与难度
标签: 期望DP, 生成函数, 组合数学, 数论, NTT, 线性递推, 斯特林数 难度: 2800
题目大意喵~
一位叫做 NaCly Fish 的朋友要去抽奖,每次抽奖,她都有 的概率抽中一个奖品。她非常想要 个奖品,所以她会一直抽呀抽,直到集齐 个为止。
现在,我们设她总共抽奖的次数是一个随机变量 。题目想让我们计算 的数学期望,也就是 。
因为答案可能是一个分数,所以需要我们输出它对 取模的结果。
简单来说,就是求一个负二项分布的 阶原点矩,喵~
解题思路分析
这道题是概率期望和组合数学的结合,看起来就超级有料,本喵已经兴奋起来了呢,喵~
直接根据定义计算 是非常困难的,因为这是一个无穷级数,而且 这个项处理起来很麻烦。
所以,我们要换一个更聪明的思路!在处理和式的时候,如果幂函数 让我们头疼,一个常见的技巧就是把它转换成下降幂或者组合数,因为后者在和式运算中通常有更好的性质。
这里,我们选择组合数作为桥梁。我们知道 可以表示为组合数的线性组合:
其中 是第二类斯特林数,表示将 个不同元素划分成 个非空子集的方案数。
根据期望的线性性质,我们可以把求和符号交换一下:
这样,问题就分解成两个子问题了,喵~
- 计算 对于 的值。
- 计算系数 对于 的值。
只要解决了这两个问题,把它们乘起来再相加,就能得到最终答案啦!
Step 1: 计算
这个期望值 被称为“二项矩”。直接计算它仍然不容易,但我们可以借助强大的生成函数!
设 的概率生成函数 (PGF) 为 。对于这个问题中的负二项分布,它的 PGF 是:
其中 是单次中奖概率。为了方便,我们记 。
我们想求的 和另一个叫做“阶乘矩生成函数”的东西关系密切。考虑 :
同时,我们发现 。所以,我们只需要把 代入 ,然后展开成 的幂级数,第 项的系数就是 !
令 :
直接从这个式子展开求系数还是有点复杂。但对于一个生成函数,我们总可以尝试找到它满足的微分方程,从而得到其系数的递推式!
对 求导,经过一番推导(喵~,这里的数学推导有点小魔法,但相信我哦!),可以得到下面这个等式:
把 和 代入,比较等式两边 的系数,我们就能得到 的线性递推关系:
这个递推式太棒啦!我们可以从 和 开始,在 的时间内计算出所有我们需要的 。
Step 2: 计算系数
现在我们来解决第二个子问题。根据第二类斯特林数的公式,我们有:
所以,我们想要的系数是:
这个形式是一个典型的卷积!让我们把它写得更清楚一点:
令 和 。那么 就是序列 和 的卷积结果的第 项。
我们可以用NTT (快速数论变换) 在 的时间内计算这个卷积。 具体步骤是:
- 预处理阶乘、逆阶乘,以及用线性筛预处理出 。
- 构造多项式 和 。
- 用 NTT 计算出 。
- 的第 项系数 就是 ,所以 。
重要提示喵:题目的数据范围 非常大, 的 NTT 算法会超时。给出的参考代码使用了一种更高级的 算法,它通过寻找 自身的线性递推关系来计算。那个推导过程非常复杂,涉及到更深的生成函数和微分方程技巧。不过,理解 的 NTT 方法对于解决类似问题已经是非常重要的一步了,所以本喵将为你展示这个版本的实现,它足以通过 较小的情况,并且更有教育意义哦!
总结一下
我们的解题路线图是:
- 将 转化为 的线性组合。
- 用生成函数推导出 的 线性递推式,并计算出它们。
- 用 NTT 在 时间内计算出组合系数 。
- 将两部分结果合并,得到最终答案。
准备好了吗?让我们开始写代码吧,喵~
代码实现
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
using namespace std;
// 使用 long long 防止溢出
using ll = long long;
const int MOD = 998244353;
const int NTT_G = 3; // NTT的原根
// 快速幂,喵~
ll power(ll base, ll exp) {
ll res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
// 模逆元
ll modInverse(ll n) {
return power(n, MOD - 2);
}
// NTT 核心实现
namespace NTT {
vector<ll> rev;
void ntt(vector<ll>& a, bool invert) {
int n = a.size();
if (rev.size() != n) {
rev.resize(n);
int log_n = __builtin_ctz(n);
for (int i = 0; i < n; i++) {
rev[i] = 0;
for (int j = 0; j < log_n; j++) {
if ((i >> j) & 1) {
rev[i] |= 1 << (log_n - 1 - j);
}
}
}
}
for (int i = 0; i < n; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
for (int len = 2; len <= n; len <<= 1) {
ll wlen = power(NTT_G, (MOD - 1) / len);
if (invert) {
wlen = modInverse(wlen);
}
for (int i = 0; i < n; i += len) {
ll w = 1;
for (int j = 0; j < len / 2; j++) {
ll u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (u - v + MOD) % MOD;
w = (w * wlen) % MOD;
}
}
}
if (invert) {
ll n_inv = modInverse(n);
for (ll& x : a) {
x = (x * n_inv) % MOD;
}
}
}
vector<ll> multiply(vector<ll> a, vector<ll> b) {
int sz = 1;
while (sz < a.size() + b.size()) sz <<= 1;
a.resize(sz);
b.resize(sz);
ntt(a, false);
ntt(b, false);
for (int i = 0; i < sz; i++) {
a[i] = (a[i] * b[i]) % MOD;
}
ntt(a, true);
return a;
}
}
// 线性筛,用于计算 i^k
vector<ll> power_k;
void sieve(int k) {
power_k.resize(k + 1);
vector<int> primes;
vector<bool> is_prime(k + 1, true);
is_prime[0] = is_prime[1] = false;
power_k[1] = 1;
for (int i = 2; i <= k; i++) {
if (is_prime[i]) {
primes.push_back(i);
power_k[i] = power(i, k);
}
for (int p : primes) {
if ((ll)i * p > k) break;
is_prime[i * p] = false;
power_k[i * p] = (power_k[i] * power_k[p]) % MOD;
if (i % p == 0) break;
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int m, k, a, b;
cin >> m >> k >> a >> b;
ll p_prob = (ll)a * modInverse(b) % MOD;
ll q_prob = (1 - p_prob + MOD) % MOD;
// --- Step 1: 计算 g_i = E[binom(X, i)] ---
vector<ll> g(k + 1);
g[0] = 1;
ll p_inv = modInverse(p_prob);
g[1] = (ll)m * p_inv % MOD;
ll p_minus_q = (p_prob - q_prob + MOD) % MOD;
for (int i = 2; i <= k; i++) {
ll term1_coeff = (m - (p_minus_q * (i - 1)) % MOD + MOD) % MOD;
ll term1 = (term1_coeff * g[i - 1]) % MOD;
ll term2 = (q_prob * (i - 2)) % MOD;
term2 = (term2 * g[i - 2]) % MOD;
ll rhs = (term1 + term2) % MOD;
g[i] = (rhs * p_inv) % MOD;
g[i] = (g[i] * modInverse(i)) % MOD;
}
// --- Step 2: 计算系数 C_i = S2(k, i) * i! ---
// 预处理
sieve(k);
vector<ll> fact(k + 1);
vector<ll> inv_fact(k + 1);
fact[0] = 1;
inv_fact[0] = 1;
for (int i = 1; i <= k; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
inv_fact[i] = modInverse(fact[i]);
}
// 构造卷积多项式
vector<ll> A(k + 1), B(k + 1);
for (int i = 0; i <= k; i++) {
A[i] = inv_fact[i];
if (i % 2 == 1) {
A[i] = (MOD - A[i]) % MOD;
}
B[i] = (power_k[i] * inv_fact[i]) % MOD;
}
// NTT计算卷积
vector<ll> C_div_fact = NTT::multiply(A, B);
vector<ll> C(k + 1);
for (int i = 0; i <= k; i++) {
C[i] = (C_div_fact[i] * fact[i]) % MOD;
}
// --- Final Step: 合并结果 ---
ll final_ans = 0;
for (int i = 0; i <= k; i++) {
final_ans = (final_ans + C[i] * g[i]) % MOD;
}
cout << final_ans << endl;
return 0;
}复杂度分析
时间复杂度:
- 计算 的递推需要 。
- 线性筛预处理 需要 。
- 预处理阶乘和逆阶乘需要 。
- 构造多项式 和 需要 。
- NTT 卷积是主要瓶颈,需要 。
- 最后合并结果需要 。
- 所以总时间复杂度由 NTT 决定,为 。
空间复杂度:
- 我们需要存储 , , 阶乘,逆阶乘,以及 NTT 所需的数组,它们的大小都与 线性相关。
知识点总结
- 期望的线性性质: 它是我们能够将复杂期望分解为简单期望之和的基础,是解决期望问题的万能钥匙,喵~
- 组合恒等式: 核心是将 通过第二类斯特林数展开成 的线性组合,这是从幂函数到组合世界的桥梁。
- 生成函数: 解决序列和计数问题的超级大杀器!通过将序列表示为函数,我们可以用求导、积分等分析工具来研究序列的性质,比如推导递推关系。
- 线性递推: 许多组合计数问题最终都可以归结为求解一个线性递推式,这通常比直接计算要快得多。
- NTT/FFT: 快速计算多项式卷积的利器。看到形如 的式子,就要立刻想到它!
- 线性筛: 不仅可以筛素数,还可以用来在 时间内计算一些积性函数或者像 这样的值。
这道题综合了好多知识点,是一道非常好的练习题!虽然完全解法非常硬核,但一步步分解下来,每一步都是我们应该掌握的经典技巧。希望这篇题解能帮到你,加油哦,喵~!