Skip to content

集合划分 - 题解

标签与难度

标签: 动态规划, 分治, NTT, 卷积, 组合计数, 数论 难度: 2600

题目大意喵~

主人你好呀,喵~ 这道题是这样的:

我们有一个长度为 n+1n+1 的序列 a0,a1,,ana_0, a_1, \dots, a_n 和一个常数 KK。 首先,我们要理解一个集合的“价值”是什么。对于一个由 {0,1,,n}\{0, 1, \dots, n\} 的下标组成的集合 S={s1,s2,,sm}S = \{s_1, s_2, \dots, s_m\},它的价值被定义为 Mex({as1,as2,,asm})Mex(\{a_{s_1}, a_{s_2}, \dots, a_{s_m}\})。这里的 MexMex (Minimum Excluded value) 是指“最小的未出现的非负整数”。比如说,Mex({0,1,3,4})Mex(\{0, 1, 3, 4\}) 就是 22 啦。

我们的任务是,要把全集 {0,1,,n}\{0, 1, \dots, n\} 划分成两个不相交的集合 S1S_1S2S_2(它们的并集是全集)。我们需要对每个 kk(从 00KK),计算出有多少种这样的划分方法,能够使得 S1S_1 的价值和 S2S_2 的价值之和恰好等于 kk

也就是说,如果 v1=Mex(S1)v_1 = Mex(S_1) 并且 v2=Mex(S2)v_2 = Mex(S_2),我们要对每个 k[0,K]k \in [0, K],求出满足 v1+v2=kv_1 + v_2 = k 的有序对 (S1,S2)(S_1, S_2) 的数量。

解题思路分析

这道题看起来好复杂呀,又是集合划分又是 Mex 的,猫猫的脑袋都要绕晕了呢 >.<。但是别怕,只要我们一步一步来,肯定能想明白的!

第一步:理解 Mex 和划分条件

首先,我们来仔细分析一下 Mex 的定义。Mex(S) = v 意味着两件事:

  1. 对于所有非负整数 j<vj < v,数字 jj 必须在集合 {assS}\{a_s \mid s \in S\} 中出现过。
  2. 数字 vv 本身,绝对不能在集合 {assS}\{a_s \mid s \in S\} 中出现。

为了方便讨论,我们先预处理一下,计算出每个数值 jj 在序列 aa 中出现了多少次。我们用 cjc_j 表示值等于 jj 的元素个数,也就是 cj={iai=j}c_j = |\{i \mid a_i = j\}|

现在,我们想要求有多少对 (S1,S2)(S_1, S_2) 使得 Mex(S1)=v1Mex(S_1)=v_1Mex(S2)=v2Mex(S_2)=v_2。直接计算满足这个条件的划分数(我们记作 P(v1,v2)P(v_1, v_2))非常困难,因为它同时对 S1S_1S2S_2 提出了严格的等于条件。

第二步:化“等于”为“大于等于”

在组合计数问题中,当“等于”不好处理时,一个常见的技巧是先计算“大于等于”的情况,然后通过容斥原理变回来。

我们定义一个辅助函数 f(v1,v2)f(v_1, v_2),表示满足 Mex(S1)v1Mex(S_1) \ge v_1Mex(S2)v2Mex(S_2) \ge v_2 的划分方案数。

Mex(S1)v1Mex(S_1) \ge v_1 的条件是:对于所有 j<v1j < v_1,数值 jj 必须在 {assS1}\{a_s \mid s \in S_1\} 中出现。 Mex(S2)v2Mex(S_2) \ge v_2 的条件是:对于所有 j<v2j < v_2,数值 jj 必须在 {assS2}\{a_s \mid s \in S_2\} 中出现。

现在我们来考虑如何分配所有下标 i{0,,n}i \in \{0, \dots, n\}S1S_1S2S_2 中,来满足 Mex(S1)v1Mex(S_1) \ge v_1Mex(S2)v2Mex(S_2) \ge v_2。我们按 aia_i 的值来分类讨论:

  • 对于一个值 j<min(v1,v2)j < \min(v_1, v_2):我们既要保证 jjS1S_1 的值集合里,也要保证它在 S2S_2 的值集合里。这意味着,所有值为 jj 的下标(也就是 cjc_j 个)不能全部被分到 S1S_1,也不能全部被分到 S2S_2。总共有 2cj2^{c_j} 种分配方式,排除掉这两种极端情况,就有 2cj22^{c_j} - 2 种方式。
  • 对于一个值 jj 满足 min(v1,v2)j<max(v1,v2)\min(v_1, v_2) \le j < \max(v_1, v_2):假设 v1v2v_1 \le v_2,那么对于 v1j<v2v_1 \le j < v_2,我们只需要保证 jjS2S_2 的值集合里。这意味着 cjc_j 个下标不能全部被分到 S1S_1。所以有 2cj12^{c_j} - 1 种方式。
  • 对于一个值 jmax(v1,v2)j \ge \max(v_1, v_2):没有任何限制!cjc_j 个下标可以任意分配,有 2cj2^{c_j} 种方式。

把它们全部乘起来,我们就得到了 f(v1,v2)f(v_1, v_2) 的表达式。假设 v1v2v_1 \le v_2

f(v1,v2)=(j=0v11(2cj2))(j=v1v21(2cj1))(j=v22cj)f(v_1, v_2) = \left( \prod_{j=0}^{v_1-1} (2^{c_j}-2) \right) \cdot \left( \prod_{j=v_1}^{v_2-1} (2^{c_j}-1) \right) \cdot \left( \prod_{j=v_2}^{\infty} 2^{c_j} \right)

有了 f(v1,v2)f(v_1, v_2),我们就可以用容斥原理(或者说是二维差分)来求 P(v1,v2)P(v_1, v_2) 了! P(v1,v2)=(ways for Mex1v1,Mex2v2)(ways for Mex1v1+1,Mex2v2)(ways for Mex1v1,Mex2v2+1)+(ways for Mex1v1+1,Mex2v2+1)P(v_1, v_2) = (\text{ways for } Mex_1 \ge v_1, Mex_2 \ge v_2) - (\text{ways for } Mex_1 \ge v_1+1, Mex_2 \ge v_2) - (\text{ways for } Mex_1 \ge v_1, Mex_2 \ge v_2+1) + (\text{ways for } Mex_1 \ge v_1+1, Mex_2 \ge v_2+1)

P(v1,v2)=f(v1,v2)f(v1+1,v2)f(v1,v2+1)+f(v1+1,v2+1)P(v_1, v_2) = f(v_1, v_2) - f(v_1+1, v_2) - f(v_1, v_2+1) + f(v_1+1, v_2+1)

这个式子看起来很可怕,但是经过一番神奇的代数化简(把 ff 的表达式代入,提取公因式),对于 v1<v2v_1 < v_2,我们可以得到一个非常简洁的结果,喵~

P(v1,v2)=(j=0v11(2cj2))(j=v1+1v21(2cj1))(j=v2+12cj)P(v_1, v_2) = \left( \prod_{j=0}^{v_1-1} (2^{c_j}-2) \right) \cdot \left( \prod_{j=v_1+1}^{v_2-1} (2^{c_j}-1) \right) \cdot \left( \prod_{j=v_2+1}^{\infty} 2^{c_j} \right)

是不是很神奇?在 v1v_1v2v_2 位置的项,它们的贡献互相抵消,最后变成了 11

第三步:卷积!是卷积的味道!

我们的目标是求 Ansk=v1+v2=kP(v1,v2)Ans_k = \sum_{v_1+v_2=k} P(v_1, v_2)。 这个形式让我们立刻想到了多项式乘法(卷积)!

我们将 AnskAns_k 的计算分为三部分:v1<v2v_1 < v_2, v1>v2v_1 > v_2, 和 v1=v2v_1 = v_2

Case 1: v1<v2v_1 < v_2 我们想计算 v1+v2=k,v1<v2P(v1,v2)\sum_{v_1+v_2=k, v_1<v_2} P(v_1, v_2)。 这个 P(v1,v2)P(v_1, v_2) 的表达式里,连乘的范围 j=v1+1v21\prod_{j=v_1+1}^{v_2-1}v1v_1v2v_2 耦合在了一起,不方便直接卷积。但是我们可以把它拆开! 令 PrefAi=j=0i1(2cj2)PrefA_i = \prod_{j=0}^{i-1} (2^{c_j}-2)PrefBi=j=0i1(2cj1)PrefB_i = \prod_{j=0}^{i-1} (2^{c_j}-1)SuffCi=j=i2cjSuffC_i = \prod_{j=i}^{\infty} 2^{c_j}。 那么 P(v1,v2)=PrefAv1PrefBv21PrefBv1+1SuffCv2+1P(v_1, v_2) = PrefA_{v_1} \cdot \frac{PrefB_{v_2-1}}{PrefB_{v_1+1}} \cdot SuffC_{v_2+1}=(PrefAv1(PrefBv1+1)1)(PrefBv21SuffCv2+1)= (PrefA_{v_1} \cdot (PrefB_{v_1+1})^{-1}) \cdot (PrefB_{v_2-1} \cdot SuffC_{v_2+1}) 看!我们成功地把 P(v1,v2)P(v_1, v_2) 分解成了只和 v1v_1 有关的项与只和 v2v_2 有关的项的乘积!

定义两个多项式(生成函数): A(x)=i=0K(j=0i1(2cj2)j=0i(2cj1))xiA(x) = \sum_{i=0}^{K} \left( \frac{\prod_{j=0}^{i-1} (2^{c_j}-2)}{\prod_{j=0}^{i} (2^{c_j}-1)} \right) x^iB(x)=j=0K((l=0j1(2cl1))(l=j+12cl))xjB(x) = \sum_{j=0}^{K} \left( (\prod_{l=0}^{j-1} (2^{c_l}-1)) \cdot (\prod_{l=j+1}^{\infty} 2^{c_l}) \right) x^j

那么 A(x)B(x)A(x) \cdot B(x) 的第 kk 项系数 [xk](A(x)B(x))=i+j=kAiBj=i+j=kP(i,j)[x^k](A(x)B(x)) = \sum_{i+j=k} A_i B_j = \sum_{i+j=k} P(i,j)。 哇!我们要求的和就是这个卷积的结果!我们可以用 NTT (快速数论变换) 在 O(KlogK)O(K \log K) 的时间里计算这个卷积。

但是,这个卷积计算了所有 i,ji,j 对的贡献。我们只想要 v1<v2v_1 < v_2 的。直接这样算会把 v1>v2v_1 > v_2v1=v2v_1 = v_2 的情况也算进去。 一个聪明的办法是使用分治 NTT (CDQ-NTT)。我们将 v1v_1v2v_2 的范围 [0, K] 一分为二。 solve(l, r) 的任务是计算所有 v1[l,mid],v2[mid+1,r]v_1 \in [l, mid], v_2 \in [mid+1, r] (以及对称的 v2[l,mid],v1[mid+1,r]v_2 \in [l, mid], v_1 \in [mid+1, r])的贡献。 在 solve(l, r) 中:

  1. 递归调用 solve(l, mid)solve(mid+1, r)
  2. 构造两个多项式,一个的系数是 v1[l,mid]v_1 \in [l, mid] 对应的项,另一个是 v2[mid+1,r]v_2 \in [mid+1, r] 对应的项。
  3. 用 NTT 计算它们的卷积,把结果加到全局的答案数组 Ans 中。

Case 2: v1>v2v_1 > v_2 由于 P(v1,v2)P(v_1, v_2)P(v2,v1)P(v_2, v_1) 的表达式结构是对称的,所以 v1>v2v_1 > v_2 的总贡献和 v1<v2v_1 < v_2 的总贡献是一样的。所以我们把 CDQ-NTT 的结果乘以 2 就好啦。

Case 3: v1=v2v_1 = v_2 最后,我们来处理对角线上的情况 P(i,i)P(i, i)。 通过类似的容斥推导,可以发现 P(i,i)P(i,i) 只有在 ci=0c_i=0 (即序列 aa 中不存在值 ii) 时才可能不为零。 当 ci=0c_i=0 时,要满足 Mex(S1)=iMex(S_1)=iMex(S2)=iMex(S_2)=i

  • 对于 j<ij < icjc_j 必须被划分到 S1S_1S2S_2 中,有 2cj22^{c_j}-2 种方法。
  • 对于 jij \ge i,由于值 ii 根本不存在,所以 MexMexii 的第二个条件(值 ii 未出现)自动满足。所以对 jij \ge i 的下标分配没有限制,有 2cj2^{c_j} 种方法。 所以,如果 ci=0c_i=0,则 P(i,i)=(j=0i1(2cj2))(j=i2cj)P(i,i) = \left( \prod_{j=0}^{i-1} (2^{c_j}-2) \right) \cdot \left( \prod_{j=i}^{\infty} 2^{c_j} \right)。 我们单独计算这些项,并加到对应的 Ans2iAns_{2i} 上。

总结一下我们的算法:

  1. 预处理出每个值的计数 cjc_j
  2. 预处理计算过程中需要用到的各种前缀积和后缀积,以及它们的逆元。
  3. 使用 CDQ-NTT 计算所有 v1<v2v_1 < v_2P(v1,v2)P(v_1, v_2)Ansv1+v2Ans_{v_1+v_2} 的贡献。
  4. 将 CDQ-NTT 的结果乘以 2,得到所有 v1v2v_1 \neq v_2 的贡献。
  5. 单独计算并累加所有 v1=v2v_1=v_2 的情况的贡献。
  6. 输出最终的 Ans 数组。

这样,一只聪明的我就把一个复杂的问题分解成可以解决的小块啦,喵~

代码实现

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

// 使用 long long 防止溢出
using ll = long long;

// 模数
const int MOD = 998244353;
// NTT的原根
const int G = 3;

// 快速幂,用于计算乘法逆元
ll power(ll base, ll exp) {
    ll res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

ll modInverse(ll n) {
    return power(n, MOD - 2);
}

// NTT 实现
void ntt(vector<ll>& a, bool invert) {
    int n = a.size();

    // 位逆序置换
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;
        if (i < j)
            swap(a[i], a[j]);
    }

    // 蝴蝶操作
    for (int len = 2; len <= n; len <<= 1) {
        ll wlen = power(G, (MOD - 1) / len);
        if (invert) wlen = modInverse(wlen);
        for (int i = 0; i < n; i += len) {
            ll w = 1;
            for (int j = 0; j < len / 2; j++) {
                ll u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
                a[i + j] = (u + v) % MOD;
                a[i + j + len / 2] = (u - v + MOD) % MOD;
                w = (w * wlen) % MOD;
            }
        }
    }

    if (invert) {
        ll n_inv = modInverse(n);
        for (ll& x : a)
            x = (x * n_inv) % MOD;
    }
}

const int MAX_N_K = 200005;

int n, K;
int a[MAX_N_K];
ll counts[MAX_N_K * 2]; // 值的范围可能比n大

// 预处理各种乘积
ll pref_termA[MAX_N_K * 2]; // prod(2^c - 2)
ll pref_termB[MAX_N_K * 2]; // prod(2^c - 1)
ll inv_pref_termB[MAX_N_K * 2];
ll suff_termC[MAX_N_K * 2]; // prod(2^c)

ll ans[MAX_N_K * 2];

// 分治NTT
void cdq_solve(int l, int r) {
    if (l >= r) {
        return;
    }
    int mid = l + (r - l) / 2;
    cdq_solve(l, mid);
    cdq_solve(mid + 1, r);

    // 构造多项式 A(x) 和 B(x)
    vector<ll> poly_A, poly_B;
    for (int i = l; i <= mid; ++i) {
        // A_i = (prefA_{i}) / (prefB_{i+1})
        ll term_A = (pref_termA[i] * inv_pref_termB[i + 1]) % MOD;
        poly_A.push_back(term_A);
    }
    for (int i = mid + 1; i <= r; ++i) {
        // B_i = prefB_i * suffC_{i+1}
        ll term_B = (pref_termB[i] * suff_termC[i + 1]) % MOD;
        poly_B.push_back(term_B);
    }

    int len_A = poly_A.size();
    int len_B = poly_B.size();
    if (len_A == 0 || len_B == 0) return;

    int fft_size = 1;
    while (fft_size < len_A + len_B) fft_size <<= 1;

    poly_A.resize(fft_size);
    poly_B.resize(fft_size);

    ntt(poly_A, false);
    ntt(poly_B, false);

    for (int i = 0; i < fft_size; ++i) {
        poly_A[i] = (poly_A[i] * poly_B[i]) % MOD;
    }

    ntt(poly_A, true);

    for (int i = 0; i < len_A; ++i) {
        for (int j = 0; j < len_B; ++j) {
            int k = (l + i) + (mid + 1 + j);
            if (k <= K) {
                // P(i,j)的贡献是 poly_A[i+j]
                // 乘以2是因为P(i,j)和P(j,i)贡献相同
                ans[k] = (ans[k] + 2 * poly_A[i + j]) % MOD;
            }
        }
    }
}


int main() {
    // 加速喵~
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> K;
    int max_val = 0;
    for (int i = 0; i <= n; ++i) {
        cin >> a[i];
        counts[a[i]]++;
        max_val = max(max_val, a[i]);
    }
    
    // 值的范围最大可能到K
    int limit = max(K, max_val) + 2;

    // 预处理
    pref_termA[0] = 1;
    pref_termB[0] = 1;
    for (int i = 0; i < limit; ++i) {
        ll p2c = power(2, counts[i]);
        pref_termA[i + 1] = (pref_termA[i] * (p2c - 2 + MOD)) % MOD;
        pref_termB[i + 1] = (pref_termB[i] * (p2c - 1 + MOD)) % MOD;
    }
    
    inv_pref_termB[limit] = modInverse(pref_termB[limit]);
    for (int i = limit - 1; i >= 0; --i) {
        ll p2c = power(2, counts[i]);
        inv_pref_termB[i] = (inv_pref_termB[i + 1] * (p2c - 1 + MOD)) % MOD;
    }
    
    suff_termC[limit] = 1;
    for (int i = limit - 1; i >= 0; --i) {
        suff_termC[i] = (suff_termC[i + 1] * power(2, counts[i])) % MOD;
    }

    // 分治NTT计算 v1 != v2 的情况
    cdq_solve(0, K);

    // 处理 v1 = v2 的情况
    for (int i = 0; 2 * i <= K; ++i) {
        if (counts[i] == 0) {
            ll term_diag = (pref_termA[i] * suff_termC[i]) % MOD;
            ans[2 * i] = (ans[2 * i] + term_diag) % MOD;
        }
    }

    for (int k = 0; k <= K; ++k) {
        cout << ans[k] << (k == K ? "" : " ");
    }
    cout << endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(Vlog2V)O(V \log^2 V), 其中 V=max(n,K)V = \max(n, K)

    • 预处理计数和前缀积需要 O(Vlog(count))O(V \log(\text{count})) 的时间(主要是快速幂)。
    • CDQ-NTT 的递归关系是 T(N)=2T(N/2)+O(NlogN)T(N) = 2T(N/2) + O(N \log N),解得 T(N)=O(Nlog2N)T(N) = O(N \log^2 N)。在这里 NNKK 的范围。
    • 处理对角线情况需要 O(K)O(K)
    • 所以总时间复杂度由 CDQ-NTT 主导。
  • 空间复杂度: O(V)O(V)

    • counts, pref_term, suff_term 等数组都需要 O(V)O(V) 的空间。
    • NTT 过程中需要 O(V)O(V) 的临时空间来存储多项式。

知识点总结

这真是一道精彩的题目,融合了好多有趣的知识点呢,喵!

  1. 组合计数与容斥原理: 题目的核心在于计数。当直接计算“恰好”很困难时,转而计算“至少”,再通过容斥原理(在这里是二维差分)得到精确解,是非常强大和常用的思路。
  2. 生成函数与卷积: 我们将复杂的求和式 i+j=kAiBj\sum_{i+j=k} A_i B_j 识别为卷积的形式,这是从组合问题迈向多项式算法的关键一步。
  3. NTT (快速数论变换): NTT 是在模意义下计算卷积的利器,是解决这类问题的标准工具。
  4. 分治 (CDQ): 对于一些不能直接用一次卷积解决,但具有良好分治结构的问题(例如本题中 v1<v2v_1 < v_2 的限制),CDQ 分治可以和数据结构或FFT/NTT等算法结合,有效地处理区间之间的贡献。
  5. 代数变形: 将 P(v1,v2)P(v_1, v_2) 的复杂表达式变形为可以卷积的形式,需要一定的代数推导能力。这是解题过程中最需要灵感和技巧的一步!

希望这篇题解能帮助到你,主人!如果还有不懂的地方,随时可以再来问我哦,喵~