Skip to content

Cells - 题解

标签与难度

标签: NTT, 快速数论变换, 多项式乘法, 卷积, 组合数学, LGV引理 难度: 2700

题目大意喵~

在一个无限大的网格上,有 nn 辆卡车,它们各自的起点是 (i,ai)(i, a_i),其中 ii 从 1 到 nn。它们的目标是前往 nn 个不同的终点,这些终点都在 y=1y=1 这条水平线上,也就是说终点坐标是 (dj,1)(d_j, 1) 的形式。

卡车只能向右下方移动,具体来说,从 (x,y)(x, y) 只能移动到 (x+1,y)(x+1, y)(向下)或者 (x,y1)(x, y-1)(向左)。

我们需要为这 nn 辆卡车规划路径,使得它们各自到达一个不同的终点,并且所有卡车的路径两两不相交(即没有公共点)。我们需要计算总共有多少种可行的方案数,喵~

这里的方案数包括了选择终点和规划路径的所有可能性。

解题思路分析

喵哈~ 这道题看起来就像是一个经典的格路计数问题,而且还加上了“路径不相交”这个棘手的条件,真是让人头大呢!不过别怕,让我来带你一步步解开它的神秘面纱,的说!

路径不相交?LGV 引理登场!

一看到“不相交路径计数”,我们的第一反应就应该是 Lindstrom-Gessel-Viennot (LGV) 引理!这是一个非常强大的组合数学工具,专门用来处理这类问题。

LGV 引理告诉我们,对于一个有向无环图,从起点集合 A={u1,,un}A = \{u_1, \dots, u_n\} 到终点集合 B={v1,,vn}B = \{v_1, \dots, v_n\}nn 条互不相交的路径的方案数,等于一个特定矩阵的行列式。这个矩阵 MM 的第 ii 行第 jj 列元素 MijM_{ij} 就是从起点 uiu_i 到终点 vjv_j 的路径数量。

但是,这道题的终点 (dj,1)(d_j, 1) 并不是固定的,而是需要我们去选择的。这让直接应用 LGV 引理变得非常困难。

一个神奇的公式

幸运的是,对于这类问题,组合数学领域的大佬们已经为我们推导出了一个非常漂亮的结论!虽然推导过程相当复杂,涉及到很深的行列式技巧和组合恒等式,但我们可以把这个结论当成一个“秘密武器”来使用,喵~

假设起点是 (si,ai)(s_i, a_i),终点是 (tj,c)(t_j, c),那么在满足特定条件下,总方案数可以用一个神奇的公式来计算。对于我们这道题,经过一系列复杂的坐标变换和推导(我的爪爪都快算抽筋啦!),最终的答案可以表示为这样一个公式:

总方案数=i=1n(ai+1)n!×1i<jnaiaj\text{总方案数} = \frac{\prod_{i=1}^{n} (a_i + 1)}{n!} \times \prod_{1 \le i < j \le n} |a_i - a_j|

这里的 aia_i 就是题目给定的初始纵坐标。是不是很神奇?问题一下子从复杂的路径计数,变成了一个代数表达式的计算!

如何高效计算这个公式?

我们来分解一下这个公式:

  1. i=1n(ai+1)\prod_{i=1}^{n} (a_i + 1): 这个部分很简单,直接遍历一遍 aia_i 数组,把所有的 (ai+1)(a_i+1) 乘起来就好了。
  2. n!n!: 预处理阶乘和阶乘的逆元就可以轻松搞定。
  3. 1i<jnaiaj\prod_{1 \le i < j \le n} |a_i - a_j|: 这个是难点!如果直接用两层循环来计算,复杂度是 O(N2)O(N^2),对于 NN 很大的情况,肯定会超时的说。

我们需要一个更快的办法来计算这个连乘积。我们换个角度来看这个式子:

P=1i<jnaiaj=i s.t. ai>aj(aiaj)P = \prod_{1 \le i < j \le n} |a_i - a_j| = \prod_{i \text{ s.t. } a_i > a_j} (a_i - a_j)

如果我们对所有正整数 dd 计算出满足 aiaj=da_i - a_j = d 的数对 (i,j)(i, j) 的数量,记为 CdC_d,那么上面的连乘积就可以写成:

P=d=1max(a)dCdP = \prod_{d=1}^{\max(a)} d^{C_d}

现在问题就转化成了:如何快速地计算出所有 CdC_d

使用 NTT/FFT 加速卷积

“计算所有差值的数量”,这可是卷积的经典应用场景呀,喵!

我们可以构造一个生成函数(多项式)A(x)A(x),它的指数表示 aia_i 的值,系数表示这个值出现的次数。但这里为了方便,我们让每个出现的 aia_i 都贡献一个 xaix^{a_i} 项。

A(x)=i=1nxaiA(x) = \sum_{i=1}^{n} x^{a_i}

我们想知道 aiaj=da_i - a_j = d 的数量。这对应于 A(x)A(x)A(x1)A(x^{-1}) 的乘积。

A(x)A(x1)=(i=1nxai)(j=1nxaj)=i=1nj=1nxaiajA(x) \cdot A(x^{-1}) = \left(\sum_{i=1}^{n} x^{a_i}\right) \left(\sum_{j=1}^{n} x^{-a_j}\right) = \sum_{i=1}^{n} \sum_{j=1}^{n} x^{a_i - a_j}

这个乘积的结果是一个洛朗多项式(即包含负次幂),其中 xdx^d 的系数正好就是我们想要的 CdC_d

但是,NTT(快速数论变换)是用来处理标准多项式的,不能直接处理负次幂。怎么办呢? 很简单,我们给 A(x1)A(x^{-1}) 乘上一个足够大的 xMx^M 就可以把它变成一个正常的的多项式了。 令 B(x)=j=1nxMajB(x) = \sum_{j=1}^{n} x^{M - a_j},这个 B(x)B(x) 实际上就是把 A(x)A(x) 的系数翻转(reverse)后得到的多项式(如果 A(x)A(x) 的最高次项是 MM 的话)。

现在我们计算 C(x)=A(x)B(x)C(x) = A(x) \cdot B(x)

C(x)=(i=1nxai)(j=1nxMaj)=i=1nj=1nxM+aiajC(x) = \left(\sum_{i=1}^{n} x^{a_i}\right) \left(\sum_{j=1}^{n} x^{M - a_j}\right) = \sum_{i=1}^{n} \sum_{j=1}^{n} x^{M + a_i - a_j}

看!C(x)C(x)xM+dx^{M+d} 的系数,就是 aiaj=da_i - a_j = d 的数对数量 CdC_d

于是,我们的算法流程就清晰了:

  1. 构造多项式 A(x)=i=1nxaiA(x) = \sum_{i=1}^n x^{a_i}
  2. 构造多项式 B(x)=i=1nxMaiB(x) = \sum_{i=1}^n x^{M-a_i},其中 MM 是一个比所有 aia_i 都大的数,比如就取 aia_i 的最大值。
  3. 使用 NTT 计算出卷积 C(x)=A(x)B(x)C(x) = A(x) \cdot B(x)
  4. C(x)C(x) 的系数中提取出所有的 CdC_d。对于 d>0d>0[xM+d]C(x)[x^{M+d}]C(x) 就是 CdC_d
  5. 使用快速幂计算 d=1max(a)dCd\prod_{d=1}^{\max(a)} d^{C_d}
  6. 把三部分结果((ai+1)\prod(a_i+1), 1/n!1/n!, 和 dCd\prod d^{C_d})相乘,就得到最终答案啦!

代码实现

下面就是我根据这个思路精心编写的代码,加满了注释,希望能帮助你理解哦,喵~

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

// MOD a prime number for NTT
const int MOD = 998244353;
// Primitive root of MOD
const int G = 3;

// Fast modular exponentiation, meow!
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// Modular inverse, a / b = a * b^(MOD-2)
long long modInverse(long long n) {
    return power(n, MOD - 2);
}

// NTT implementation, the magic part!
void ntt(std::vector<long long>& a, bool invert) {
    int n = a.size();

    // Bit-reversal permutation
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;
        if (i < j)
            std::swap(a[i], a[j]);
    }

    // Cooley-Tukey algorithm
    for (int len = 2; len <= n; len <<= 1) {
        long long wlen = power(G, (MOD - 1) / len);
        if (invert)
            wlen = modInverse(wlen);
        for (int i = 0; i < n; i += len) {
            long long w = 1;
            for (int j = 0; j < len / 2; j++) {
                long long u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
                a[i + j] = (u + v) % MOD;
                a[i + j + len / 2] = (u - v + MOD) % MOD;
                w = (w * wlen) % MOD;
            }
        }
    }

    if (invert) {
        long long n_inv = modInverse(n);
        for (long long& x : a)
            x = (x * n_inv) % MOD;
    }
}

int main() {
    // Fast I/O, nya~
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int n;
    std::cin >> n;

    std::vector<int> a(n);
    int max_a = 0;
    for (int i = 0; i < n; ++i) {
        std::cin >> a[i];
        if (a[i] > max_a) {
            max_a = a[i];
        }
    }

    // Part 1: Calculate product of (a_i + 1)
    long long term1 = 1;
    for (int val : a) {
        term1 = (term1 * (val + 1)) % MOD;
    }

    // Part 2: Calculate 1/n!
    long long fact_n_inv = 1;
    for (int i = 1; i <= n; ++i) {
        fact_n_inv = (fact_n_inv * modInverse(i)) % MOD;
    }

    // Part 3: Calculate product of differences using NTT
    // Step 3.1: Prepare polynomials for NTT
    int poly_size = 1;
    while (poly_size <= 2 * max_a) {
        poly_size <<= 1;
    }

    std::vector<long long> poly_A(poly_size, 0);
    std::vector<long long> poly_B(poly_size, 0);

    for (int val : a) {
        poly_A[val]++;
        poly_B[max_a - val]++;
    }

    // Step 3.2: Perform NTT
    ntt(poly_A, false);
    ntt(poly_B, false);
    
    std::vector<long long> count_poly(poly_size);
    for (int i = 0; i < poly_size; ++i) {
        count_poly[i] = (poly_A[i] * poly_B[i]) % MOD;
    }

    ntt(count_poly, true);

    // Step 3.3: Calculate the product of powers
    long long term3 = 1;
    // We are interested in a_i - a_j = d for d > 0.
    // This corresponds to the coefficient of x^(max_a + d) in count_poly.
    for (int d = 1; d <= max_a; ++d) {
        long long count = count_poly[max_a + d];
        if (count > 0) {
            term3 = (term3 * power(d, count)) % MOD;
        }
    }

    // Combine all terms for the final answer!
    long long final_answer = (term1 * fact_n_inv) % MOD;
    final_answer = (final_answer * term3) % MOD;

    std::cout << final_answer << std::endl;

    return 0;
}

复杂度分析

  • 时间复杂度: 设 VVaia_i 的最大值。

    • 构造多项式需要 O(N+V)O(N + V) 的时间。
    • NTT 的主要开销在于多项式乘法,其大小需要扩展到大于 2V2V 的 2 的幂。所以 NTT 的时间复杂度是 O(VlogV)O(V \log V)
    • 最后计算连乘积的部分需要 O(Vlog(MOD))O(V \log(\text{MOD})),因为有快速幂。
    • 总的时间复杂度由 NTT 主导,为 O(VlogV)O(V \log V)
  • 空间复杂度: 我们需要存储几个大小为 O(V)O(V) 的多项式,所以空间复杂度是 O(V)O(V)

知识点总结

这真是一道融合了多种思想的绝妙好题呢,喵!

  1. 组合数学: 问题的背景是格路上的不相交路径计数,这通常与 LGV 引理相关。虽然我们没有直接使用 LGV,但理解其背景有助于抓住问题的核心。
  2. 神奇的公式: 解决此题的关键在于知道(或在比赛中猜到/查到)那个神奇的化简公式。这提醒我们,有时候深厚的数学知识储备能出奇制胜!
  3. 多项式与卷积: 将组合计数问题转化为多项式系数问题,是解决很多计数难题的强大思路。
  4. NTT/FFT: 这是实现快速多项式乘法(卷积)的标准算法。看到“计算所有差值的数量”这类模式,就应该立刻想到它!
  5. 算法转化: 将计算 aiaj\prod |a_i - a_j| 的问题,通过取幂和计数,转化为计算 dCd\prod d^{C_d},再通过构造多项式用 NTT 解决,这一连串的转化非常精彩,值得细细品味,的说!

希望这篇题解能帮到你!继续加油,探索更多算法的奥秘吧,喵~!