使用Rabin-Karp算法查找最长回文字符串

利用Rabin-Karp,尝试解决LeetCode005_longest-palindromic-substring问题。Python实现,意义大于性能🙃

Rabin-Karp的基本运用是子字符串查找,因为利用hash值比较,因此也被称作指纹字符串查找算法。比如在文本S找查找长为M的字符串,在已知S[i: i+M]字符串的hash值时,Rabin-Karp可以在常数时间内计算出S[i+1: i+M+1]的hash值,因此可以在线性时间复杂度内,扫描S中所有长为M的子串。

回到LeetCode005_longest-palindromic-substring问题,一个回文字符串,正反的hash值一定相同,然后利用二分查找,可以找到一个最大的子字符串长度。在实现之前,先说明关于hash计算的数学基础。

mod运算的一些性质

Rabin-Karp算法在计算hash值时,主要利用mod操作的几个性质来加速计算

线性组合的分配律

$$
(a+b)\mod n = [(a \mod n) + (b\mod n)]\mod n \\
(ab)\mod n = [(a \mod n)(b\mod n)]\mod n
$$

简单的证明,设$a\mod n=x$, 有$a=pn+x$,同理有$b=qn+y$,所以加法等式替换为
$$
(pn+x+qn+y) \mod n = (x+y) \mod n
$$
这个就自然相等了。实际上mod分配一个还是两个都无所谓。基本的Rabin-Karp算法利用这个性质就可以了。如果要逆向计算字符串的hash值,就必须得借助下面两个性质。

除法定义

$$
\frac ab\mod n=(ab^{-1})\mod n
$$

根据分配率,求解需要$(b^{-1}\mod n)$,它的定义是,设$(b^{-1}\mod n)=I$,则有$bI\equiv1\space(\mod n)$,这里的$I$就被称为b的模反元素。也就是存在关系,$bI-1=pn$。

接下来就是证明,$\frac ab \mod n =aI\mod n$。把上述关系乘上等式左边,变为n的整数倍,余数为0,有$(aI-\frac ab)\mod n = 0$,即可证明。

如何求模反元素

关系$bI-1=pn$换一种写法就是贝祖等式,$bI+(-p)n=1$

然后可以通过拓展欧几里得方法求得$I, p$

这里贴一个wiki上的例子

Z6ZnRH.png

把图中的过程利用递归表达出来,就是所谓的拓展欧几里得算法:在原始的欧几里得算法递归完成后,输出当前贝祖等式的参数。

1
2
3
4
5
6
def exgcd(a, b):
if b == 0:
return 1, 0 # 1 = 1*1 + 0*0
x, y = exgcd(b, a % b)
x, y = y, x - (a//b)*y
return x, y

正反扫描的Rabin-Karp

1
2
3
4
5
6
7
class RabinKarp:
def __init__(self, text, M, R=128, Q=100000037):
self.R = R
self.M = M
self.Q = Q
self.text = text
self.text_inv = text[::-1]

一些符号约定:

  • R,把字符串看作R进制数。比如对于ASCII字符,不妨R=128,”abc”的hash值为:
    $$
    (97\times128^2+98\times128+99) \mod Q
    $$

  • Q,对Q取模,Q必须为质数,保证R、Q互质,使R存在模反元素

  • M,子字符串长度

实现分为正向和逆向两部分,正向就是原始的Rabin-Karp算法,逆向是为了求回文字符串做的增加。

正向

首先计算出起始位置上,长为M的子字符串的hash值——计算M位的R进制数对Q取模。利用mod的分配律,在每次计算时,都可以对Q取模,而不影响结果。

1
2
3
4
5
def dummy_hash(self, s): # 传入text[:M]子字符串,直接计算hash值
h = 0
for c in s:
h = (h * self.R + ord(c)) % self.Q
return h

记第一个子字符串text[:M]的hash值为:
$$
hash(text[:M])=t_0R^{M-1}+t_1R^{M-2}+…+t_{M-1} \mod Q
$$

第二个子字符串text[1:M+1]同样利用mod分配率,可以在常数时间内得到,因为由前一个R进制数,可以在常数时间内计算当前的R进制数——

  1. 减去最高位
  2. 乘以R
  3. 再加上新增的末位

$$
(t_0R^{M-1}+t_1R^{M-2}+…+t_{M-1} - t_0R^{M-1}) \times R +t_M
$$

考虑mod,上式等于:

$$
(((t_0R^{M-1}+t_1R^{M-2}+…+t_{M-1}) \mod Q-(R^{M-1} \mod Q)\times t_0)\times R+t_M) \mod Q
$$

第一项就是上一个字符串的hash值,第二项是最高位的取模,因为$(R^{M-1}\mod Q)$每次都会用到,所以事先计算出来保存。第三项就是多出来的末位。

1
2
3
4
5
6
def __init__(self, text, M, R=128, Q=100000037):
# ...
pow = 1
for _ in range(1, M):
pow = (R * pow) % Q
self.pow = pow # R^(M-1) % Q

然后使用生成器来返回从左往右的长为M的字符串hash值。在实际的实现中,减去最高项时使用又加上了Q,保证始终是正数,让%运算符有正确的结果。

1
2
3
4
5
6
7
8
def gen_hash(self):
txt, pow, R, M, Q = self.text, self.pow, self.R, self.M, self.Q
txt_hash = self.dummy_hash(txt[:M])
yield txt_hash
for i in range(M, len(txt)):
txt_hash = (txt_hash + Q - pow * ord(txt[i-M]) % Q) % Q
txt_hash = (txt_hash * R + ord(txt[i])) % Q
yield txt_hash

逆向

在常数时间内计算出逆序字符串的hash值,和正向的思路相似,也分三步:

  1. 减去最低位
  2. 除以R
  3. 加上最高位

最关键的第二步,除以R并且取模,就需要使用模反元素,把除法变为乘法,继续按照分配率的保证来分解运算。模拟元素这里使用拓展欧几里得算法获得,并在初始化中保存。

1
2
3
4
5
6
7
8
9
10
11
12
13
def __init__(self, text, M, R=128, Q=100000037):
# ...
self.invR, _ = self._exgcd(R, Q)
if self.invR < 0:
self.invR = self.invR % Q + Q

def _exgcd(self, a, b):
"""omit remainder here
"""
if b == 0:
return 1, 0 # 1 = 1*1 + 0*0
x, y = self._exgcd(b, a % b)
return y, x - (a//b)*y

随后再利用生成器,对等地生成逆序子字符串的hash值

1
2
3
4
5
6
7
8
9
10
11
   
def gen_hash_inv(self):
txt, pow, invR, M, Q = self.text_inv, self.pow, self.invR, self.M, self.Q
txt_hash = self.dummy_hash(txt[-M:])
yield txt_hash
for i in range(len(txt)-M-1, -1, -1):
txt_hash -= ord(txt[i+M])
txt_hash = (txt_hash * invR) % Q
txt_hash += ord(txt[i])*pow
txt_hash = txt_hash % Q
yield txt_hash

二分查找

关键思想是根据回文的性质,分为奇数和偶数的情况,分开查找。

如果对于一个回文字符串长为M,只能断言,除去首尾两个字符,剩下的M-2的字符串也是回文字符串。也就是存在一个函数g(x),返回值是True or False,表示在给定文本中,是否存在长为x的回文字符串,一个M使得g(M)为True,那么g(M-2)也为True,而g(M-1)无法确定。因此想要利用二分查找的方法,寻找使g(x)返回值切换的边界$M’$,需要分为奇数和偶数来查找。

首先定义出g(x),用函数can_solve_at(s, L)来表示在文本s中,是否存在长为L的回文字符串,如果存在,返回字串起始位置,否则返回-1。当遇到hash值相同时,再用字符串比较来确定,防止$\frac 1Q$概率的hash值相同的误判。也就是拉斯维加斯算法。

1
2
3
4
5
6
7
8
def can_solve_at(s, L):
rk = RabinKarp(s, L)
i = 0
for h, h_inv in zip(rk.gen_hash(), rk.gen_hash()):
if h == h_inv and s[i:i+L] == ''.join(reversed(s[i:i+L])):
return i
i += 1
return -1

然后利用can_solve_at实现二分查找。这里的的L分为两半,左边为True,右边为False。当猜测的L=mid时为True,说明还可能存在更大的L使can_solve_at也为True,所以lo = mid + 1向右半区收缩。虽然最后并不关心搜索位置,lo的物理意义也不重要,但是也不妨分析一下,和常规的二分g(x)做比较。结果就是lo表示使得can_solve_at为False的最小值。

此外,实现中还有两点可注意:

  1. mid的计算使用了等差数列的二分技巧

  2. 因为求最长回文字符串,所以优先判断2mid + 1的结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def longest_RK(s):
if not s:
return ''
res = s[0]
lo, hi = 1, len(s) // 2
while lo <= hi:
mid = (lo + hi) // 2
idx_even = can_solve_at(s, 2*mid)
idx_odd = can_solve_at(s, 2*mid+1)
if idx_odd >= 0:
idx, L = idx_odd, 2*mid+1
else:
idx, L = idx_even, 2*mid
if idx >= 0:
if L > len(res):
res = s[idx:idx+L]
lo = mid + 1
else:
hi = mid - 1
return res

总结

  1. Rabin-Karp的性能,应用

    理论上Rabin-Karp的性能是稳定的线性级别,算上二分查找,应该也是$O(nlogn)$的解法,比常规的$O(n^2)$解法要好,但是当前实现在LeetCode仅beats 42%。当然和Python本身运行速度有关,排名意义不大。

    Rabin-Karp这种用指纹的查找方法,在思路上还可以更开阔一些,比如可以在二维结构中找寻找特定模式。另外一点好处是它不需要额外的内存空间。

  2. 二分查找的变化

    这里又是二分查找的一种变化,除了g(x)本身的含义改变,连搜索空间也不再是连续自然数。所以需要更灵活地认识二分查找的一些核心,还要再复习一下中位数