Skip to content

??? - 题解

标签与难度

标签: 快速数论变换 (NTT), 生成函数, 分治, 动态规划 难度: 2600

题目大意喵~

一位叫 Zeeman 的朋友有一个长度为 nn 的数组 a={a1,,an}a = \{a_1, \dots, a_n\}nn 个整数集合 S1,,SnS_1, \dots, S_n 呐。数组 aa 中的每个元素 aia_i 的值都在 [1,k1)[-1, k-1) 范围内。

对于每个位置 ii,我们可以进行一次操作:将 aia_i 的值修改为集合 SiS_i 中的任意一个元素。这次操作的花费是 1,但有一个例外:如果 aia_i 原本就是 1-1,那么这次修改是免费的!

我们有一个神奇的函数 f(b)f(b),它能把一个整数数组 bb 变成一个单独的数字。它的工作方式是这样的:

  1. 如果数组 bb 只有一个元素,那结果就是这个元素。
  2. 如果数组 bb 有多个元素,就先计算所有元素的和 SUM。然后,把 SUM 写成 kk 进制数,它的每一位数字就组成了新的数组 bb'。最后,我们对这个新数组 bb' 再次调用 ff 函数,也就是计算 f(b)f(b')

Zeeman 想知道,对于每一个可能的最终结果 xx(从 00k1k-1),有多少种不同的修改方案,可以使得最终的数组 aa' 经过 f(a)f(a') 的计算后,得到的结果是 {x}\{x\},并且总花费不超过 1

也就是说,我们要么不花钱(只修改那些 ai=1a_i = -1 的位置),要么只花 1 块钱(额外选择一个 aj1a_j \ne -1 的位置进行修改)。

要把所有 x[0,k1)x \in [0, k-1) 的答案都算出来,并且对 998244353998244353 取模哦,喵~

解题思路分析

喵哈~!这道题看起来有点复杂,又是修改数组又是递归函数,但别怕,让我带你一步步拆解它!

关键的第一步:看穿 f(b)f(b) 的本质!

这个函数 f(b)f(b) 一直在求和、取 kk 进制位、再求和……这其实是一个经典的过程,和数根(Digital Root)非常像!

我们知道一个性质:一个数 NN 和它的 kk 进制下所有数字之和 SS,在模 k1k-1 的意义下是同余的。也就是 NS(modk1)N \equiv S \pmod{k-1}。 这是因为 N=dikiN = \sum d_i k^i, S=diS = \sum d_i,所以 NS=di(ki1)N-S = \sum d_i(k^i-1)。因为 ki1k^i-1 总能被 k1k-1 整除,所以 NSN-Sk1k-1 的倍数。

函数 f(b)f(b) 不断重复这个过程,直到数组长度为 1,也就是上一步的和小于 kk。设最终的结果是 yy,那么最初的数组元素总和 bi\sum b_i 和这个最终结果 yy 之间一定满足:

i=1nbiy(modk1)\sum_{i=1}^n b_i \equiv y \pmod{k-1}

这里有个小小的例外:

  • 如果 bi=0\sum b_i = 0,那么结果就是 00
  • 如果 bi>0\sum b_i > 0 并且 bi0(modk1)\sum b_i \equiv 0 \pmod{k-1},结果会是 k1k-1(因为结果是在 [0, k-1) 里的,而大于0的 k1k-1 的倍数最小就是 k1k-1)。
  • 对于其他情况,如果 bir(modk1)\sum b_i \equiv r \pmod{k-1}r[1,k2]r \in [1, k-2],那么结果就是 rr

这个性质是解题的基石!但是直接处理 sum=0sum>0 的情况有点麻烦。有没有更统一的方法呢?我们可以等到最后再用这个性质从总和反推结果,而不是在计算过程中就纠结于此。

第二步:用生成函数来表示“选择”

这道题的本质是“组合计数”,每个位置 ii 都有多种选择,我们要求所有选择方案最终导致的总和的分布。这种问题,生成函数是我们的好朋友!

我们可以用一个多项式来表示一个位置的所有选择。比如,对于位置 ii,如果我们能选择的数值集合是 ViV_i,那么对应的生成函数就是 Pi(z)=vVizvP_i(z) = \sum_{v \in V_i} z^v。多项式中 zvz^v 的系数为 1,表示可以选择数值 vv

当我们把所有位置的选择组合起来时,总和的生成函数就是所有单个生成函数的乘积!也就是说,如果我们为每个位置 ii 选择一个值 bib_i,总和为 bi\sum b_i,那么所有可能总和的生成函数就是 F(z)=i=1nPi(z)F(z) = \prod_{i=1}^n P_i(z)F(z)F(z)zSz^S 的系数,就代表着凑出总和为 SS 的方案数。

第三步:处理“花费不超过1”的约束

我们有两种情况:

  1. 花费为 0:我们只能修改那些 ai=1a_i = -1 的位置。对于 ai1a_i \ne -1 的位置,我们必须保持原样,即选择 aia_i

    • 如果 ai=1a_i = -1,选择的集合是 SiS_i,生成函数是 Pi(z)=vSizvP_i(z) = \sum_{v \in S_i} z^v
    • 如果 ai1a_i \ne -1,选择是固定的 aia_i,生成函数是 Pi(z)=zaiP_i(z) = z^{a_i}

    所以,花费为 0 的总和生成函数是:

    Fcost0(z)=i=1nPi(z)F_{cost0}(z) = \prod_{i=1}^n P_i(z)

    其中 Pi(z)P_i(z) 根据 aia_i 的值来确定。

  2. 花费为 1:我们选择一个 jj(其中 aj1a_j \ne -1),将它修改为 SjS_j 中的一个值。其他所有位置的选择规则和花费为 0 时一样。

    对于一个特定的 jj,它的生成函数从 zajz^{a_j} 变成了 vSjzv\sum_{v \in S_j} z^v。其他位置不变。 所以,只修改 jj 的生成函数是:

    Fcost1,j(z)=Fcost0(z)zaj(vSjzv)F_{cost1, j}(z) = \frac{F_{cost0}(z)}{z^{a_j}} \cdot \left(\sum_{v \in S_j} z^v\right)

    总的花费为 1 的生成函数就是把所有可能的 jj 的情况加起来:

    Fcost1(z)=j:aj1Fcost1,j(z)F_{cost1}(z) = \sum_{j: a_j \ne -1} F_{cost1, j}(z)

那么,总花费不超过 1 的总生成函数就是 Ftotal(z)=Fcost0(z)+Fcost1(z)F_{total}(z) = F_{cost0}(z) + F_{cost1}(z)。 我们可以把它因式分解一下:

Ftotal(z)=Fcost0(z)+j:aj1Fcost0(z)zajPj(z)=Fcost0(z)(1+j:aj1zajPj(z))F_{total}(z) = F_{cost0}(z) + \sum_{j: a_j \ne -1} \frac{F_{cost0}(z)}{z^{a_j}} \cdot P_j(z) = F_{cost0}(z) \left(1 + \sum_{j: a_j \ne -1} z^{-a_j}P_j(z)\right)

这真是个漂亮的形式!我们只需要计算两个多项式然后把它们乘起来:

  • 第一个多项式是 Fcost0(z)F_{cost0}(z)
  • 第二个多项式是 M(z)=1+j:aj1zajPj(z)M(z) = 1 + \sum_{j: a_j \ne -1} z^{-a_j}P_j(z),我们叫它“修饰多项式”好了,喵~

第四步:用分治NTT加速多项式乘法

计算 Fcost0(z)=i=1nPi(z)F_{cost0}(z) = \prod_{i=1}^n P_i(z) 需要做 n1n-1 次多项式乘法。如果一个个乘,度数会越来越大,时间上承受不住。 这里可以用分治的思想来优化! 我们可以把 nn 个多项式分成两半,递归地计算左半部分的乘积和右半部分的乘积,最后再把这两个结果乘起来。 mult(1...n) = mult(1...n/2) * mult(n/2+1...n) 多项式乘法本身,可以用快速数论变换 (NTT) 来实现,时间复杂度是 O(NlogN)O(N \log N),其中 NN 是多项式的度数。

第五步:控制多项式的“体型”

一个问题是,多项式的度数(也就是最大可能的总和)可能会非常大(n×kn \times k 的级别),直接用NTT会很慢。但我们还记得吗?f(b)f(b) 的结果只和总和模 k1k-1 有关!

这启发我们对多项式进行“瘦身”。如果一个总和 sum >= k,它的最终结果和 f(digits of sum) 一样。而 sumdigits of sumk1k-1 是同余的。 所以,我们可以把所有度数 k\ge k 的项 zsumz^{\text{sum}} 合并到某个代表性的低次项上。 一个非常巧妙的方法是,我们维护一个度数不超过 2k12k-1 的多项式。

  • 对于度数 d<kd < k 的项 zdz^d,我们保留它。
  • 对于度数 dkd \ge k 的项 zdz^d,我们把它合并到 zk+(dk)(modk1)z^{k + (d-k) \pmod{k-1}} 上。

这样,在分治NTT的每一步乘法之后,我们都进行一次这样的“折叠”操作,把结果多项式的度数控制在 4k4k 左右,再折叠回 2k2k 以内,NTT的规模就不会无限增长啦!

最终的作战计划!

  1. 准备NTT:写好一个标准的多项式乘法板子。
  2. 计算 Fcost0(z)F_{cost0}(z)
    • 对每个 i[1,n]i \in [1, n],生成初始多项式 Pi(z)P_i(z)
    • 使用分治+NTT,计算这 nn 个多项式的乘积。在每一步分治的合并(乘法)后,都进行“折叠”操作,将多项式大小保持在 2k2k 左右。
  3. 计算修饰多项式 M(z)M(z)
    • 创建一个多项式 M(z)M(z)
    • 为了处理 zajz^{-a_j} 可能出现的负指数,我们给所有指数加上一个偏移量,比如 kk
    • M(z)M(z) 初始化为 zkz^k(代表常数1)。
    • 对于每个 jj 使得 aj1a_j \ne -1,我们遍历 SjS_j 中的每个元素 vv,给 M(z)M(z)zvaj+kz^{v-a_j+k} 项的系数加 1。
  4. 最终相乘
    • 计算最终的总生成函数 Ftotal(z)=Fcost0(z)M(z)F_{total}(z) = F_{cost0}(z) \cdot M(z)
  5. 统计答案
    • Ftotal(z)F_{total}(z)ziz^i 的系数,代表总花费 1\le 1 时,凑出总和为 iki-k(因为有偏移)的方案数。
    • 遍历 Ftotal(z)F_{total}(z) 的所有非零项,对于每个系数 cic_i,它对应总和 S=ikS = i-k
    • 我们用一个递归函数 get_f(S) 计算出 SS 对应的最终结果 xx
    • cic_i 累加到最终答案 ans[x] 上。
  6. 输出结果:喵~ 大功告成,输出 ans 数组就好啦!

这个方法把所有情况都优雅地统一到了生成函数里,是不是很酷?下面就来看看代码实现吧!

代码实现

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

// MOD and NTT parameters
const int MOD = 998244353;
const int NTT_PRIMITIVE_ROOT = 3;

// Fast power function
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// NTT function
void ntt(std::vector<long long>& a, bool invert) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;
        if (i < j)
            std::swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        long long wlen = power(NTT_PRIMITIVE_ROOT, (MOD - 1) / len);
        if (invert)
            wlen = power(wlen, MOD - 2);
        for (int i = 0; i < n; i += len) {
            long long w = 1;
            for (int j = 0; j < len / 2; j++) {
                long long u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
                a[i + j] = (u + v) % MOD;
                a[i + j + len / 2] = (u - v + MOD) % MOD;
                w = (w * wlen) % MOD;
            }
        }
    }

    if (invert) {
        long long n_inv = power(n, MOD - 2);
        for (long long& x : a)
            x = (x * n_inv) % MOD;
    }
}

// Polynomial multiplication
std::vector<long long> multiply(std::vector<long long> a, std::vector<long long> b) {
    int sz = 1;
    while (sz < a.size() + b.size()) sz <<= 1;
    a.resize(sz);
    b.resize(sz);

    ntt(a, false);
    ntt(b, false);
    for (int i = 0; i < sz; i++) a[i] = (a[i] * b[i]) % MOD;
    ntt(a, true);

    return a;
}

int N, K;

// The recursive function f(b) from the problem statement
int get_final_value(int sum) {
    if (sum < K) return sum;
    int digit_sum = 0;
    while (sum > 0) {
        digit_sum += sum % K;
        sum /= K;
    }
    return get_final_value(digit_sum);
}

// Folds a polynomial to keep its size manageable
std::vector<long long> fold_poly(const std::vector<long long>& p) {
    if (p.empty()) return {};
    // Target size is 2*K, representing sums < K and equivalence classes for sums >= K
    std::vector<long long> folded(2 * K, 0);
    for (int i = 0; i < p.size(); ++i) {
        if (p[i] == 0) continue;
        int target_idx;
        if (i < K) {
            target_idx = i;
        } else {
            // For sum >= K, map to a representative index >= K
            target_idx = K + (i - K) % (K - 1);
        }
        folded[target_idx] = (folded[target_idx] + p[i]) % MOD;
    }
    return folded;
}

// Divide and conquer function to multiply polynomials
std::vector<long long> multiply_and_fold_dq(const std::vector<std::vector<long long>>& polys, int l, int r) {
    if (l == r) {
        return fold_poly(polys[l]);
    }
    if (l > r) {
        // Return identity polynomial (1)
        std::vector<long long> identity(1, 1);
        return identity;
    }
    int mid = l + (r - l) / 2;
    auto left_prod = multiply_and_fold_dq(polys, l, mid);
    auto right_prod = multiply_and_fold_dq(polys, mid + 1, r);
    auto result = multiply(left_prod, right_prod);
    return fold_poly(result);
}


int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    std::cin >> N >> K;

    std::vector<int> a(N);
    int max_val = 0;
    for (int i = 0; i < N; ++i) {
        std::cin >> a[i];
        if (a[i] > max_val) max_val = a[i];
    }

    std::vector<std::vector<int>> S(N);
    std::vector<std::vector<long long>> initial_polys;
    for (int i = 0; i < N; ++i) {
        int sz;
        std::cin >> sz;
        S[i].resize(sz);
        std::vector<long long> p;
        int current_max_deg = 0;
        if (a[i] == -1) {
            for (int j = 0; j < sz; ++j) {
                std::cin >> S[i][j];
                if (S[i][j] > current_max_deg) current_max_deg = S[i][j];
            }
            p.assign(current_max_deg + 1, 0);
            for (int val : S[i]) p[val] = 1;
        } else {
            // Read and store S_i for later use
             for (int j = 0; j < sz; ++j) std::cin >> S[i][j];
            // For cost 0, this position is fixed
            p.assign(a[i] + 1, 0);
            p[a[i]] = 1;
        }
        initial_polys.push_back(p);
    }
    
    // Calculate F_cost0(z)
    auto cost0_poly = multiply_and_fold_dq(initial_polys, 0, N - 1);

    // Calculate the modifier polynomial M(z)
    std::vector<long long> modifier_poly(2 * K, 0);
    modifier_poly[K] = 1; // Represents the '1 +' part, shifted by K
    for (int i = 0; i < N; ++i) {
        if (a[i] != -1) {
            for (int v : S[i]) {
                int exponent = v - a[i] + K;
                if (exponent >= 0 && exponent < 2 * K) {
                    modifier_poly[exponent] = (modifier_poly[exponent] + 1) % MOD;
                }
            }
        }
    }

    // Get the final total polynomial
    auto final_poly = multiply(cost0_poly, modifier_poly);

    // Tally the results
    std::vector<long long> ans(K, 0);
    for (int i = 0; i < final_poly.size(); ++i) {
        if (final_poly[i] > 0) {
            int total_sum = i - K; // Adjust for the shift in modifier_poly
            if (total_sum >= 0) {
                int result_val = get_final_value(total_sum);
                ans[result_val] = (ans[result_val] + final_poly[i]) % MOD;
            }
        }
    }

    for (int i = 0; i < K; ++i) {
        std::cout << ans[i] << (i == K - 1 ? "" : " ");
    }
    std::cout << std::endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(NKlogK+KlogK)O(N \cdot K \log K + K \log K)

    • 分治计算 Fcost0(z)F_{cost0}(z) 的过程有 logN\log N 层。在每一层,我们需要做几次大小约为 2K×2K2K \times 2K 的 NTT 乘法。总的来说,这一部分的复杂度是 O(NKlogK)O(N \cdot K \log K)
    • 构建修饰多项式 M(z)M(z) 的时间取决于所有集合 SiS_i 的总大小,记为 ΣS|\Sigma S|,复杂度是 O(ΣS)O(|\Sigma S|)
    • 最后一次NTT乘法是 O(KlogK)O(K \log K)
    • 统计答案时,遍历最终多项式的复杂度是 O(K)O(K)get_final_value 函数的调用非常快。
    • 瓶颈在于分治NTT部分,所以总体复杂度近似为 O(NKlogK)O(N \cdot K \log K)
  • 空间复杂度: O(NK)O(N \cdot K)

    • 在分治递归的栈中,我们需要存储中间结果多项式,最坏情况下空间复杂度是 O(KlogN)O(K \log N)
    • 存储初始的多项式 Pi(z)P_i(z) 可能需要 O(Si+ai)O(\sum |S_i| + \sum a_i) 的空间,在最坏情况下是 O(NK)O(N \cdot K)
    • NTT本身需要 O(K)O(K) 的辅助空间。
    • 所以总的空间复杂度是 O(NK)O(N \cdot K)

知识点总结

  1. 生成函数: 它是解决组合计数问题的强大工具。将“选择”转化为多项式,将“组合”转化为多项式乘法,思路非常直观。
  2. 快速数论变换 (NTT): 在模意义下实现多项式快速乘法的算法,是FFT在数论中的对应。它是解决许多生成函数问题的关键。
  3. 分治: 当需要计算多个对象的累积效应(如连乘)时,分治是一个非常有效的优化策略,能将复杂度从平方级别降低到对数级别。
  4. 模运算性质与数根: 理解 N \equiv \text{sum_digits}(N) \pmod{k-1} 这个性质是简化问题的突破口。它让我们能从复杂递归中找到规律。
  5. 问题建模: 将“花费不超过1”的约束分解为“花费0”和“花费1”两种情况,并用生成函数统一表示,是本题建模的核心思想。把问题分解成更小的、可以独立计算的部分,然后优雅地组合起来,这就是算法的魅力所在呀,喵~