Skip to content

233玩游戏 - 题解

标签与难度

标签: 生成函数, 多项式, NTT, 概率DP, 分治 难度: 2500

题目大意喵~

各位Master,下午好喵~ 让我们一起来玩个有趣的游戏吧!

游戏有 nn 轮,我们从 1级 开始,初始血量是 n+on+o

在每一轮,如果我们的等级是 xx,血量是 yy,会发生以下事件:

  1. pp 的概率 升级
    • 如果当前是1级,会升到2级,但血量会减少 oo 点。
    • 如果当前是2级,再次“升级”的话,血量会直接变成0,游戏就结束啦!
  2. 1p1-p 的概率 不升级
    • 在这种情况下,又有 1y\frac{1}{y} 的概率血量减少1点。
    • 剩下 11y1-\frac{1}{y} 的概率血量不变。

每一次事件(升级、掉血、或无事发生)都标志着一轮的结束。

我们的任务是,对于从第1轮到第 nn 轮的每一轮 ii,计算出在第 ii 轮结束后,血量恰好为 tt 的概率是多少。所有概率都要对 998244353998244353 取模哦!

解题思路分析

这道题看起来像一个动态规划问题,但是状态包含了轮数、等级和血量,而且血量的转移概率还和当前的血量有关,直接DP的话状态空间太大了,会超时的说!(>ω<)

当遇到这种复杂的转移和需要求精确值的概率DP时,通常是我们的好朋友——生成函数——大显身手的时候了,喵!

我们可以用生成函数来描述概率分布。不过,这次我们不把多项式的变量 zz 对应血量,而是对应轮数

建立生成函数模型

我们定义两个系列的生成函数:

  • Fh(z)=i=0fi(h)ziF_h(z) = \sum_{i=0}^{\infty} f_i(h) z^i:其中 fi(h)f_i(h) 是在第 ii 轮后,等级为1、血量为 hh 的概率。
  • Gh(z)=i=0gi(h)ziG_h(z) = \sum_{i=0}^{\infty} g_i(h) z^i:其中 gi(h)g_i(h) 是在第 ii 轮后,等级为2、血量为 hh 的概率。

我们的最终目标是求出 fi(t)+gi(t)f_i(t) + g_i(t),也就是 [zi](Ft(z)+Gt(z))[z^i](F_t(z) + G_t(z))

根据题意,我们可以写出概率 fi(h)f_i(h)gi(h)g_i(h) 的递推关系: fi(h)=(1p)(11h)fi1(h)+(1p)1h+1fi1(h+1)f_i(h) = (1-p)(1-\frac{1}{h})f_{i-1}(h) + (1-p)\frac{1}{h+1}f_{i-1}(h+1)gi(h)=(1p)(11h)gi1(h)+(1p)1h+1gi1(h+1)+pfi1(h+o)g_i(h) = (1-p)(1-\frac{1}{h})g_{i-1}(h) + (1-p)\frac{1}{h+1}g_{i-1}(h+1) + p \cdot f_{i-1}(h+o)

初始条件是 f0(n+o)=1f_0(n+o) = 1,其他所有在第0轮的概率都是0。

把这些递推关系转换成生成函数的形式,两边同时乘以 ziz^i 再对所有 i1i \ge 1 求和,经过一番魔法推导(涉及到分离 i=0i=0 的项和重新索引求和),我们可以得到关于 Fh(z)F_h(z)Gh(z)G_h(z) 的关系式:

Fh(z)(1z(1p)h1h)=δh,n+o+z(1p)1h+1Fh+1(z)F_h(z) \left(1 - z(1-p)\frac{h-1}{h}\right) = \delta_{h, n+o} + z(1-p)\frac{1}{h+1} F_{h+1}(z)Gh(z)(1z(1p)h1h)=z(1p)1h+1Gh+1(z)+zpFh+o(z)G_h(z) \left(1 - z(1-p)\frac{h-1}{h}\right) = z(1-p)\frac{1}{h+1} G_{h+1}(z) + zp F_{h+o}(z)

这里 δh,n+o\delta_{h, n+o} 是一个克罗内克符号,当 h=n+oh=n+o 时为1,否则为0。

求解 Ft(z)F_t(z) (等级1的概率)

我们来简化一下记号,让我的爪爪打字轻松一点~ 令 q=1pq = 1-p, Dh(z)=1zqh1hD_h(z) = 1 - zq\frac{h-1}{h}, Eh(z)=zq1h+1E_h(z) = zq\frac{1}{h+1}

关系式变成了: Fh(z)Dh(z)=δh,n+o+Eh(z)Fh+1(z)F_h(z) D_h(z) = \delta_{h, n+o} + E_h(z) F_{h+1}(z)

因为血量最高从 n+on+o 开始,所以 Fh(z)=0F_h(z)=0 for h>n+oh > n+o。 我们可以从 h=n+oh=n+o 开始向下递推: Fn+o(z)Dn+o(z)=1    Fn+o(z)=1Dn+o(z)F_{n+o}(z) D_{n+o}(z) = 1 \implies F_{n+o}(z) = \frac{1}{D_{n+o}(z)}Fn+o1(z)Dn+o1(z)=En+o1(z)Fn+o(z)    Fn+o1(z)=En+o1(z)Dn+o1(z)Fn+o(z)F_{n+o-1}(z) D_{n+o-1}(z) = E_{n+o-1}(z) F_{n+o}(z) \implies F_{n+o-1}(z) = \frac{E_{n+o-1}(z)}{D_{n+o-1}(z)} F_{n+o}(z) ... 一直推到我们关心的血量 tt,可以得到:

Ft(z)=k=tn+o1Ek(z)k=tn+oDk(z)F_t(z) = \frac{\prod_{k=t}^{n+o-1} E_k(z)}{\prod_{k=t}^{n+o} D_k(z)}

这个表达式的分子和分母都是关于 zz 的多项式!

  • 分母 P(z)=k=tn+oDk(z)P(z) = \prod_{k=t}^{n+o} D_k(z) 是一个度数为 n+ot+1n+o-t+1 的多项式。我们可以用分治+NTTO(Nlog2N)O(N \log^2 N) 的时间里把它算出来。
  • 分子 NF(z)=k=tn+o1Ek(z)N_F(z) = \prod_{k=t}^{n+o-1} E_k(z) 更简单,因为 Ek(z)E_k(z)zz 的一次项,所以分子是 CFzn+otC_F \cdot z^{n+o-t} 的形式,其中 CFC_F 是一个可以 O(N)O(N) 预处理逆元后算出的常数。

于是,Ft(z)=NF(z)P(z)1F_t(z) = N_F(z) \cdot P(z)^{-1}。我们只需求出 P(z)P(z)多项式逆元,再和分子乘起来,就得到了 Ft(z)F_t(z) 的系数,也就是各轮在等级1、血量为 tt 的概率啦!

求解 Gt(z)G_t(z) (等级2的概率)

等级2的生成函数 Gt(z)G_t(z) 稍微复杂一点,喵~ Gh(z)Dh(z)=Eh(z)Gh+1(z)+zpFh+o(z)G_h(z) D_h(z) = E_h(z) G_{h+1}(z) + zp F_{h+o}(z)

这是一个关于 hh 的线性递推式。把它展开,可以得到 Gt(z)G_t(z) 的表达式:

Gt(z)=j=tn(zpFj+o(z)Dj(z)k=tj1Ek(z)Dk(z))G_t(z) = \sum_{j=t}^{n} \left( \frac{zp F_{j+o}(z)}{D_j(z)} \prod_{k=t}^{j-1} \frac{E_k(z)}{D_k(z)} \right)

把我们之前解出的 Fh(z)F_h(z) 代入,并通分,最后可以化简成:

Gt(z)=1P(z)(zpj=tn(Numj(z)k=j+1j+o1Dk(z)))G_t(z) = \frac{1}{P(z)} \left( zp \sum_{j=t}^{n} \left( \text{Num}_j(z) \cdot \prod_{k=j+1}^{j+o-1} D_k(z) \right) \right)

其中 P(z)P(z) 是我们上面求出的同一个分母,Numj(z)\text{Num}_j(z) 是一个只含常数和 zz 的幂次的项。

这个式子的核心是计算分子那个巨大的和式。我们令它为 NG(z)N_G(z)。 仔细观察,k=j+1j+o1Dk(z)\prod_{k=j+1}^{j+o-1} D_k(z) 是一个度数只有 o1o-1 的多项式,我们可以把它看作一个大小为 o1o-1滑动窗口。当求和的索引 jj 变化时,这个窗口也跟着滑动。 而 Numj(z)\text{Num}_j(z) 部分化简后是 CjzntC_j \cdot z^{n-t} 的形式。

所以,NG(z)=znt+1j=tnCj(k=j+1j+o1Dk(z))N_G(z) = z^{n-t+1} \sum_{j=t}^{n} C_j' \left( \prod_{k=j+1}^{j+o-1} D_k(z) \right)

我们可以写一个循环,从 j=nj=ntt

  1. 维护一个大小为 o1o-1 的滑动窗口,计算出窗口内多项式 k=j+1j+o1Dk(z)\prod_{k=j+1}^{j+o-1} D_k(z)。窗口每滑动一次,就乘上一个新的线性多项式,再除以一个旧的。因为是线性多项式,乘除法都是 O(o)O(o) 的。
  2. 计算出常数 CjC_j'
  3. CjC_j' 乘上窗口多项式,累加到一个总和多项式 S(z)S(z) 上。

这个过程的复杂度是 O(no)O(n \cdot o)。只要 oo 不太大,这就是可以接受的! 计算出 S(z)S(z) 后,Gt(z)=znt+1S(z)P(z)1G_t(z) = z^{n-t+1} S(z) \cdot P(z)^{-1}

总结与合成

  1. 分治NTT求出公共分母 P(z)=k=tn+o(1z(1p)k1k)P(z) = \prod_{k=t}^{n+o} (1 - z(1-p)\frac{k-1}{k})
  2. 多项式求逆得到 P(z)1P(z)^{-1}
  3. 计算 Ft(z)F_t(z) 的系数。
  4. O(no)O(n \cdot o)滑动窗口方法计算出 Gt(z)G_t(z) 的分子多项式 NG(z)N_G(z)
  5. 计算 Gt(z)G_t(z) 的系数。
  6. Ft(z)F_t(z)Gt(z)G_t(z) 的系数相加,得到每一轮 ii 的最终答案 [zi](Ft(z)+Gt(z))[z^i](F_t(z)+G_t(z))

好啦,思路清晰了,让我们用代码来实现这个魔法吧!喵~

代码实现

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

// MOD and NTT parameters
const int MOD = 998244353;
const int G = 3;

// Fast power function
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// Modular inverse function
long long modInverse(long long n) {
    return power(n, MOD - 2);
}

// NTT implementation
namespace Poly {
    std::vector<int> rev;
    std::vector<long long> w;

    void precompute_fft(int n) {
        if (w.size() >= n) return;
        w.resize(n);
        int l = 0;
        while ((1 << l) < n) l++;
        rev.resize(n);
        for (int i = 0; i < n; i++) {
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
        }
        long long wn = power(G, (MOD - 1) / n);
        w[0] = 1;
        for (int i = 1; i < n; i++) {
            w[i] = w[i - 1] * wn % MOD;
        }
    }

    void ntt(std::vector<long long>& a, bool invert) {
        int n = a.size();
        for (int i = 0; i < n; i++) {
            if (i < rev[i]) std::swap(a[i], a[rev[i]]);
        }
        for (int len = 2; len <= n; len <<= 1) {
            int mid = len >> 1;
            for (int i = 0; i < n; i += len) {
                for (int j = 0; j < mid; j++) {
                    long long t = w[n / len * j] * a[i + j + mid] % MOD;
                    a[i + j + mid] = (a[i + j] - t + MOD) % MOD;
                    a[i + j] = (a[i + j] + t) % MOD;
                }
            }
        }
        if (invert) {
            std::reverse(a.begin() + 1, a.end());
            long long inv_n = modInverse(n);
            for (int i = 0; i < n; i++) {
                a[i] = a[i] * inv_n % MOD;
            }
        }
    }

    std::vector<long long> multiply(std::vector<long long> a, std::vector<long long> b) {
        int sz = 1;
        while (sz < a.size() + b.size()) sz <<= 1;
        a.resize(sz);
        b.resize(sz);

        precompute_fft(sz);
        ntt(a, false);
        ntt(b, false);

        for (int i = 0; i < sz; i++) a[i] = a[i] * b[i] % MOD;

        ntt(a, true);
        return a;
    }

    std::vector<long long> inverse(const std::vector<long long>& a, int n) {
        if (n == 1) return {modInverse(a[0])};
        
        std::vector<long long> a0 = inverse(a, (n + 1) / 2);
        int sz = 1;
        while (sz < 2 * n) sz <<= 1;
        
        std::vector<long long> current_a(a.begin(), a.begin() + std::min((int)a.size(), n));
        current_a.resize(sz);
        a0.resize(sz);

        precompute_fft(sz);
        ntt(current_a, false);
        ntt(a0, false);

        for (int i = 0; i < sz; i++) {
            a0[i] = a0[i] * (2 - current_a[i] * a0[i] % MOD + MOD) % MOD;
        }

        ntt(a0, true);
        a0.resize(n);
        return a0;
    }
}

// Function to compute product of linear polynomials using divide and conquer
std::vector<long long> product_tree(const std::vector<std::vector<long long>>& polys, int l, int r) {
    if (l == r) return polys[l];
    if (l + 1 == r) return Poly::multiply(polys[l], polys[r]);
    int mid = l + (r - l) / 2;
    return Poly::multiply(product_tree(polys, l, mid), product_tree(polys, mid + 1, r));
}

int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int n, a, b, t, o;
    std::cin >> n >> a >> b >> t >> o;

    long long p = (long long)a * modInverse(b) % MOD;
    long long q = (1 - p + MOD) % MOD;

    // Precompute modular inverses for k and coefficients v_k
    std::vector<long long> invs(n + o + 2);
    std::vector<long long> v(n + o + 2);
    invs[1] = 1;
    for (int k = 2; k <= n + o + 1; ++k) {
        invs[k] = MOD - (long long)(MOD / k) * invs[MOD % k] % MOD;
    }
    for (int k = 1; k <= n + o + 1; ++k) {
        v[k] = q * (k - 1) % MOD * invs[k] % MOD;
    }

    // Part 1: Compute the common denominator polynomial P(z)
    std::vector<std::vector<long long>> d_polys;
    for (int k = t; k <= n + o; ++k) {
        d_polys.push_back({1, (MOD - v[k]) % MOD});
    }
    std::vector<long long> P = product_tree(d_polys, 0, d_polys.size() - 1);
    P.resize(n + 1, 0);

    // Get its inverse
    std::vector<long long> P_inv = Poly::inverse(P, n + 1);

    // Part 2: Compute F_t(z)
    int shift_f = n + o - t;
    long long C_f = 1;
    for (int k = t; k <= n + o - 1; ++k) {
        C_f = C_f * q % MOD * invs[k + 1] % MOD;
    }
    std::vector<long long> prob_l1(n + 1, 0);
    for (int i = 0; i <= n; ++i) {
        if (i >= shift_f) {
            prob_l1[i] = C_f * P_inv[i - shift_f] % MOD;
        }
    }

    // Part 3: Compute G_t(z)'s numerator part S(z)
    std::vector<long long> S(o + 1, 0);
    if (t <= n) {
        // Sliding window for polynomial product W_j(z)
        std::vector<long long> W = {1};
        for (int k = n + 1; k <= n + o - 1; ++k) {
            std::vector<long long> next_term = {1, (MOD - v[k]) % MOD};
            W = Poly::multiply(W, next_term);
            W.resize(o + 1);
        }

        // Sliding window for constant C_j'
        long long C_prefix = 1;
        for (int k = t; k <= n - 1; ++k) C_prefix = C_prefix * q % MOD * invs[k + 1] % MOD;
        long long C_suffix = 1;
        // C_j = p * C_prefix * C_suffix
        
        for (int j = n; j >= t; --j) {
            long long C_j_prime = p * C_prefix % MOD * C_suffix % MOD;
            for (int k = 0; k < W.size(); ++k) {
                S[k] = (S[k] + C_j_prime * W[k]) % MOD;
            }

            // Update for j-1
            if (j > t) {
                // Update W
                std::vector<long long> D_j_plus_o_inv = {1, v[j + o - 1]}; // (1-v_k z)^-1 ~ 1+v_k z
                for (int k = o; k > 0; --k) D_j_plus_o_inv[k] = (D_j_plus_o_inv[k] + v[j + o - 1] * D_j_plus_o_inv[k - 1]) % MOD;
                W = Poly::multiply(W, D_j_plus_o_inv);
                std::vector<long long> D_j = {1, (MOD - v[j]) % MOD};
                W = Poly::multiply(W, D_j);
                W.resize(o+1);
                
                // Update C
                C_prefix = C_prefix * (j) % MOD * modInverse(q) % MOD;
                C_suffix = C_suffix * q % MOD * invs[j + o] % MOD;
            }
        }
    }
    
    // Compute G_t(z)
    int shift_g = n - t + 1;
    std::vector<long long> N_g;
    if (t <= n) {
       N_g.resize(shift_g, 0);
       for(long long coeff : S) N_g.push_back(coeff);
    }
    std::vector<long long> prob_l2_num = Poly::multiply(N_g, P_inv);

    // Final answer
    for (int i = 1; i <= n; ++i) {
        long long ans = prob_l1[i];
        if (i < prob_l2_num.size()) {
            ans = (ans + prob_l2_num[i]) % MOD;
        }
        std::cout << ans << "\n";
    }

    return 0;
}

复杂度分析

  • 时间复杂度: O(NlogN+nologo)O(N \log N + n \cdot o \log o)

    • 计算分母多项式 P(z)P(z) 使用分治NTT,复杂度为 O(Nlog2N)O(N \log^2 N),但通过更优化的实现可以做到 O(NlogN)O(N \log N)。这里 NN 是多项式度数,约等于 n+o
    • 多项式求逆的复杂度是 O(NlogN)O(N \log N)
    • 计算 Gt(z)G_t(z) 分子的循环部分,有 ntn-t 次迭代,每次迭代中涉及滑动窗口,主要是度数为 oo 的多项式乘法,复杂度是 O(ologo)O(o \log o)。总共是 O(nologo)O(n \cdot o \log o)。如果用 O(o2)O(o^2) 的朴素乘法代替NTT,则是 O(no2)O(n \cdot o^2)。我的代码里偷懒用了 O(ologo)O(o \log o) 的NTT乘法。
    • 最终的各项计算都是 O(NlogN)O(N \log N) 级别。
    • 主导复杂度的是 O(NlogN+nologo)O(N \log N + n \cdot o \log o)
  • 空间复杂度: O(N)O(N)

    • 主要空间开销来自于存储各个多项式,以及NTT需要的辅助数组,与多项式的最高次(即 n+on+o)成正比。

知识点总结

这道题是一道非常硬核的生成函数应用题,将它攻克下来,Master一定能收获满满,喵!

  1. 概率与生成函数: 对于复杂的概率DP问题,特别是转移与状态值本身相关时,生成函数是一个强有力的工具。将轮数作为变量,可以把递推关系转化为代数方程。
  2. 多项式全家桶:
    • NTT (快速数论变换): 在模意义下快速计算多项式乘法(卷积)的核心算法。
    • 多项式求逆: 求解 A(x)B(x)1(modxn)A(x)B(x) \equiv 1 \pmod{x^n} 的基本操作,通常用牛顿迭代法实现。
    • 分治求多项式乘积: 高效计算 (1ciz)\prod (1-c_i z) 形式的多项式。
  3. 滑动窗口思想: 在计算 Gt(z)G_t(z) 的分子时,我们巧妙地将一个复杂乘积的维护看作一个滑动窗口,每次迭代只需要 O(o)O(o) 的代价来更新,避免了重复计算。
  4. 化繁为简的数学推导: 解题的关键在于将复杂的递推关系一步步转化为可以计算的多项式形式。虽然过程可能有点绕,但这是通往正确解法的必经之路,呐!

希望这篇题解能帮助到你,Master!要继续加油哦!喵~ (ฅ'ω'ฅ)