Skip to content

排列 - 题解

标签与难度

标签: 动态规划, 分治, NTT, 多项式, 卷积, 生成函数, 组合计数 难度: 2600

题目大意喵~

各位Master,大家好呀!咱是我小助手,最喜欢的就是和大家一起挑战算法题啦,喵~

这道题是这样的:我们有一个从 11nn 的数字组成的排列 pp。我们可以对这个排列做一种特殊的操作:选择一个位置 ii1i<n1 \le i < n),如果 pip_ipi+1p_{i+1} 这两个相邻的数,它们的差的绝对值大于等于 22(也就是 pipi+12|p_i - p_{i+1}| \ge 2),我们就可以交换它们的位置。

这个操作可以进行任意多次。现在的问题是,从最初的排列 pp 出发,我们总共能变出多少种不同的排列呢?答案需要对 998244353998244353 取模哦!

举个栗子:如果 n=3n=3,排列是 (1, 3, 2)

  • p1=1,p2=3p_1=1, p_2=313=22|1-3|=2 \ge 2,所以可以交换它们,得到 (3, 1, 2)
  • p2=3,p3=2p_2=3, p_3=232=1<2|3-2|=1 < 2,所以不能交换它们。 从 (3, 1, 2) 出发,也只能交换 31 回到 (1, 3, 2)。所以总共能得到 2 种排列。

解题思路分析

这道题看起来是在问图的连通性问题,所有能相互到达的排列构成一个连通块,我们要求的就是这个连通块的大小,喵~ 但直接去建图然后搜索,状态空间是 n!n! 级别的,肯定会超时的说。

所以,我们需要找到一些不变量或者规律来简化问题!

关键洞察:不变的相对顺序

操作的核心是交换相邻元素。我们想想什么情况下不能交换。那就是当 pipi+1=1|p_i - p_{i+1}| = 1 时,也就是当两个数字是连续整数时(比如 3 和 4,或者 6 和 5)。

那么,对于任意一对连续整数 kk+1,它们的相对顺序会改变吗? 比如说,排列里 kk+1 的前面。要想让 k+1跑到 k 的前面,它们俩必须在某个时刻成为邻居,然后交换位置。就像这样:... k, k+1 ... -> ... k+1, k ...。 但是!这个操作恰好是被禁止的!因为 k(k+1)=1|k - (k+1)| = 1

所以,对于任意 k[1,n1]k \in [1, n-1],如果初始排列中 kk 的位置在 k+1k+1 的位置之前,那么在任何可以通过操作到达的排列中,kk 的位置永远在 k+1k+1 的位置之前。反之亦然。

这个性质是解题的突破口!它告诉我们,所有可到达的排列都必须遵守这 n1n-1 个关于连续整数的相对位置关系。 例如,如果初始排列是 (1, 4, 2, 3)

  • pos[1] < pos[2] (1在2前面)
  • pos[2] > pos[3] (2在3后面)
  • pos[3] < pos[4] (3在4前面) 那么我们能生成的所有排列,都必须满足这三个条件。

问题转化:排列计数

现在问题就变成了一个纯粹的组合计数问题: 给定 n1n-1 个约束条件,形如 pos[k] < pos[k+1]pos[k] > pos[k+1],求有多少个 11nn 的排列满足所有这些条件?

这其实是一个经典问题:统计具有给定“签名”(up/down序列)的排列数量。 为了方便,我们把这些约束关系转换成一个签名序列 SS。如果 pos[k] < pos[k+1],我们记为 Sk=<S_k = '<';如果 pos[k] > pos[k+1],我们记为 Sk=>S_k = '>'.

一个有趣且重要的事实是:满足 pos[k] < pos[k+1] 这种约束的排列数量,和满足 p_k < p_{k+1} 这种约束的排列数量是一样多的(对于同一个签名序列)。这可以通过一个叫做 "Foata's correspondence" 的双射来证明,不过我们在这里只需要知道这个结论就好啦,喵~

所以,我们的问题最终等价于: 给定一个长度为 n1n-1 的签名序列 SS,求有多少个排列 pp 满足 pi<pi+1p_i < p_{i+1} (如果 Si=<S_i='<')或 pi>pi+1p_i > p_{i+1} (如果 Si=>S_i='>'

生成函数与分治NTT

这个问题可以用动态规划在 O(n2)O(n^2) 的时间内解决,但对于这道题的 nn 来说太慢了。我们需要更强大的工具——生成函数!

这道题的数学推导有点复杂呢,喵~ 直接从第一性原理推导它的生成函数关系式需要不少篇幅。简单来说,这类排列计数问题可以用指数生成函数(EGF)来解决。通过一系列推导(咱在这里就省略掉啦,不然小脑袋要过载了~),可以得到一个关于系数的递推关系。

fif_i 是一个辅助序列中的第 ii 项,我们想通过 f0,,fn1f_0, \dots, f_{n-1} 来求出 fnf_n。最终答案就是 n!fnn! \cdot f_n。 这个递推关系是:

fi=j=0i1I(Sj=>)fj1(ij)!(i>0)f_i = -\sum_{j=0}^{i-1} \mathbb{I}(S_j = '>') \cdot f_j \cdot \frac{1}{(i-j)!} \quad (i > 0)

其中,I(Sj=>)\mathbb{I}(S_j = '>') 是一个指示函数,当 Sj=>S_j = '>' 时为 11,否则为 00。 初始条件是 f0=1f_0 = 1。 (注:参考代码中为了实现的方便,对这个公式做了一些符号上的变换,但本质是一样的)。

这个公式是不是很眼熟?它是一个卷积的形式!fif_i 由所有之前的 fjf_j 和一个只与 iji-j 有关的项 1(ij)!\frac{1}{(i-j)!} 组合而成。

形如 fi=j=0i1fjgijf_i = \sum_{j=0}^{i-1} f_j g_{i-j} 的递推式,我们可以用分治+NTT来优化计算。 具体做法是:

  1. 分治: 我们定义一个函数 solve(l, r) 来计算区间 [l, r) 内的 ff 值。
  2. 首先,递归调用 solve(l, mid),计算出左半区间 [l, mid) 的所有 ff 值。
  3. 卷积: 左半区间的 fjf_j (j[l,mid)j \in [l, \text{mid})) 会对右半区间的 fif_i (i[mid,r)i \in [\text{mid}, r)) 产生贡献。这个贡献正好是 j=lmid1I(Sj=>)fj1(ij)!\sum_{j=l}^{\text{mid}-1} \mathbb{I}(S_j='>') f_j \cdot \frac{1}{(i-j)!}。我们可以把 [fl,,fmid1][f_l, \dots, f_{\text{mid}-1}]('>'位置)和 [11!,12!,][\frac{1}{1!}, \frac{1}{2!}, \dots] 这两个序列做卷积(用NTT加速),然后把结果加到右半区间的 ff 值上。
  4. 最后,递归调用 solve(mid, r)。这个调用会处理右半区间内部的贡献(即 j,ij, i 都在 [mid, r) 内的情况)。

这个算法的时间复杂度是 T(N)=2T(N/2)+O(NlogN)T(N) = 2T(N/2) + O(N \log N),解出来是 O(Nlog2N)O(N \log^2 N),对于本题的数据范围来说就完全没问题啦!

代码实现

下面是咱根据这个思路重构的 C++ 代码,加了详细的注释,希望能帮助 Master 理解哦!

cpp
#include <iostream>
#include <vector>
#include <string>
#include <numeric>
#include <algorithm>

// MOD and primitive root for NTT
const int MOD = 998244353;
const int PRIMITIVE_ROOT = 3;

// 快速幂,喵~
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// 模块逆元
long long modInverse(long long n) {
    return power(n, MOD - 2);
}

// NTT实现,这是魔法的核心哦!
void ntt(std::vector<long long>& a, bool invert) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;
        if (i < j)
            std::swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        long long wlen = power(PRIMITIVE_ROOT, (MOD - 1) / len);
        if (invert) wlen = modInverse(wlen);
        for (int i = 0; i < n; i += len) {
            long long w = 1;
            for (int j = 0; j < len / 2; j++) {
                long long u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
                a[i + j] = (u + v) % MOD;
                a[i + j + len / 2] = (u - v + MOD) % MOD;
                w = (w * wlen) % MOD;
            }
        }
    }

    if (invert) {
        long long n_inv = modInverse(n);
        for (long long& x : a)
            x = (x * n_inv) % MOD;
    }
}

// 卷积函数
std::vector<long long> convolution(std::vector<long long> a, std::vector<long long> b) {
    int sz = 1;
    while (sz < a.size() + b.size()) sz <<= 1;
    a.resize(sz);
    b.resize(sz);

    ntt(a, false);
    ntt(b, false);

    std::vector<long long> c(sz);
    for (int i = 0; i < sz; i++)
        c[i] = (a[i] * b[i]) % MOD;

    ntt(c, true);
    return c;
}

int n;
std::vector<char> signature;
std::vector<long long> f;
std::vector<long long> inv_factorial;

// 分治函数
void cdq_solve(int l, int r) {
    if (r - l <= 1) {
        return;
    }

    int mid = l + (r - l) / 2;
    cdq_solve(l, mid); // 1. 递归解决左半部分

    // 2. 计算左半部分对右半部分的贡献
    int len = r - l;
    std::vector<long long> left_poly(mid - l);
    for (int i = 0; i < mid - l; ++i) {
        if (signature[l + i] == '>') {
            left_poly[i] = f[l + i];
        }
    }

    std::vector<long long> right_poly(len);
    for (int i = 1; i < len; ++i) { // 1/(i-j)! -> i-j >= 1
        right_poly[i] = inv_factorial[i];
    }

    std::vector<long long> conv_res = convolution(left_poly, right_poly);

    // 3. 将贡献累加到右半部分
    for (int i = mid; i < r; ++i) {
        long long contribution = conv_res[i - l];
        f[i] = (f[i] - contribution + MOD) % MOD;
    }

    cdq_solve(mid, r); // 4. 递归解决右半部分
}


int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    std::cin >> n;
    std::vector<int> p(n), pos(n + 1);
    for (int i = 0; i < n; ++i) {
        std::cin >> p[i];
        pos[p[i]] = i;
    }

    // 生成签名序列
    // signature[i] 对应 p_i 和 p_{i+1} 的关系,或者说 k=i 和 k=i+1 的关系
    // 为了和递推公式对应,我们让 signature[k] 表示 k 和 k+1 的关系
    signature.resize(n + 1);
    signature[0] = '>'; // 这是一个小技巧,为了让递推在边界情况下也成立
    for (int k = 1; k < n; ++k) {
        if (pos[k] < pos[k + 1]) {
            signature[k] = '<';
        } else {
            signature[k] = '>';
        }
    }

    // 预处理阶乘的逆元
    int max_val = 1;
    while(max_val <= n) max_val <<= 1;
    
    inv_factorial.resize(max_val + 1);
    std::vector<long long> factorial(max_val + 1);
    factorial[0] = 1;
    for(int i = 1; i <= max_val; ++i) {
        factorial[i] = (factorial[i-1] * i) % MOD;
    }
    inv_factorial[max_val] = modInverse(factorial[max_val]);
    for(int i = max_val - 1; i >= 0; --i) {
        inv_factorial[i] = (inv_factorial[i+1] * (i+1)) % MOD;
    }
    
    // 初始化f数组并开始分治
    f.resize(max_val + 1, 0);
    f[0] = 1;

    cdq_solve(0, max_val);

    long long ans = (f[n] * factorial[n]) % MOD;
    std::cout << (ans + MOD) % MOD << std::endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(Nlog2N)O(N \log^2 N) 我们使用分治算法,状态转移方程是 T(N)=2T(N/2)+O(NlogN)T(N) = 2T(N/2) + O(N \log N)。其中 O(NlogN)O(N \log N) 是长度为 O(N)O(N) 的多项式卷积(NTT)的复杂度。根据主定理,这个递推式解出来的时间复杂度就是 O(Nlog2N)O(N \log^2 N),喵~

  • 空间复杂度: O(N)O(N) 我们需要 O(N)O(N) 的空间来存储阶乘、逆元、f 数组以及NTT临时使用的数组。

知识点总结

这真是一道融合了多种思想的超棒的题目呀!让咱来总结一下吧:

  1. 问题转化: 解决组合问题的关键一步,常常是找到问题的不变量,将其转化为一个我们更熟悉的模型。本题中,我们将一个动态的排列操作问题,转化为了一个静态的、满足特定约束的排列计数问题。
  2. 排列签名: “pi<pi+1p_i < p_{i+1}” 和 “pos[i]<pos[i+1]pos[i] < pos[i+1]” 这两种约束定义的排列计数问题是等价的,这是一个非常有用(但不太直观)的组合学结论。
  3. 生成函数: 对于复杂的计数问题,特别是与排列组合相关的,生成函数(尤其是指数生成函数)是推导递推关系的强大武器。
  4. 分治+NTT: 对于形如 fi=j<ifjgijf_i = \sum_{j<i} f_j g_{i-j} 的卷积递推式,分治+NTT是一种经典的优化技巧,能将复杂度从 O(N2)O(N^2) 降至 O(Nlog2N)O(N \log^2 N)

希望这篇题解能帮到你,Master!如果还有不明白的地方,随时可以再来问咱哦,喵~