Skip to content

math - 题解

标签与难度

标签: 数学, 组合数学, 线性代数, 高斯消元, 线性递推, 数论 难度: 2500

题目大意喵~

主人你好呀,喵~ 这是一道看起来超级复杂的数学题,但是别怕,我会陪你一起把它解决掉的!

题目给了我们一个整数 nn,然后定义了一个函数 f(i)f(i)

f(i)=j=0i(1)j(i+j)!(ij)(ni+j)f(i) = \sum_{j=0}^{i} (-1)^j (i+j)! \binom{i}{j} \binom{n}{i+j}

我们的任务是,计算出所有 f(i)f(i) (从 i=0i=0nn) 在模 10000000071000000007 意义下的值,然后把它们全部按位异或(XOR)起来,得到最终的答案,的说。

也就是要求这个式子的值:

i=0n(f(i)(mod1000000007))\bigoplus_{i=0}^{n} (f(i) \pmod{1000000007})

这里的 \bigoplus 符号就是异或和的意思哦。

解题思路分析

呜喵... 这个 f(i)f(i) 的公式看起来就像一团缠在一起的毛线球,好复杂呀!里面有阶乘、组合数,还有一个求和符号。如果我们要对每个 ii (从 00nn) 都老老实实地按照公式计算 f(i)f(i),会怎么样呢?

对于每个 f(i)f(i),我们都需要做一个 jj00ii 的循环。所以总的计算量大概是 O(n2)O(n^2) 级别的。考虑到 nn 可能很大(比如到 10710^7),O(n2)O(n^2) 的算法肯定会超时,会被关进小黑屋的,喵~

那该怎么办呢?直接计算行不通,我们得找点捷径!

神奇的递推关系

在算法竞赛中,遇到这种复杂的组合数学求和式子,一个非常强大的魔法就是——寻找线性递推关系

很多看起来很复杂的序列 a0,a1,a2,a_0, a_1, a_2, \dots,实际上都满足一个递推关系。最简单的就像斐波那契数列 Fn=Fn1+Fn2F_n = F_{n-1} + F_{n-2}。而我们这道题的 f(i)f(i),可能也满足一个更复杂的递推关系,它的系数可能不是常数,而是关于 ii 的多项式。这种序列被称为 P-递归序列整性序列

我们可以大胆地猜测,f(i)f(i) 满足一个这样的递推关系:

f(i)=k=1dCk(i)f(ik)f(i) = \sum_{k=1}^{d} C_k(i) f(i-k)

其中,dd 是递推的阶数(表示 f(i)f(i) 和前面多少项有关),Ck(i)C_k(i) 是一些关于 ii 的简单多项式,比如 Ck(i)=aki2+bki+ckC_k(i) = a_k i^2 + b_k i + c_k

怎么找到这个“魔法公式”呢?

这就像一个侦探游戏!我们不知道递推公式的具体模样(不知道阶数 dd 和多项式系数 ak,bk,cka_k, b_k, c_k 是多少),但我们可以通过已有的线索来把它推理出来!

线索就是序列 f(i)f(i) 的前几项。我们可以先用最原始、最笨的方法计算出 f(0),f(1),f(2),,f(M)f(0), f(1), f(2), \dots, f(M),这里 MM 是一个不大不小的数,比如 20。这个计算量不大,完全可以接受。

有了这些初始值,我们就可以开始构建方程了。假设我们猜递推阶数是 d=3d=3,多项式次数是 p=2p=2。那么递推公式就是:

f(i)=(a1i2+b1i+c1)f(i1)+(a2i2+b2i+c2)f(i2)+(a3i2+b3i+c3)f(i3)f(i) = (a_1 i^2 + b_1 i + c_1)f(i-1) + (a_2 i^2 + b_2 i + c_2)f(i-2) + (a_3 i^2 + b_3 i + c_3)f(i-3)

这里面有 3×3=93 \times 3 = 9 个我们不知道的系数 (a1,b1,c1,a2,,c3)(a_1, b_1, c_1, a_2, \dots, c_3)。为了解出这9个未知数,我们就需要9个独立的线性方程。

怎么得到方程呢?很简单!我们把已知的 ff 值代进去。

  • i=3i=3 时,我们得到一个关于系数的方程。
  • i=4i=4 时,我们得到第二个方程。
  • ...
  • i=11i=11 时,我们就得到了第 113+1=911-3+1=9 个方程。

这样,我们就得到了一个包含9个未知数和9个方程的线性方程组。解这个方程组,我们就能得到所有的系数 ak,bk,cka_k, b_k, c_k!解线性方程组的经典方法就是高斯消元法,喵~

完整的解题步骤

好啦,思路已经清晰了,我们来总结一下具体的步骤吧:

  1. 预处理:因为要大量计算组合数,所以我们需要预处理阶乘和阶乘的逆元,这样就能 O(1)O(1) 计算 (nk)\binom{n}{k} 啦。
  2. 计算初始值:写一个函数,用题目给的原始公式,暴力计算出 f(0),f(1),,f(19)f(0), f(1), \dots, f(19) 的值。这个范围足够我们建立方程组了。
  3. 建立方程组:根据我们猜测的递推形式(比如阶数为3,多项式次数为2),利用 f(3),,f(11)f(3), \dots, f(11) (或者更多) 的值,建立一个线性方程组。
    • 比如,对于每个 i[d,M1]i \in [d, M-1],我们都有一个方程:

      k=1d(aki2+bki+ck)f(ik)f(i)=0\sum_{k=1}^{d} (a_k i^2 + b_k i + c_k)f(i-k) - f(i) = 0

      整理一下,就是关于 ak,bk,cka_k, b_k, c_k 的线性方程。
  4. 高斯消元:用高斯消元法解出这个方程组,得到我们梦寐以求的递推公式的系数。
  5. 快速计算:有了递推公式,我们就可以从 f(19)f(19) 开始,一路递推出 f(20),f(21),,f(n)f(20), f(21), \dots, f(n)。每计算一项都只需要常数时间,所以这部分的复杂度是 O(n)O(n)
  6. 统计答案:在计算 f(i)f(i) 的过程中,同时将它们的值(对 109+710^9+7 取模后)进行异或累加。

这样,整个算法的瓶颈就在于预处理和最后的递推计算,总时间复杂度是 O(n)O(n),空间复杂度也是 O(n)O(n)(用来存 f(i)f(i) 的值),完美地解决了问题,喵~

代码实现

下面是我根据上面的思路,精心重构的一份代码。注释写得很详细,希望能帮助主人更好地理解哦!

cpp
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

// 使用 long long 防止计算过程中溢出
using ll = long long;

// 模数
const int MOD = 1000000007;
// 递推阶数
const int RECURRENCE_ORDER = 3;
// 多项式次数
const int POLY_DEGREE = 2;
// 解方程需要的初始项数量
const int INITIAL_TERMS_COUNT = 20;

// 模块化整数类,处理取模运算,喵~
struct ModInt {
    int val;

    ModInt(ll v = 0) {
        v %= MOD;
        if (v < 0) v += MOD;
        val = v;
    }

    ModInt& operator+=(const ModInt& other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    ModInt& operator-=(const ModInt& other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
    ModInt& operator*=(const ModInt& other) {
        val = (ll)val * other.val % MOD;
        return *this;
    }

    // 快速幂求逆元
    ModInt inv() const {
        return power(*this, MOD - 2);
    }

    ModInt& operator/=(const ModInt& other) {
        return *this *= other.inv();
    }

    friend ModInt operator+(ModInt a, const ModInt& b) { return a += b; }
    friend ModInt operator-(ModInt a, const ModInt& b) { return a -= b; }
    friend ModInt operator*(ModInt a, const ModInt& b) { return a *= b; }
    friend ModInt operator/(ModInt a, const ModInt& b) { return a /= b; }

    static ModInt power(ModInt base, ll exp) {
        ModInt res = 1;
        while (exp > 0) {
            if (exp % 2 == 1) res *= base;
            base *= base;
            exp /= 2;
        }
        return res;
    }
};

// 组合数工具,需要预先计算阶乘
struct Combinatorics {
    vector<ModInt> fact, invFact;

    Combinatorics(int n) {
        fact.resize(n + 1);
        invFact.resize(n + 1);
        fact[0] = 1;
        invFact[0] = 1;
        for (int i = 1; i <= n; ++i) {
            fact[i] = fact[i - 1] * i;
        }
        invFact[n] = fact[n].inv();
        for (int i = n - 1; i >= 1; --i) {
            invFact[i] = invFact[i + 1] * (i + 1);
        }
    }

    ModInt C(int n, int k) {
        if (k < 0 || k > n) {
            return 0;
        }
        return fact[n] * invFact[k] * invFact[n - k];
    }
};

// 全局组合数工具实例
unique_ptr<Combinatorics> comb;

// 按照原始定义计算 f(i)
ModInt calculate_f_directly(int i, int n) {
    ModInt res = 0;
    for (int j = 0; j <= i; ++j) {
        if (i + j > n) break; // C(n, i+j) 会是 0
        ModInt term = comb->fact[i + j] * comb->C(i, j) * comb->C(n, i + j);
        if (j % 2 == 1) {
            res -= term;
        } else {
            res += term;
        }
    }
    return res;
}

// 高斯消元解线性方程组 Ax = b
// A 是系数矩阵,b 是结果向量
vector<ModInt> gaussian_elimination(vector<vector<ModInt>>& A, vector<ModInt>& b) {
    int n = A.size();
    for (int i = 0; i < n; ++i) {
        // 找到主元最大的行
        int pivot = i;
        for (int j = i + 1; j < n; ++j) {
            if (A[j][i].val > A[pivot][i].val) {
                pivot = j;
            }
        }
        swap(A[i], A[pivot]);
        swap(b[i], b[pivot]);

        // 把主元变成 1
        ModInt inv_pivot = A[i][i].inv();
        for (int j = i; j < n; ++j) {
            A[i][j] *= inv_pivot;
        }
        b[i] *= inv_pivot;

        // 消去其他行的第 i 个变量
        for (int j = 0; j < n; ++j) {
            if (i != j) {
                ModInt factor = A[j][i];
                for (int k = i; k < n; ++k) {
                    A[j][k] -= factor * A[i][k];
                }
                b[j] -= factor * b[i];
            }
        }
    }
    return b;
}

int main() {
    // 加速输入输出,喵~
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    // 预处理阶乘,最大可能需要 2n
    comb = make_unique<Combinatorics>(2 * max(n, INITIAL_TERMS_COUNT));

    vector<ModInt> f(n + 1);
    int initial_count = min(n, INITIAL_TERMS_COUNT - 1);

    // 1. 计算初始值
    for (int i = 0; i <= initial_count; ++i) {
        f[i] = calculate_f_directly(i, n);
    }
    
    // 如果 n 很小,直接计算完了
    if (n < INITIAL_TERMS_COUNT) {
        int xor_sum = 0;
        for (int i = 0; i <= n; ++i) {
            xor_sum ^= f[i].val;
        }
        cout << xor_sum << endl;
        return 0;
    }

    // 2. 建立并解方程组来找到递推系数
    const int num_unknowns = RECURRENCE_ORDER * (POLY_DEGREE + 1);
    vector<vector<ModInt>> A(num_unknowns, vector<ModInt>(num_unknowns));
    vector<ModInt> b(num_unknowns);

    for (int row = 0; row < num_unknowns; ++row) {
        int i = RECURRENCE_ORDER + row;
        b[row] = f[i];
        int current_unknown = 0;
        for (int k = 1; k <= RECURRENCE_ORDER; ++k) {
            ModInt i_poly = 1;
            for (int p = 0; p <= POLY_DEGREE; ++p) {
                A[row][current_unknown++] = f[i - k] * i_poly;
                i_poly *= i;
            }
        }
    }
    
    vector<ModInt> coeffs = gaussian_elimination(A, b);

    // 3. 使用递推公式计算剩下的 f[i]
    for (int i = INITIAL_TERMS_COUNT; i <= n; ++i) {
        f[i] = 0;
        int coeff_idx = 0;
        for (int k = 1; k <= RECURRENCE_ORDER; ++k) {
            ModInt poly_val = 0;
            ModInt i_poly = 1;
            for (int p = 0; p <= POLY_DEGREE; ++p) {
                poly_val += coeffs[coeff_idx++] * i_poly;
                i_poly *= i;
            }
            f[i] += poly_val * f[i - k];
        }
    }

    // 4. 计算异或和
    int xor_sum = 0;
    for (int i = 0; i <= n; ++i) {
        xor_sum ^= f[i].val;
    }

    cout << xor_sum << endl;

    return 0;
}

复杂度分析

  • 时间复杂度: O(N+C3)O(N + C^3)

    • 预处理阶乘和逆元需要 O(N)O(N) 的时间。
    • 计算初始的 INITIAL_TERMS_COUNT (我们设为20) 个 f(i)f(i) 值,每个值计算需要 O(i)O(i),总共是 O(C2)O(C^2),其中 CC 是个小常数 (20)。
    • 高斯消元解一个 k×kk \times k 的方程组需要 O(k3)O(k^3) 的时间。我们的 kkRECURRENCE_ORDER * (POLY_DEGREE + 1),也是一个很小的常数。
    • 最后,使用递推公式计算剩下的 f(i)f(i) 值,需要 O(N)O(N) 的时间。
    • 所以总的时间复杂度由最大的部分决定,即 O(N)O(N)
  • 空间复杂度: O(N)O(N)

    • 我们需要一个大小为 N+1N+1 的数组来存储所有的 f(i)f(i) 值。
    • 阶乘和逆元数组也需要 O(N)O(N) 的空间。
    • 因此,总的空间复杂度是 O(N)O(N)

知识点总结

这道题真是一次有趣的冒险呢,喵~ 我们来总结一下学到的知识点吧:

  1. 组合恒等式问题: 面对复杂的组合求和式,要意识到直接计算可能行不通,需要寻找更巧妙的方法。
  2. P-递归序列 (整性序列): 学习到了一个重要的思想,即许多组合序列都满足一个系数为多项式的线性递推关系。
  3. 通过初始项求解递推关系: 这是一个非常实用的“黑盒”技巧。通过计算序列的前几项,建立线性方程组,然后用高斯消元法解出递推系数。
  4. 高斯消元: 这是线性代数中的基本工具,用于求解线性方程组,在很多问题中都有应用。
  5. 模块化算术: 在处理大数和模运算时,封装一个 ModInt 类能让代码更清晰、更不容易出错。

希望我的讲解对主人有帮助!下次遇到难题,也请不要害怕,我们一起解决它,喵~!