在算法题中,有一种题型:计算排列组合数,并使用乘法逆元实现对结果的取模运算。

由于计算机的编码方式,形如 (a * b) % base 这样的运算,乘法的结果可能导致溢出,我们希望找到一种技巧,能够化简这种表达式,避免溢出同时得到结果。

实际上,乘法逆元是应用在除法求模上的。因为加法、减法和乘法的求模都比较简单,但是乘法的取模容易造成溢出。

乘法逆元介绍

a 的逆元 是 $a^{(p-2)}$。

注意:乘法逆元不一定是存在的。a 存在乘法逆元的充要条件是 a 与模数 p 互质。当模数 p 为质数时,$a^{(p-2)}$ 即为 a 的乘法逆元。

当我们要计算一大串连续的阶乘的逆元时,采用费马小定理或扩展欧几里得算法就有可能超时,所以我们必须采用一个更快的算法。

Alt text

from 乘法逆元通俗易懂的理解方法

从上图可以看出,除法的取模是不满足分配律的。

所以 (b * a) mod p 等价于 (b / a 的乘法逆元) mod p

这样就可以把 乘法 转换为 除法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 乘法逆元模板
private static final int MOD = (int) 1e9 + 7;

private int mod(long numerator, long denominator) {
return (int) (numerator * quickPow(denominator, MOD - 2) % MOD);
}

private long quickPow(long x, int n) {
long res = 1L;
for (; n > 0; n /= 2) {
if (n % 2 > 0)
res = res * x % MOD;
x = x * x % MOD;
}
return res;
}

快速幂函数

快速幂通过每次将 n 除以 2,使得时间复杂度为达到 O(log n)。

1
2
3
4
5
6
7
8
private long pow(long x, int n) {
var res = 1L;
for (; n > 0; n /= 2) {
if (n % 2 > 0) res = res * x % MOD;
x = x * x % MOD;
}
return res;
}

例题

2514. 统计同位异构字符串数目

1
2
3
4
5
给你一个字符串 s ,它包含一个或者多个单词。单词之间用单个空格 ' ' 隔开。

如果字符串 t 中第 i 个单词是 s 中第 i 个单词的一个 排列 ,那么我们称字符串 t 是字符串 s 的同位异构字符串。

比方说,"acb dfe" 是 "abc def" 的同位异构字符串,但是 "def cab" 和 "adc bef" 不是。请你返回 s 的同位异构字符串的数目,由于答案可能很大,请你将它对 $10^9 + 7$ 取余 后返回。

对于一个长为 n 的单词,其全排列的个数为 n!,但由于相同的字母不做区分,所以如果有 x 个字母 a,还需要除以这些 a 的全排列的个数,即 x!,对于其余字母同理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution {
public int countAnagrams(String s) {
String[] strs = s.split(" ");
long numerator = 1L; // 分子:所有字母的全排列
long denominator = 1L; // 分母:每个字母出现的次数的全排列
for (String str: strs) {
int[] cnts = new int[26];
int len = 0;
for (char c: str.toCharArray()) {
cnts[c - 'a']++;
len++;
denominator = denominator * cnts[c - 'a'] % MOD;
numerator = numerator * len % MOD;
}
}
return (int)mod(numerator, denominator);
}

private static final int MOD = (int) 1e9 + 7;

private int mod(long numerator, long denominator) {
return (int) (numerator * quickPow(denominator, MOD - 2) % MOD);
}

private long quickPow(long x, int n) {
long res = 1L;
for (; n > 0; n /= 2) {
if (n % 2 > 0)
res = res * x % MOD;
x = x * x % MOD;
}
return res;
}
}

附:求组合数 C(n, m)

1
2
3
4
5
6
7
public long comb(int n, int m) {        
long ans = 1;
for (int x = n - m + 1, y = 1; y <= m; ++x, ++y) {
ans = ans * x / y;
}
return ans;
}