Skip to content

区间Mex(简单版本) - 题解

标签与难度

标签: 数学, 贡献法, 分治, 模运算, 快速幂, 数论, 求和公式 难度: 2300

题目大意喵~

哈喵~!各位算法大师们,今天我们来挑战一道有点意思的数学题,喵~

题目给我们一个长度为 nn 的排列(也就是 11nnnn 个数,每个都只出现一次)。我们需要对这个排列的所有连续子数组 a[i...j] 进行计算。

首先,我们要找到每个子数组的 mex 值。mex(i, j) 指的是在子数组 a[i...j] 中,没有出现过的最小正整数。比如说,对于子数组 {3, 1, 4},它包含了 1,但不包含 2,所以它的 mex 就是 2,喵~

接着,我们定义一个值 t(i,j)=(子数组长度)×mex(i,j)t(i, j) = (\text{子数组长度}) \times \text{mex}(i, j)。也就是 (ji+1)×mex(i,j)(j - i + 1) \times \text{mex}(i, j)

最终的目标,是计算下面这个超级大的和,并对 998244353998244353 取模:

F=i=1nj=int(i,j)2t(i,j)F = \sum_{i=1}^{n} \sum_{j=i}^{n} t(i,j) \cdot 2^{t(i,j)}

简单来说,就是遍历所有子数组,算出各自的 t(i,j)t(i, j),然后把 t(i,j)2t(i,j)t(i, j) \cdot 2^{t(i, j)} 这个值加起来。因为是简单版本,所以我们只需要对一个固定的排列计算一次就好啦,没有修改操作哦!

解题思路分析

这么复杂的求和公式,直接暴力遍历所有 O(n2)O(n^2) 个子数组,再为每个子数组计算 mex,肯定会超时的说!我的直觉告诉我,肯定有更聪明的办法,喵!

我们不妨换个角度思考。与其枚举子数组 [i, j],不如我们来枚举 mex 的值,然后计算所有 mex 等于这个值的子数组的贡献。这种方法叫做贡献法,是解决这类求和问题的常用技巧哦!

核心思想:按 mex 值计算贡献

一个子数组 [i, j]mex 值为 kk,意味着:

  1. 数字 1,2,,k11, 2, \dots, k-1 全部都出现在子数组 a[i...j] 中。
  2. 数字 kk 没有出现在子数组 a[i...j] 中。

为了方便处理,我们先预处理出每个数字 vv 在排列中的位置 pos[v]

步骤一:mex = 1 的情况

mex(i, j) = 1 是最简单的情况。这意味着数字 1 没有出现在子数组 a[i...j] 中。 找到数字 1 的位置 pos[1]。那么,任何一个完全在 pos[1] 左边或者右边的子数组,其 mex 都为 1

  • pos[1] 左边有 L = pos[1] - 1 个元素。
  • pos[1] 右边有 R = n - pos[1] 个元素。

对于左边的部分,我们可以枚举子数组长度 k1L。长度为 k 的子数组有 L-k+1 个。它们的 t 值都是 k×1=kk \times 1 = k。所以它们对总答案的贡献是 k=1L(Lk+1)k2k\sum_{k=1}^{L} (L-k+1) \cdot k \cdot 2^k。 右边的部分同理,贡献是 k=1R(Rk+1)k2k\sum_{k=1}^{R} (R-k+1) \cdot k \cdot 2^k。 这两个和式可以直接用一个 O(N)O(N) 的循环来计算,对于本题的数据范围是完全可以接受的,喵~

步骤二:mex = k (当 k > 1) 的情况(最关键的一步!)

现在我们来考虑一般情况。假设我们正在计算 mex = m+1 的贡献。 这意味着我们关心的是那些包含了 {1,2,,m}\{1, 2, \dots, m\} 但不包含 m+1m+1 的子数组。

我们可以采用一种增量的方式来思考。我们维护一个当前最小的区间 [l, r],它恰好包含了数字集合 {1,2,,m}\{1, 2, \dots, m\}

  • m=1m=1 时,[l, r] 就是 [pos[1], pos[1]]
  • 当我们从 mm 推进到 m+1m+1 时,我们找到数字 m+1m+1 的位置 p = pos[m+1]

如果 p 已经在当前的 [l, r] 区间内了(即 l < p < r),那么任何包含 {1,,m}\{1, \dots, m\} 的子数组(也就是包含 [l, r] 的子数组)也必然会包含 m+1m+1。这种情况下,不存在 mex = m+1 的子数组。我们就直接把 p 并入 [l, r](更新 l=min(l,p), r=max(r,p)),然后继续考虑下一个 mex 值。

最有趣的情况是当 p[l, r] 的外部时!假设 p < lp > r 的情况是完全对称的)。 此时,任何满足 mex = m+1 的子数组 [i, j] 必须满足:

  1. 包含 {1,,m}\{1, \dots, m\},所以它必须覆盖 [l, r],即 i <= lj >= r
  2. 不包含 m+1m+1,所以它不能覆盖 p,即 i > p

综合起来,这样的子数组 [i, j] 的左端点 i 只能在 [p+1, l] 中选取,右端点 j 只能在 [r, n] 中选取。

len_base = r - l + 1 是核心区间的长度。 一个子数组 [i, j] 的长度可以表示为 len_base + (l-i) + (j-r)。 令 dl = l-i 为向左延伸的长度,dr = j-r 为向右延伸的长度。

  • dl 的取值范围是 0l - (p+1) = l-p-1
  • dr 的取值范围是 0 to n-r

对于每一个这样的 (dl, dr) 组合,我们都找到了一个 mexm+1 的子数组。它的长度是 len_new = len_base + dl + dr,它的 t 值是 len_new * (m+1)。 我们需要计算所有这些子数组的贡献总和:

Contribution=dl=0lp1dr=0nr(m+1)(lenbase+dl+dr)2(m+1)(lenbase+dl+dr)\text{Contribution} = \sum_{dl=0}^{l-p-1} \sum_{dr=0}^{n-r} (m+1) \cdot (len_{base} + dl + dr) \cdot 2^{(m+1) \cdot (len_{base} + dl + dr)}

这个式子看起来太可怕了!但是不要怕,我来帮你拆解它! 令 M=m+1M = m+1q=2Mq = 2^M

Contribution=Mqlenbasedl=0lp1dr=0nr(lenbase+dl+dr)qdlqdr\text{Contribution} = M \cdot q^{len_{base}} \sum_{dl=0}^{l-p-1} \sum_{dr=0}^{n-r} (len_{base} + dl + dr) \cdot q^{dl} \cdot q^{dr}

这个双重求和可以被分解成三个部分:

dldr(lenbaseqdlqdr+dlqdlqdr+drqdlqdr)\sum_{dl} \sum_{dr} (len_{base} \cdot q^{dl}q^{dr} + dl \cdot q^{dl}q^{dr} + dr \cdot q^{dl}q^{dr})

=lenbase(qdl)(qdr)+(dlqdl)(qdr)+(qdl)(drqdr)= len_{base} (\sum q^{dl})(\sum q^{dr}) + (\sum dl \cdot q^{dl})(\sum q^{dr}) + (\sum q^{dl})(\sum dr \cdot q^{dr})

看呐!我们成功地把复杂的求和变成了几个简单求和的乘积!这些求和是:

  1. 等比数列求和 (Geometric Series): S0(q,N)=k=0NqkS_0(q, N) = \sum_{k=0}^{N} q^k
  2. 差比数列求和 (Arithmetico-geometric Series): S1(q,N)=k=0NkqkS_1(q, N) = \sum_{k=0}^{N} k \cdot q^k

这两个求和都有封闭的数学公式!我们可以写出对应的辅助函数来计算它们。这样,在主循环中,我们每次只需要 O(logMOD)O(\log MOD) 的时间(主要用于快速幂)来计算贡献,总的来说,这部分的复杂度就是 O(NlogMOD)O(N \log MOD)

步骤三:mex = n+1 的情况

最后,别忘了 mex = n+1 的情况。这只可能发生在子数组包含了所有数字 {1,,n}\{1, \dots, n\} 的时候。因为输入是排列,所以只有整个数组 a[1...n] 才满足这个条件。 它的长度是 nnmexn+1n+1,所以 t=n(n+1)t = n \cdot (n+1)。 它的贡献就是 t2tt \cdot 2^t。直接计算加到总答案里就好啦!

总结一下我们的策略:

  1. 预处理 pos 数组。
  2. 用循环计算 mex = 1 的贡献。
  3. 维护一个区间 [l, r],从 m = 1n-1 循环,计算 mex = m+1 的贡献。如果 pos[m+1] 扩展了 [l, r],就用求和公式计算新产生的子数组贡献;否则跳过。
  4. 加上 mex = n+1 的贡献。
  5. 输出总答案!

这样一套组合拳下来,问题就迎刃而解啦!喵~

代码实现

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

// 我是聪明的我,喵~
// 这是我的解题代码,注释很详细哦!

// 定义模数
const long long MOD = 998244353;

// 快速幂,用于计算 a^b % MOD
long long power(long long base, long long exp) {
    long long res = 1;
    // 指数要对 MOD-1 取模 (费马小定理)
    // 也要处理负数指数的情况
    exp %= (MOD - 1);
    if (exp < 0) exp += (MOD - 1);
    
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (static_cast<__int128>(res) * base) % MOD;
        base = (static_cast<__int128>(base) * base) % MOD;
        exp /= 2;
    }
    return res;
}

// 模块逆元,a^{-1} % MOD
long long modInverse(long long n) {
    return power(n, MOD - 2);
}

// 等比数列求和: S_0(q, n) = sum_{k=0 to n} q^k
long long sum_geometric(long long q, int n) {
    if (n < 0) return 0;
    q = (q % MOD + MOD) % MOD;
    if (q == 1) return (n + 1) % MOD;
    
    long long num = (power(q, n + 1) - 1 + MOD) % MOD;
    long long den = modInverse((q - 1 + MOD) % MOD);
    return (num * den) % MOD;
}

// 差比数列求和: S_1(q, n) = sum_{k=0 to n} k * q^k
// 公式: (n*q^{n+2} - (n+1)*q^{n+1} + q) / (q-1)^2
long long sum_arith_geometric(long long q, int n) {
    if (n < 0) return 0;
    q = (q % MOD + MOD) % MOD;
    if (q == 1) {
        long long n_ll = n;
        long long term1 = n_ll * (n_ll + 1) % MOD;
        long long term2 = modInverse(2);
        return (term1 * term2) % MOD;
    }

    long long inv_q_minus_1 = modInverse((q - 1 + MOD) % MOD);
    long long inv_q_minus_1_sq = (inv_q_minus_1 * inv_q_minus_1) % MOD;

    long long term1 = (static_cast<__int128>(n) * power(q, n + 2)) % MOD;
    long long term2 = (static_cast<__int128>(n + 1) * power(q, n + 1)) % MOD;
    
    long long num = (term1 - term2 + MOD) % MOD;
    num = (num + q) % MOD;
    
    return (num * inv_q_minus_1_sq) % MOD;
}


int main() {
    // 加速输入输出,让程序跑得更快,喵~
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int n;
    std::cin >> n;
    std::vector<int> a(n + 1), pos(n + 1);
    for (int i = 1; i <= n; ++i) {
        std::cin >> a[i];
        pos[a[i]] = i;
    }

    long long total_ans = 0;

    // --- Part 1: mex = 1 ---
    // 在 pos[1] 左侧的子数组
    int left_len = pos[1] - 1;
    if (left_len > 0) {
        long long mex = 1;
        long long q = power(2, mex);
        for (int k = 1; k <= left_len; ++k) {
            long long len = k;
            long long count = left_len - k + 1;
            long long t = len * mex; // t = k
            long long term_val = (t % MOD * power(2, t)) % MOD;
            total_ans = (total_ans + (count * term_val) % MOD) % MOD;
        }
    }
    // 在 pos[1] 右侧的子数组
    int right_len = n - pos[1];
    if (right_len > 0) {
        long long mex = 1;
        long long q = power(2, mex);
        for (int k = 1; k <= right_len; ++k) {
            long long len = k;
            long long count = right_len - k + 1;
            long long t = len * mex; // t = k
            long long term_val = (t % MOD * power(2, t)) % MOD;
            total_ans = (total_ans + (count * term_val) % MOD) % MOD;
        }
    }

    // --- Part 2: mex = 2 to n ---
    int l = pos[1], r = pos[1];
    for (int m = 1; m < n; ++m) {
        int mex = m + 1;
        int p = pos[mex];

        // 如果 p 已经在 [l,r] 内部,则不存在 mex = m+1 的子数组
        if (p > l && p < r) {
            continue;
        }
        
        long long current_len = r - l + 1;
        long long q = power(2, mex);

        int L_ext_count, R_ext_count;
        if (p < l) {
            L_ext_count = l - p - 1;
            R_ext_count = n - r;
        } else { // p > r
            L_ext_count = l - 1;
            R_ext_count = p - r - 1;
        }
        
        long long s0_L = sum_geometric(q, L_ext_count);
        long long s0_R = sum_geometric(q, R_ext_count);
        long long s1_L = sum_arith_geometric(q, L_ext_count);
        long long s1_R = sum_arith_geometric(q, R_ext_count);

        long long term_len = (current_len % MOD * s0_L % MOD * s0_R) % MOD;
        long long term_dl = (s1_L * s0_R) % MOD;
        long long term_dr = (s0_L * s1_R) % MOD;

        long long sum_of_lens = (term_len + term_dl + term_dr) % MOD;
        
        long long t_base_exp = current_len * mex;
        long long factor = (mex % MOD * power(2, t_base_exp)) % MOD;

        long long contribution = (factor * sum_of_lens) % MOD;
        total_ans = (total_ans + contribution) % MOD;
        
        // 更新 [l,r] 以包含新的数字
        l = std::min(l, p);
        r = std::max(r, p);
    }

    // --- Part 3: mex = n + 1 ---
    long long final_len = n;
    long long final_mex = n + 1;
    long long t_final = final_len * final_mex;
    long long final_term = (t_final % MOD * power(2, t_final)) % MOD;
    total_ans = (total_ans + final_term) % MOD;

    std::cout << total_ans << std::endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(NlogMOD)O(N \log MOD)

    • 预处理 pos 数组需要 O(N)O(N)
    • 计算 mex=1 贡献的部分,我们用了两个循环,总共是 O(N)O(N)
    • 主循环从 m=1n-1,共 O(N)O(N) 次迭代。在每次迭代中,我们调用了 power, sum_geometric, sum_arith_geometric 等函数。这些函数内部都依赖于 power 函数,其复杂度是 O(logMOD)O(\log MOD)。因此,主循环的总复杂度是 O(NlogMOD)O(N \log MOD)
    • 最后计算 mex=n+1 的贡献是 O(logMOD)O(\log MOD)
    • 所以,总的时间复杂度由主循环决定,为 O(NlogMOD)O(N \log MOD),对于本题的数据范围来说非常高效,喵!
  • 空间复杂度: O(N)O(N)

    • 我们主要使用了 a 数组和 pos 数组来存储输入和数字位置,它们的大小都是 N+1N+1。所以空间复杂度是 O(N)O(N)

知识点总结

这道题真是一次愉快的思维探险呢!它融合了多种算法思想,我们来总结一下吧:

  1. 贡献法: 这是解决复杂求和问题的核心思想。当直接枚举求和项困难时,可以反过来枚举每个元素的“值”(在这里是mex),然后计算它对总和的贡献。
  2. 增量思想: 我们不是一次性考虑所有情况,而是通过 mex11 递增到 n+1n+1,逐步构建和求解问题。每一步都基于上一步的结果,这让问题变得更加清晰可控。
  3. 求和公式: 解决本题的关键在于将复杂的双重求和分解,并识别出其中的等比数列和差比数列结构。掌握这些基础的求和公式,可以在处理组合计数和数学问题时派上大用场!
  4. 模运算: 在处理大数问题时,别忘了所有的加减乘除都要在模意义下进行。特别是除法,要用乘以乘法逆元来代替。快速幂是计算模逆元和模下大数幂的基础工具,一定要熟练掌握哦!

希望这篇题解能帮到你,喵~!如果还有不懂的地方,随时可以再来问我哦!一起加油,成为更厉害的算法大师吧!