??? - 题解
标签与难度
标签: 快速数论变换 (NTT), 生成函数, 分治, 动态规划 难度: 2600
题目大意喵~
一位叫 Zeeman 的朋友有一个长度为 的数组 和 个整数集合 呐。数组 中的每个元素 的值都在 范围内。
对于每个位置 ,我们可以进行一次操作:将 的值修改为集合 中的任意一个元素。这次操作的花费是 1,但有一个例外:如果 原本就是 ,那么这次修改是免费的!
我们有一个神奇的函数 ,它能把一个整数数组 变成一个单独的数字。它的工作方式是这样的:
- 如果数组 只有一个元素,那结果就是这个元素。
- 如果数组 有多个元素,就先计算所有元素的和
SUM。然后,把SUM写成 进制数,它的每一位数字就组成了新的数组 。最后,我们对这个新数组 再次调用 函数,也就是计算 。
Zeeman 想知道,对于每一个可能的最终结果 (从 到 ),有多少种不同的修改方案,可以使得最终的数组 经过 的计算后,得到的结果是 ,并且总花费不超过 1?
也就是说,我们要么不花钱(只修改那些 的位置),要么只花 1 块钱(额外选择一个 的位置进行修改)。
要把所有 的答案都算出来,并且对 取模哦,喵~
解题思路分析
喵哈~!这道题看起来有点复杂,又是修改数组又是递归函数,但别怕,让我带你一步步拆解它!
关键的第一步:看穿 的本质!
这个函数 一直在求和、取 进制位、再求和……这其实是一个经典的过程,和数根(Digital Root)非常像!
我们知道一个性质:一个数 和它的 进制下所有数字之和 ,在模 的意义下是同余的。也就是 。 这是因为 , ,所以 。因为 总能被 整除,所以 是 的倍数。
函数 不断重复这个过程,直到数组长度为 1,也就是上一步的和小于 。设最终的结果是 ,那么最初的数组元素总和 和这个最终结果 之间一定满足:
这里有个小小的例外:
- 如果 ,那么结果就是 。
- 如果 并且 ,结果会是 (因为结果是在
[0, k-1)里的,而大于0的 的倍数最小就是 )。 - 对于其他情况,如果 且 ,那么结果就是 。
这个性质是解题的基石!但是直接处理 sum=0 和 sum>0 的情况有点麻烦。有没有更统一的方法呢?我们可以等到最后再用这个性质从总和反推结果,而不是在计算过程中就纠结于此。
第二步:用生成函数来表示“选择”
这道题的本质是“组合计数”,每个位置 都有多种选择,我们要求所有选择方案最终导致的总和的分布。这种问题,生成函数是我们的好朋友!
我们可以用一个多项式来表示一个位置的所有选择。比如,对于位置 ,如果我们能选择的数值集合是 ,那么对应的生成函数就是 。多项式中 的系数为 1,表示可以选择数值 。
当我们把所有位置的选择组合起来时,总和的生成函数就是所有单个生成函数的乘积!也就是说,如果我们为每个位置 选择一个值 ,总和为 ,那么所有可能总和的生成函数就是 。 中 的系数,就代表着凑出总和为 的方案数。
第三步:处理“花费不超过1”的约束
我们有两种情况:
花费为 0:我们只能修改那些 的位置。对于 的位置,我们必须保持原样,即选择 。
- 如果 ,选择的集合是 ,生成函数是 。
- 如果 ,选择是固定的 ,生成函数是 。
所以,花费为 0 的总和生成函数是:
其中 根据 的值来确定。
花费为 1:我们选择一个 (其中 ),将它修改为 中的一个值。其他所有位置的选择规则和花费为 0 时一样。
对于一个特定的 ,它的生成函数从 变成了 。其他位置不变。 所以,只修改 的生成函数是:
总的花费为 1 的生成函数就是把所有可能的 的情况加起来:
那么,总花费不超过 1 的总生成函数就是 。 我们可以把它因式分解一下:
这真是个漂亮的形式!我们只需要计算两个多项式然后把它们乘起来:
- 第一个多项式是 。
- 第二个多项式是 ,我们叫它“修饰多项式”好了,喵~
第四步:用分治NTT加速多项式乘法
计算 需要做 次多项式乘法。如果一个个乘,度数会越来越大,时间上承受不住。 这里可以用分治的思想来优化! 我们可以把 个多项式分成两半,递归地计算左半部分的乘积和右半部分的乘积,最后再把这两个结果乘起来。 mult(1...n) = mult(1...n/2) * mult(n/2+1...n) 多项式乘法本身,可以用快速数论变换 (NTT) 来实现,时间复杂度是 ,其中 是多项式的度数。
第五步:控制多项式的“体型”
一个问题是,多项式的度数(也就是最大可能的总和)可能会非常大( 的级别),直接用NTT会很慢。但我们还记得吗? 的结果只和总和模 有关!
这启发我们对多项式进行“瘦身”。如果一个总和 sum >= k,它的最终结果和 f(digits of sum) 一样。而 sum 和 digits of sum 模 是同余的。 所以,我们可以把所有度数 的项 合并到某个代表性的低次项上。 一个非常巧妙的方法是,我们维护一个度数不超过 的多项式。
- 对于度数 的项 ,我们保留它。
- 对于度数 的项 ,我们把它合并到 上。
这样,在分治NTT的每一步乘法之后,我们都进行一次这样的“折叠”操作,把结果多项式的度数控制在 左右,再折叠回 以内,NTT的规模就不会无限增长啦!
最终的作战计划!
- 准备NTT:写好一个标准的多项式乘法板子。
- 计算 :
- 对每个 ,生成初始多项式 。
- 使用分治+NTT,计算这 个多项式的乘积。在每一步分治的合并(乘法)后,都进行“折叠”操作,将多项式大小保持在 左右。
- 计算修饰多项式 :
- 创建一个多项式 。
- 为了处理 可能出现的负指数,我们给所有指数加上一个偏移量,比如 。
- 初始化为 (代表常数1)。
- 对于每个 使得 ,我们遍历 中的每个元素 ,给 的 项的系数加 1。
- 最终相乘:
- 计算最终的总生成函数 。
- 统计答案:
- 中 的系数,代表总花费 时,凑出总和为 (因为有偏移)的方案数。
- 遍历 的所有非零项,对于每个系数 ,它对应总和 。
- 我们用一个递归函数
get_f(S)计算出 对应的最终结果 。 - 将 累加到最终答案
ans[x]上。
- 输出结果:喵~ 大功告成,输出
ans数组就好啦!
这个方法把所有情况都优雅地统一到了生成函数里,是不是很酷?下面就来看看代码实现吧!
代码实现
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
// MOD and NTT parameters
const int MOD = 998244353;
const int NTT_PRIMITIVE_ROOT = 3;
// Fast power function
long long power(long long base, long long exp) {
long long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
// NTT function
void ntt(std::vector<long long>& a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
std::swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
long long wlen = power(NTT_PRIMITIVE_ROOT, (MOD - 1) / len);
if (invert)
wlen = power(wlen, MOD - 2);
for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (u - v + MOD) % MOD;
w = (w * wlen) % MOD;
}
}
}
if (invert) {
long long n_inv = power(n, MOD - 2);
for (long long& x : a)
x = (x * n_inv) % MOD;
}
}
// Polynomial multiplication
std::vector<long long> multiply(std::vector<long long> a, std::vector<long long> b) {
int sz = 1;
while (sz < a.size() + b.size()) sz <<= 1;
a.resize(sz);
b.resize(sz);
ntt(a, false);
ntt(b, false);
for (int i = 0; i < sz; i++) a[i] = (a[i] * b[i]) % MOD;
ntt(a, true);
return a;
}
int N, K;
// The recursive function f(b) from the problem statement
int get_final_value(int sum) {
if (sum < K) return sum;
int digit_sum = 0;
while (sum > 0) {
digit_sum += sum % K;
sum /= K;
}
return get_final_value(digit_sum);
}
// Folds a polynomial to keep its size manageable
std::vector<long long> fold_poly(const std::vector<long long>& p) {
if (p.empty()) return {};
// Target size is 2*K, representing sums < K and equivalence classes for sums >= K
std::vector<long long> folded(2 * K, 0);
for (int i = 0; i < p.size(); ++i) {
if (p[i] == 0) continue;
int target_idx;
if (i < K) {
target_idx = i;
} else {
// For sum >= K, map to a representative index >= K
target_idx = K + (i - K) % (K - 1);
}
folded[target_idx] = (folded[target_idx] + p[i]) % MOD;
}
return folded;
}
// Divide and conquer function to multiply polynomials
std::vector<long long> multiply_and_fold_dq(const std::vector<std::vector<long long>>& polys, int l, int r) {
if (l == r) {
return fold_poly(polys[l]);
}
if (l > r) {
// Return identity polynomial (1)
std::vector<long long> identity(1, 1);
return identity;
}
int mid = l + (r - l) / 2;
auto left_prod = multiply_and_fold_dq(polys, l, mid);
auto right_prod = multiply_and_fold_dq(polys, mid + 1, r);
auto result = multiply(left_prod, right_prod);
return fold_poly(result);
}
int main() {
std::ios_base::sync_with_stdio(false);
std::cin.tie(NULL);
std::cin >> N >> K;
std::vector<int> a(N);
int max_val = 0;
for (int i = 0; i < N; ++i) {
std::cin >> a[i];
if (a[i] > max_val) max_val = a[i];
}
std::vector<std::vector<int>> S(N);
std::vector<std::vector<long long>> initial_polys;
for (int i = 0; i < N; ++i) {
int sz;
std::cin >> sz;
S[i].resize(sz);
std::vector<long long> p;
int current_max_deg = 0;
if (a[i] == -1) {
for (int j = 0; j < sz; ++j) {
std::cin >> S[i][j];
if (S[i][j] > current_max_deg) current_max_deg = S[i][j];
}
p.assign(current_max_deg + 1, 0);
for (int val : S[i]) p[val] = 1;
} else {
// Read and store S_i for later use
for (int j = 0; j < sz; ++j) std::cin >> S[i][j];
// For cost 0, this position is fixed
p.assign(a[i] + 1, 0);
p[a[i]] = 1;
}
initial_polys.push_back(p);
}
// Calculate F_cost0(z)
auto cost0_poly = multiply_and_fold_dq(initial_polys, 0, N - 1);
// Calculate the modifier polynomial M(z)
std::vector<long long> modifier_poly(2 * K, 0);
modifier_poly[K] = 1; // Represents the '1 +' part, shifted by K
for (int i = 0; i < N; ++i) {
if (a[i] != -1) {
for (int v : S[i]) {
int exponent = v - a[i] + K;
if (exponent >= 0 && exponent < 2 * K) {
modifier_poly[exponent] = (modifier_poly[exponent] + 1) % MOD;
}
}
}
}
// Get the final total polynomial
auto final_poly = multiply(cost0_poly, modifier_poly);
// Tally the results
std::vector<long long> ans(K, 0);
for (int i = 0; i < final_poly.size(); ++i) {
if (final_poly[i] > 0) {
int total_sum = i - K; // Adjust for the shift in modifier_poly
if (total_sum >= 0) {
int result_val = get_final_value(total_sum);
ans[result_val] = (ans[result_val] + final_poly[i]) % MOD;
}
}
}
for (int i = 0; i < K; ++i) {
std::cout << ans[i] << (i == K - 1 ? "" : " ");
}
std::cout << std::endl;
return 0;
}复杂度分析
时间复杂度:
- 分治计算 的过程有 层。在每一层,我们需要做几次大小约为 的 NTT 乘法。总的来说,这一部分的复杂度是 。
- 构建修饰多项式 的时间取决于所有集合 的总大小,记为 ,复杂度是 。
- 最后一次NTT乘法是 。
- 统计答案时,遍历最终多项式的复杂度是 ,
get_final_value函数的调用非常快。 - 瓶颈在于分治NTT部分,所以总体复杂度近似为 。
空间复杂度:
- 在分治递归的栈中,我们需要存储中间结果多项式,最坏情况下空间复杂度是 。
- 存储初始的多项式 可能需要 的空间,在最坏情况下是 。
- NTT本身需要 的辅助空间。
- 所以总的空间复杂度是 。
知识点总结
- 生成函数: 它是解决组合计数问题的强大工具。将“选择”转化为多项式,将“组合”转化为多项式乘法,思路非常直观。
- 快速数论变换 (NTT): 在模意义下实现多项式快速乘法的算法,是FFT在数论中的对应。它是解决许多生成函数问题的关键。
- 分治: 当需要计算多个对象的累积效应(如连乘)时,分治是一个非常有效的优化策略,能将复杂度从平方级别降低到对数级别。
- 模运算性质与数根: 理解 N \equiv \text{sum_digits}(N) \pmod{k-1} 这个性质是简化问题的突破口。它让我们能从复杂递归中找到规律。
- 问题建模: 将“花费不超过1”的约束分解为“花费0”和“花费1”两种情况,并用生成函数统一表示,是本题建模的核心思想。把问题分解成更小的、可以独立计算的部分,然后优雅地组合起来,这就是算法的魅力所在呀,喵~