Skip to content

任意模数NTT - 题解

标签与难度

标签: 动态规划, 拉格朗日插值, 多项式, 计数问题, 组合数学, 模运算 难度: 2300

题目大意喵~

主人你好呀,这道题是这样的喵~

我们有 nn 个随机变量 a1,a2,,ana_1, a_2, \dots, a_n。每一个 aia_i 都会从 [1,m][1, m] 这个区间里等概率地随机选择一个整数。总共的可能性有 mnm^n 种,对吧?

然后呢,我们定义一个值 SS,它是所有相邻两项之和的最大值,也就是 S=max{a1+a2,a2+a3,,an1+an}S = \max\{a_1+a_2, a_2+a_3, \dots, a_{n-1}+a_n\}

我们的任务是,对于所有 mnm^n 种可能的序列 {ai}\{a_i\},计算出它们各自对应的 SS 值,然后把这些 SS 全部加起来。最后的结果要对 109+710^9 + 7 取模哦!

举个例子,如果 n=3,m=2n=3, m=2,序列可以是 {1,1,1},{1,1,2},,{2,2,2}\{1,1,1\}, \{1,1,2\}, \dots, \{2,2,2\},总共 23=82^3=8 种。 对于序列 {1,2,1}\{1, 2, 1\},相邻和是 1+2=31+2=32+1=32+1=3,所以 S=max{3,3}=3S = \max\{3,3\}=3。 我们要对所有 8 种序列都算出 SS,然后求和,喵~

解题思路分析

主人,看到这道题,特别是那个巨大的 mm (m109m \le 10^9),直接枚举所有情况肯定是不行的啦,会累死猫的!nn 也有 100 这么大,我们需要找到更聪明的办法,喵~

期望转求和

直接计算所有 SS 的和很困难,因为这个 max 操作太调皮了。在处理和期望或总和有关的 max 问题时,有一个经典的小技巧哦!

对于一个非负整数随机变量 XX,它的期望(或者说总和,只是差一个常数因子)可以这样算:

k=1P(Xk)\sum_{k=1}^{\infty} P(X \ge k)

应用到我们这道题里,所有 SS 的总和就可以表示为:

TotalSum=k=22m(S 值至少为 k 的序列数量)\text{TotalSum} = \sum_{k=2}^{2m} (\text{S 值至少为 k 的序列数量})

SS 的最小可能值是 1+1=21+1=2,最大是 m+m=2mm+m=2m

这个公式可以进一步变形,变成:

TotalSum=k=22m(mnS 值小于 k 的序列数量)\text{TotalSum} = \sum_{k=2}^{2m} (m^n - \text{S 值小于 k 的序列数量})

也就是:

TotalSum=k=12m1(mnS 值至多为 k 的序列数量)\text{TotalSum} = \sum_{k=1}^{2m-1} (m^n - \text{S 值至多为 k 的序列数量})

g(k,m)g(k, m) 表示当取值范围是 [1,m][1, m] 时,SkS \le k 的序列数量。那么总和就是:

Ans(m)=k=12m1(mng(k,m))\text{Ans}(m) = \sum_{k=1}^{2m-1} (m^n - g(k, m))

这样,我们就把问题转化为了如何计算 g(k,m)g(k, m),也就是满足所有 ai+ai+1ka_i+a_{i+1} \le k1aim1 \le a_i \le m 的序列数量。

动态规划计算 g(k,m)g(k, m)

计算 g(k,m)g(k, m) 是一个典型的动态规划问题,喵~ 我们可以定义状态 dp[i][j] 表示:长度为 ii 的序列,其最后一个元素 ai=ja_i=j,并且满足所有相邻和约束的方案数。

  • 状态: dp[i][j] = 满足 ap[1,m]a_p \in [1, m]ap1+apka_{p-1}+a_p \le k(对所有 pip \le i)的、以 ai=ja_i=j 结尾的长度为 ii 的序列数量。
  • 转移: 要计算 dp[i][j],我们需要看它的前一个元素 ai1a_{i-1} 可以是什么。设 ai1=pa_{i-1}=p,那么必须满足 p+jkp+j \le k,也就是 pkjp \le k-j。同时 1pm1 \le p \le m。所以 pp 的取值范围是 [1,min(m,kj)][1, \min(m, k-j)]

    dp[i][j]=p=1min(m,kj)dp[i1][p]dp[i][j] = \sum_{p=1}^{\min(m, k-j)} dp[i-1][p]

  • 基础情况: dp[1][j] = 1,对于所有 1jm1 \le j \le m
  • 最终结果: g(k,m)=j=1mdp[n][j]g(k, m) = \sum_{j=1}^{m} dp[n][j]

这个 DP 的转移是一个前缀和的形式!我们可以用一个 pref 数组来优化它。令 pref[i-1][x] = sum(dp[i-1][p] for p from 1 to x),那么转移就变成了 O(1)O(1) 的: dp[i][j] = pref[i-1][min(m, k-j)]。 这样,计算一次 g(k,m)g(k, m) 的时间复杂度就是 O(nm)O(n \cdot m)

但是!mm 实在是太大了,这个 DP 还是不行呀!

发现多项式的秘密!

主人,别灰心!我们换个角度看问题。让我们把答案 Ans(m) 看作是关于变量 mm 的一个函数。

Ans(m) 是由 mnm^n 个项相加得到的,每个项是 max(ai+ai+1)\max(a_i+a_{i+1})g(k, m) 的计算中,m 只是作为 aia_i 的上界出现。仔细分析可以发现,g(k,m)g(k, m) 是一个关于 mm 的分段多项式,每一段的次数最高是 nn。而 Ans(m) 是对许多这样的式子求和,经过复杂的推导(这里本喵就省略掉啦,因为有点复杂,但是结论很可靠哦!),可以证明 Ans(m) 本身是一个关于 mm多项式

它的次数是多少呢?Ans(m) 的数量级大约是 mn×(平均的S值)m^n \times (\text{平均的S值}),而平均的 SS 值又和 mm 正相关,所以 Ans(m) 大约是 mn+1m^{n+1} 的数量级。所以我们可以大胆猜测,Ans(m) 是一个关于 mmn+1n+1 次多项式!

拉格朗日插值法

既然我们知道了 Ans(m) 是一个 n+1n+1 次多项式,那么根据代数基本定理,只要我们知道它在 n+2n+2 个不同点上的取值,我们就能唯一确定这个多项式!

于是,我们的计划就清晰了,喵~

  1. 选择 n+2n+2 个比较小的 mm 值,比如 m=1,2,,n+2m' = 1, 2, \dots, n+2
  2. 对于每一个 mm',我们用上面提到的 O(nm2)O(n \cdot m'^2) 的方法计算出 Ans(m') 的确切值。因为这里的 mm' 很小(最大是 n+2n+2),所以这个计算是很快的。
  3. 现在我们得到了 n+2n+2 个点 (m,Ans(m))(m', \text{Ans}(m'))
  4. 使用拉格朗日插值法,根据这 n+2n+2 个点,求出多项式在我们的目标 mm 处的值。

计算 Ans(m') 的总复杂度: 对于每个 m[1,n+2]m' \in [1, n+2],我们需要计算 k=12m1(mng(k,m))\sum_{k=1}^{2m'-1} (m'^n - g(k, m'))。 计算一个 g(k,m)g(k, m') 需要 O(nm)O(n \cdot m')。 所以计算一个 Ans(m') 需要 O(k=12m1nm)=O(mnm)=O(nm2)O(\sum_{k=1}^{2m'-1} n \cdot m') = O(m' \cdot n \cdot m') = O(n \cdot m'^2)。 总的计算所有点的复杂度是 m=1n+2O(nm2)=O(n(n+2)3)O(n4)\sum_{m'=1}^{n+2} O(n \cdot m'^2) = O(n \cdot (n+2)^3) \approx O(n^4)。 如果 n=100n=100,这大概是 1001023108100 \cdot 102^3 \approx 10^8 的计算量,有点紧张但还是有希望通过的!如果 nn 的实际有效范围更小(比如 n40n \le 40),那就完全没问题啦。

拉格朗日插值公式: 有了点集 (x0,y0),(x1,y1),,(xd,yd)(x_0, y_0), (x_1, y_1), \dots, (x_d, y_d),求多项式在 xx 处的值 P(x)P(x)

P(x)=i=0dyiLi(x)P(x) = \sum_{i=0}^{d} y_i \cdot L_i(x)

其中 Li(x)L_i(x) 是拉格朗日基多项式:

Li(x)=j=0,jidxxjxixjL_i(x) = \prod_{j=0, j \ne i}^{d} \frac{x - x_j}{x_i - x_j}

这个计算的复杂度是 O(d2)O(d^2),在这里就是 O(n2)O(n^2),和前面求值的过程比起来快多了。

好啦,思路就是这样!结合 DP 和拉格朗日插值,我们就能解决这个看似棘手的问题啦,喵~

代码实现

这是本喵根据上面的思路,精心为主人准备的代码哦~ 注释很详细的,希望能帮到你!

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

long long power(long long base, long long exp) {
    long long res = 1;
    long long mod = 1e9 + 7;
    base %= mod;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % mod;
        base = (base * base) % mod;
        exp /= 2;
    }
    return res;
}

long long inv(long long n) {
    long long mod = 1e9 + 7;
    return power(n, mod - 2);
}

// 计算 g(max_pair_sum, m_prime): S <= max_pair_sum 的序列数量
long long count_sequences_le(int n, int m_prime, int max_pair_sum) {
    if (max_pair_sum < 2) return 0;
    long long mod = 1e9 + 7;

    // dp[i][j]: 长度为i, 结尾为j, 且满足条件的序列数
    vector<vector<long long>> dp(n + 1, vector<long long>(m_prime + 1, 0));
    // pref[i][j]: dp[i][1] + ... + dp[i][j] 的前缀和
    vector<vector<long long>> pref(n + 1, vector<long long>(m_prime + 1, 0));

    // 基础情况: i = 1
    for (int j = 1; j <= m_prime; ++j) {
        dp[1][j] = 1;
    }
    for (int j = 1; j <= m_prime; ++j) {
        pref[1][j] = (pref[1][j - 1] + dp[1][j]) % mod;
    }

    // 递推: i from 2 to n
    for (int i = 2; i <= n; ++i) {
        for (int j = 1; j <= m_prime; ++j) {
            int limit = max_pair_sum - j;
            if (limit >= 1) {
                dp[i][j] = pref[i - 1][min(m_prime, limit)];
            }
        }
        for (int j = 1; j <= m_prime; ++j) {
            pref[i][j] = (pref[i][j - 1] + dp[i][j]) % mod;
        }
    }

    return pref[n][m_prime];
}

// 对于一个小的 m_prime, 计算 Ans(m_prime)
long long calculate_ans_small_m(int n, int m_prime) {
    long long mod = 1e9 + 7;
    long long total_sum = 0;
    long long total_sequences = power(m_prime, n);

    // Ans = sum_{k>=1} P(S>=k) = sum_{k>=2} (m^n - P(S<k))
    // S < k 等价于 S <= k-1
    for (int k = 2; k <= 2 * m_prime; ++k) {
        long long count_le = count_sequences_le(n, m_prime, k - 1);
        long long count_ge = (total_sequences - count_le + mod) % mod;
        total_sum = (total_sum + count_ge) % mod;
    }
    return total_sum;
}

// 拉格朗日插值
long long lagrange_interpolate(const vector<pair<long long, long long>>& points, long long m) {
    long long mod = 1e9 + 7;
    long long final_ans = 0;
    int num_points = points.size();

    for (int i = 0; i < num_points; ++i) {
        long long x_i = points[i].first;
        long long y_i = points[i].second;

        long long numerator = y_i;
        long long denominator = 1;

        for (int j = 0; j < num_points; ++j) {
            if (i == j) continue;
            long long x_j = points[j].first;
            
            numerator = (numerator * (m - x_j)) % mod;
            denominator = (denominator * (x_i - x_j)) % mod;
        }
        
        long long term = (numerator * inv(denominator)) % mod;
        final_ans = (final_ans + term) % mod;
    }

    return (final_ans + mod) % mod;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    long long m;
    cin >> n >> m;

    int num_points = n + 2;
    vector<pair<long long, long long>> points;

    // 如果 m 比较小,可以直接计算,不需要插值
    if (m <= num_points) {
        cout << calculate_ans_small_m(n, m) << endl;
        return 0;
    }

    // 1. 计算 n+2 个点
    for (int m_prime = 1; m_prime <= num_points; ++m_prime) {
        long long y = calculate_ans_small_m(n, m_prime);
        points.push_back({m_prime, y});
    }

    // 2. 拉格朗日插值
    cout << lagrange_interpolate(points, m) << endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(n4)O(n^4)

    • 我们需要计算 Ans(m') 对于 m=1,2,,n+2m' = 1, 2, \dots, n+2
    • 计算单个 Ans(m') 需要对 kk 从 1 到 2m12m'-1 循环,每次调用 count_sequences_le
    • count_sequences_le(n, m', k) 的复杂度是 O(nm)O(n \cdot m'),因为我们用了前缀和优化。
    • 所以计算一个 Ans(m') 的总时间是 O(2mnm)=O(nm2)O(2m' \cdot n \cdot m') = O(n \cdot m'^2)
    • 计算所有 n+2n+2 个点的时间是 m=1n+2O(nm2)=O(n(n+2)3)O(n4)\sum_{m'=1}^{n+2} O(n \cdot m'^2) = O(n \cdot (n+2)^3) \approx O(n^4)
    • 拉格朗日插值本身需要 O(n2)O(n^2) 时间,可以忽略不计。
    • 所以总时间复杂度由计算插值点决定,为 O(n4)O(n^4)
  • 空间复杂度: O(n2)O(n^2)

    • count_sequences_le 函数中,我们需要 dppref 两个二维数组,大小都是 (n+1) x (m_prime+1)
    • 由于 m_prime 的最大值是 n+2,所以空间复杂度是 O(n(n+2))O(n2)O(n \cdot (n+2)) \approx O(n^2)

知识点总结

这道题真是一次有趣的冒险呢,喵~ 我们用到的知识点有:

  1. 期望/总和的转换: 将求 max 的总和问题,转换为对 P(X >= k) 的求和,这是处理这类问题的一个非常强大的思想!
  2. 动态规划 (DP): 用 DP 来解决计数问题是基本操作啦。这里的 DP 用来数满足特定条件的序列数量。
  3. 前缀和优化: 识别出 DP 转移中的求和模式,并用前缀和将其从线性时间优化到常数时间,是 DP 优化的常用技巧。
  4. 多项式与插值: 发现答案是一个关于某个参数(这里是 mm)的多项式,是解决这类参数巨大问题的关键一步。
  5. 拉格朗日插值法: 一个强大的数学工具,可以根据有限个点的值,求出多项式在任意一点的值。

希望这篇题解能帮助主人理解这道题的奇妙之处!如果还有问题,随时可以再来找我哦,喵~