Skip to content

权值和 plus - 题解

标签与难度

标签: 数论, NTT, 快速傅里叶变换, 卷积, 生成函数, 原根, 离散对数, 动态规划, 组合数学 难度: 2700

题目大意喵~

nyaa~ 各位算法大师们好呀!今天我们来解决一道有点挑战性的题目,不过别怕,本喵会一步步带你拆解它的~

题目是这样哒:给定一个素数 PP,一个包含 QQ 个整数的集合 SS,还有三个整数 N,M,KN, M, K。我们需要构造两个长度为 KK 的数组 a=[a1,,aK]a = [a_1, \dots, a_K]b=[b1,,bK]b = [b_1, \dots, b_K],它们需要满足下面三个条件哦:

  1. 数组 aa 中所有元素的乘积,在模 PP 意义下等于 NN。也就是 i=1KaiN(modP)\prod_{i=1}^{K} a_i \equiv N \pmod{P}
  2. 数组 bb 中所有元素的乘积,在模 PP 意义下等于 MM。也就是 i=1KbiM(modP)\prod_{i=1}^{K} b_i \equiv M \pmod{P}
  3. 两个数组里的所有元素 ai,bia_i, b_i 都必须从集合 SS 中选取。

对于每一种满足条件的构造方案(即一对数组 aabb),我们会计算一个“权值”,这个权值是 i=1Kmin(ai,bi)\sum_{i=1}^K \min(a_i, b_i)

我们的最终任务,就是计算出所有合法的构造方案的权值之和,结果对 998244353998244353 取模。

解题思路分析

这道题看起来好复杂呀,要计算所有方案的权值和,直接枚举肯定是不行的说!K那么大,方案数可能是天文数字呢。所以,我们需要用更聪明的办法,喵~

核心思想:贡献法

我们要求的最终答案是:

TotalSum=所有合法的 (a,b) 方案(i=1Kmin(ai,bi))\text{TotalSum} = \sum_{\text{所有合法的 }(a, b) \text{ 方案}} \left( \sum_{i=1}^{K} \min(a_i, b_i) \right)

这种“和的期望”或者“和的和”的问题,一个经典的思路是贡献法,或者叫线性性质。我们可以交换求和的顺序:

TotalSum=i=1K(所有合法的 (a,b) 方案min(ai,bi))\text{TotalSum} = \sum_{i=1}^{K} \left( \sum_{\text{所有合法的 }(a, b) \text{ 方案}} \min(a_i, b_i) \right)

仔细观察一下,对于任何一个位置 ii(比如 i=1i=1)和任何另一个位置 jj(比如 j=2j=2),它们对总和的贡献计算方式是完全一样的,没有任何区别。这说明每个位置的贡献是相同的!所以,我们只需要计算出位置 i=1i=1 的总贡献,然后乘以 KK 就好啦。

TotalSum=K(所有合法的 (a,b) 方案min(a1,b1))\text{TotalSum} = K \cdot \left( \sum_{\text{所有合法的 }(a, b) \text{ 方案}} \min(a_1, b_1) \right)

现在问题就变成了,对于所有合法的 (a,b)(a, b) 方案,min(a1,b1)\min(a_1, b_1) 的总和是多少。我们可以进一步把这个和式拆开,枚举 a1a_1b1b_1 的所有可能取值。让 a1=x,b1=ya_1=x, b_1=y,其中 x,yx, y 都来自集合 SS

Contribution of i=1=xSySmin(x,y)(以 x,y 开头的合法方案数)\text{Contribution of } i=1 = \sum_{x \in S} \sum_{y \in S} \min(x, y) \cdot (\text{以 } x, y \text{ 开头的合法方案数})

“以 x,yx, y 开头的合法方案数”指的是,在固定 a1=x,b1=ya_1=x, b_1=y 的情况下,有多少种方法可以选出剩下的 a2,,aKa_2, \dots, a_Kb2,,bKb_2, \dots, b_K 使得总的乘积条件依然满足。

这可以分解成两个独立的问题:

  1. 有多少种长度为 K1K-1 的序列 (a2,,aK)(a_2, \dots, a_K),元素都来自 SS,使得 i=2KaiNx1(modP)\prod_{i=2}^{K} a_i \equiv N \cdot x^{-1} \pmod P?我们把这个方案数记为 Ca(x)C_a(x)
  2. 有多少种长度为 K1K-1 的序列 (b2,,bK)(b_2, \dots, b_K),元素都来自 SS,使得 i=2KbiMy1(modP)\prod_{i=2}^{K} b_i \equiv M \cdot y^{-1} \pmod P?我们把这个方案数记为 Cb(y)C_b(y)

于是,位置 11 的总贡献就是:

xSySmin(x,y)Ca(x)Cb(y)\sum_{x \in S} \sum_{y \in S} \min(x, y) \cdot C_a(x) \cdot C_b(y)

关键瓶颈:如何计算 Ca(x)C_a(x)Cb(y)C_b(y)

计算 Ca(x)C_a(x)(和 Cb(y)C_b(y))是这道题的核心难点。我们需要计算从集合 SS 中选 K1K-1 个数,使其乘积模 PP 等于某个目标值(比如 Nx1N \cdot x^{-1})的方案数。

直接在乘法上做动态规划太困难了。但是,我们知道 PP 是一个素数,在模 PP 的世界里,所有非零的数构成一个循环群 ZP\mathbb{Z}_P^*,它的阶是 P1P-1。这意味着存在一个原根 gg,使得所有非零的数 zz 都可以表示成 gg 的幂次,即 zgk(modP)z \equiv g^k \pmod P。这个 kk 就是 zzgg 为底的离散对数(或称为指标)。

这个性质超级棒!因为它可以将乘法问题转化为加法问题

i=2KaiNx1(modP)    i=2Kdlog(ai)dlog(Nx1)(modP1)\prod_{i=2}^{K} a_i \equiv N \cdot x^{-1} \pmod P \iff \sum_{i=2}^{K} \text{dlog}(a_i) \equiv \text{dlog}(N \cdot x^{-1}) \pmod{P-1}

其中 dlog(z)\text{dlog}(z) 表示 zz 的离散对数。

现在问题就变成了:从集合 SS 中元素的离散对数集合里,选 K1K-1 个数,使其和模 P1P-1 等于某个目标值的方案数。

这是一个经典的可以用生成函数卷积解决的问题!

  1. 我们先构造一个多项式 A(z)=sS,s≢0(modP)zdlog(s)A(z) = \sum_{s \in S, s \not\equiv 0 \pmod P} z^{\text{dlog}(s)}。这个多项式的第 ii 项系数表示 SS 中有多少个非零元素的离散对数是 ii
  2. 我们要求的是从 SS 中选 K1K-1 个非零元素,它们的离散对数之和。这正好对应着多项式 A(z)A(z)K1K-1 次幂!即 (A(z))K1(A(z))^{K-1}
  3. 因为我们的加法是模 P1P-1 的,所以我们实际上是在计算循环卷积,也就是多项式乘法的结果要对 zP11z^{P-1}-1 取模。
  4. 计算 (A(z))K1(modzP11)(A(z))^{K-1} \pmod{z^{P-1}-1} 的最高效方法就是NTT (数论变换),配合快速幂。我们先把 A(z)A(z) 用 NTT 变换到点值表示,在点值表示下,乘法就是对应点值的乘法,幂运算就是点值的幂运算。算出结果的点值表示后,再用 INTT (逆NTT) 变换回系数表示。

F(z)=(A(z))K1(modzP11)=j=0P2fjzjF(z) = (A(z))^{K-1} \pmod{z^{P-1}-1} = \sum_{j=0}^{P-2} f_j z^j。那么,fjf_j 就代表从 SS 中选 K1K-1 个非零元素,其离散对数之和模 P1P-1 等于 jj 的方案数。

处理0的情况 喵~ 我们差点忘了0这个特殊的小家伙。

  • 如果目标乘积 T≢0(modP)T \not\equiv 0 \pmod P,那选出的数里一个都不能是0。方案数就是我们上面用 NTT 算出来的。
  • 如果目标乘积 T0(modP)T \equiv 0 \pmod P,那选出的数里至少要有一个是0。我们可以用总方案数减去全是非零的方案数。假设 SS 中有 s0s_0 个0,非零的有 Qs0Q-s_0 个。选 K1K-1 个数的总方案是 QK1Q^{K-1},选出的数全是非零的方案是 (Qs0)K1(Q-s_0)^{K-1}。所以,乘积为0的方案数是 QK1(Qs0)K1Q^{K-1} - (Q-s_0)^{K-1}

有了这些,我们就可以计算出任意 Ca(x)C_a(x)Cb(y)C_b(y) 啦!

最后的求和:一个聪明的技巧

现在我们有了计算 Ca(x)C_a(x)Cb(y)C_b(y) 的方法,但要直接计算 xSySmin(x,y)Ca(x)Cb(y)\sum_{x \in S} \sum_{y \in S} \min(x, y) \cdot C_a(x) \cdot C_b(y) 还是有点慢。

这里有一个很棒的优化!我们先把集合 SS 从小到大排序,得到 s1<s2<<sQs_1 < s_2 < \dots < s_Q。然后我们从大到小遍历 sis_i。 当我们遍历到 sis_i 时,我们计算它对总和的贡献。哪些项的 min\min 值是 sis_i 呢?是所有形如 (si,sj)(s_i, s_j) 其中 jij \ge i 的配对,以及 (sj,si)(s_j, s_i) 其中 j>ij > i 的配对。

这启发我们维护两个后缀和:

  • suffix_sum_a: j=i+1QCa(sj)\sum_{j=i+1}^Q C_a(s_j)
  • suffix_sum_b: j=i+1QCb(sj)\sum_{j=i+1}^Q C_b(s_j)

当我们从 i=Qi=Q 递减到 11 时:

  1. 计算当前的 cai=Ca(si)ca_i = C_a(s_i)cbi=Cb(si)cb_i = C_b(s_i)
  2. 考虑所有 min(x,y)=si\min(x,y)=s_i 的情况。
    • 对所有 j>ij > i,配对 (si,sj)(s_i, s_j) 的贡献是 sicaicbjs_i \cdot ca_i \cdot cb_j。总和是 sicaij=i+1Qcbj=sicaisuffix_sum_bs_i \cdot ca_i \cdot \sum_{j=i+1}^Q cb_j = s_i \cdot ca_i \cdot \text{suffix\_sum\_b}
    • 对所有 j>ij > i,配对 (sj,si)(s_j, s_i) 的贡献是 sicajcbis_i \cdot ca_j \cdot cb_i。总和是 sicbij=i+1Qcaj=sicbisuffix_sum_as_i \cdot cb_i \cdot \sum_{j=i+1}^Q ca_j = s_i \cdot cb_i \cdot \text{suffix\_sum\_a}
    • 还有配对 (si,si)(s_i, s_i),贡献是 sicaicbis_i \cdot ca_i \cdot cb_i
  3. 把这些贡献加起来,就是 si(caisuffix_sum_b+cbisuffix_sum_a+caicbi)s_i \cdot (ca_i \cdot \text{suffix\_sum\_b} + cb_i \cdot \text{suffix\_sum\_a} + ca_i \cdot cb_i)
  4. 更新后缀和:suffix_sum_a += ca_isuffix_sum_b += cb_i

这样,我们只需要一次遍历就可以算出总和了!最后别忘了乘以 KK 哦。

总结一下步骤

  1. 找到模 PP 的一个原根 gg
  2. 预计算离散对数表 dlog 和反向表 val_of_log
  3. 构造初始多项式 A(z)A(z),其系数表示 SS 中非零元素离散对数的分布。
  4. 使用 NTT 和快速幂计算 F(z)=(A(z))K1(modzP11)F(z) = (A(z))^{K-1} \pmod{z^{P-1}-1}
  5. 对集合 SS 排序。
  6. 从大到小遍历排好序的 siSs_i \in S,使用后缀和技巧计算 min(x,y)Ca(x)Cb(y)\sum \min(x,y) C_a(x) C_b(y)
  7. 将结果乘以 KK 得到最终答案。

好啦,思路理清了,接下来就看本喵的代码实现吧!喵~

代码实现

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <map>

using namespace std;

using ll = long long;

const int MOD = 998244353;
const int NTT_G = 3;

// 快速幂
ll power(ll base, ll exp) {
    ll res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// 求逆元
ll mod_inverse(ll n) {
    return power(n, MOD - 2);
}

// NTT 实现
void ntt(vector<ll>& a, bool invert) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;
        if (i < j)
            swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        ll wlen = power(NTT_G, (MOD - 1) / len);
        if (invert) wlen = mod_inverse(wlen);
        for (int i = 0; i < n; i += len) {
            ll w = 1;
            for (int j = 0; j < len / 2; j++) {
                ll u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
                a[i + j] = (u + v) % MOD;
                a[i + j + len / 2] = (u - v + MOD) % MOD;
                w = (w * wlen) % MOD;
            }
        }
    }
    if (invert) {
        ll n_inv = mod_inverse(n);
        for (ll& x : a)
            x = (x * n_inv) % MOD;
    }
}

// 多项式循环卷积
vector<ll> multiply(vector<ll> a, vector<ll> b, int mod_len) {
    int n = 1;
    while (n < 2 * mod_len) n <<= 1;
    a.resize(n);
    b.resize(n);

    ntt(a, false);
    ntt(b, false);
    for (int i = 0; i < n; i++) a[i] = (a[i] * b[i]) % MOD;
    ntt(a, true);
    
    a.resize(mod_len * 2 -1);
    // 折叠,实现循环卷积
    for (int i = mod_len; i < a.size(); ++i) {
        a[i - mod_len] = (a[i - mod_len] + a[i]) % MOD;
    }
    a.resize(mod_len);
    return a;
}

// 多项式快速幂 (循环卷积版)
vector<ll> poly_pow(vector<ll> base, ll exp, int mod_len) {
    vector<ll> res(mod_len, 0);
    res[0] = 1;
    while (exp > 0) {
        if (exp % 2 == 1) res = multiply(res, base, mod_len);
        base = multiply(base, base, mod_len);
        exp /= 2;
    }
    return res;
}

// 找原根
int find_primitive_root(int p) {
    if (p == 2) return 1;
    int phi = p - 1;
    vector<int> factors;
    int temp_phi = phi;
    for (int i = 2; i * i <= temp_phi; ++i) {
        if (temp_phi % i == 0) {
            factors.push_back(i);
            while (temp_phi % i == 0) temp_phi /= i;
        }
    }
    if (temp_phi > 1) factors.push_back(temp_phi);

    for (int g = 2; g < p; ++g) {
        bool ok = true;
        for (int factor : factors) {
            if (power(g, phi / factor) == 1) {
                ok = false;
                break;
            }
        }
        if (ok) return g;
    }
    return -1; // Should not happen for a prime
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int P, Q, N, M;
    ll K;
    cin >> P >> Q >> N >> M >> K;

    vector<int> s(Q);
    int zero_count = 0;
    for (int i = 0; i < Q; ++i) {
        cin >> s[i];
        if (s[i] % P == 0) {
            zero_count++;
        }
    }

    if (K == 0) {
        if (N % P == 1 && M % P == 1) {
            cout << 0 << endl;
        } else {
            cout << 0 << endl;
        }
        return 0;
    }

    int g = find_primitive_root(P);
    map<int, int> dlog;
    vector<int> val_of_log(P - 1);
    ll current_val = 1;
    for (int i = 0; i < P - 1; ++i) {
        dlog[current_val] = i;
        val_of_log[i] = current_val;
        current_val = (current_val * g) % P;
    }

    int mod_len = P - 1;
    vector<ll> A(mod_len, 0);
    for (int val : s) {
        if (val % P != 0) {
            A[dlog[val % P]]++;
        }
    }

    vector<ll> F = poly_pow(A, K - 1, mod_len);

    ll ways_for_prod_zero = (power(Q, K - 1) - power(Q - zero_count, K - 1) + MOD) % MOD;

    auto get_count = [&](int target_prod, int fixed_val) {
        fixed_val %= P;
        if (fixed_val == 0) {
            if (target_prod % P == 0) {
                return power(Q, K - 1);
            } else {
                return 0LL;
            }
        }
        
        if (target_prod % P == 0) {
            return ways_for_prod_zero;
        }

        ll inv_fixed = power(fixed_val, P - 2);
        ll final_target = (1LL * target_prod * inv_fixed) % P;
        return F[dlog[final_target]];
    };

    sort(s.begin(), s.end());

    ll total_contribution = 0;
    ll suffix_sum_a = 0, suffix_sum_b = 0;

    for (int i = Q - 1; i >= 0; --i) {
        ll count_a = get_count(N, s[i]);
        ll count_b = get_count(M, s[i]);

        ll term1 = (count_a * suffix_sum_b) % MOD;
        ll term2 = (count_b * suffix_sum_a) % MOD;
        ll term3 = (count_a * count_b) % MOD;
        
        ll current_s_contribution = (term1 + term2 + term3) % MOD;
        total_contribution = (total_contribution + (ll)s[i] * current_s_contribution) % MOD;

        suffix_sum_a = (suffix_sum_a + count_a) % MOD;
        suffix_sum_b = (suffix_sum_b + count_b) % MOD;
    }

    ll final_ans = (total_contribution * K) % MOD;
    cout << final_ans << endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(PlogPlogK+QlogQ)O(P \log P \log K + Q \log Q)

    • 寻找原根:一个比较朴素的方法需要对 P1P-1 进行质因数分解,复杂度大约是 O(P)O(\sqrt{P}),然后对每个候选 gg 进行验证,总复杂度可以接受。我们的实现是 O(glog(phi)factors)O(g \cdot \log(\text{phi}) \cdot |\text{factors}|),其中 gg 很小。
    • 预计算离散对数:O(P)O(P)
    • 多项式快速幂:poly_pow 函数需要执行 O(logK)O(\log K) 次多项式乘法。每次乘法基于 NTT,复杂度为 O(PlogP)O(P' \log P'),其中 PP' 是大于 2(P1)2(P-1) 的最小2的幂次。所以这部分是 O(PlogPlogK)O(P \log P \log K)
    • 排序:对集合 SS 排序需要 O(QlogQ)O(Q \log Q)
    • 主循环计算贡献:O(Q)O(Q)
    • 综上,时间复杂度的瓶颈在于多项式快速幂。
  • 空间复杂度: O(P)O(P)

    • NTT 需要的数组大小与 PP 相关,为 O(P)O(P)
    • 离散对数表 dlogval_of_log 也需要 O(P)O(P) 的空间。

知识点总结

这道题是数论和多项式算法的完美结合,喵~ 解决它需要融会贯通好几个知识点:

  1. 贡献法: 将对总和的求和,转化为计算每个元素的贡献再求和,是组合计数问题中的一个核心思想。
  2. 数论基础:
    • 原根: 模素数 PP 的乘法群是循环群,原根是这个群的生成元。这是将乘法问题转化为加法问题的钥匙。
    • 离散对数: 将乘法群中的元素映射到加法群(模 P1P-1)中的指数,是原根性质的具体应用。
  3. 生成函数/多项式:
    • 将组合计数问题转化为多项式系数问题。在这里,我们用 zkz^k 的系数表示和为 kk 的方案数。
  4. NTT (数论变换):
    • 快速计算多项式卷积的算法。它是FFT在有限域上的变种,避免了浮点数精度问题。
    • 循环卷积: NTT 本身计算的是线性卷积,但通过对结果进行“折叠”,可以得到循环卷积的结果,恰好对应模 P1P-1 的加法。
  5. 算法技巧:
    • 快速幂: 不仅可以用于数值,也可以用于任何满足结合律的运算,比如多项式乘法。
    • 后缀和/前缀和优化: 在计算累加式时,通过维护动态的和来避免重复计算,将 O(Q2)O(Q^2) 的暴力求和优化到 O(Q)O(Q)

希望这篇题解能帮到你哦!如果还有不明白的地方,随时可以再来问本喵,喵~ >w<