CTC loss

依据RNN网络的性质,每个时刻输出一个字符,RNN的最终输出是字符序列 S S S,需要后处理才能得到标签 T T T。在实际应用中,例如文字识别过程中, S S S T T T的长度是变化的,且不是等长的,那么就需要一种算法来完成对齐操作。CTC算法能够自动地完成 S S S T T T对齐。
CTC loss

假设训练数据集 S S S的空间分布为 D X × Z D_{X \times Z} DX×Z,输入空间 X = ( R ∗ ) ∗ X = \left( \mathbb{R}^{*} \right)^{*} X=(R) m m m维实向量所有序列的集合,目标空间 Z = L ∗ Z=L^{*} Z=L 是有限字符表 L L L中的字符序列集合。对于训练集中的每一个实例,表示为 ( x , z ) \left( x,z\right) (x,z),其中目标序列 z = ( z 1 , ⋯   , z U ) z=\left(z_{1},\cdots,z_{U} \right) z=(z1,,zU)的长度要小于等于输入序列 x = ( x 1 , ⋯   , x T ) x=\left(x_{1},\cdots,x_{T} \right) x=(x1,,xT)的长度。

对于长度为 T T T的输入序列 x x x,定义一个RNN为 N ω : ( R m ) T ↦ ( R n ) T N_{\omega}: \left ( \mathbb{R}^{m} \right )^{T} \mapsto \left ( \mathbb{R}^{n} \right )^{T} Nω:(Rm)T(Rn)T,则RNN的输出序列为 y = N ω ( x ) y = N_{\omega}\left( x \right) y=Nω(x),其中 y k t y_{k}^{t} ykt是在时刻t,输出单元 k k k的**函数值,将此值作为在时刻 t t t,字符 k k k出现的概率,那么 ∑ t = 1 T y k t = 1 \sum_{t=1}^{T} y_{k}^{t} = 1 t=1Tykt=1。定义字符集合为 L ′ T = L ∪ { b l a n k } {L}'^{T}=L \cup \left\{blank\right\} LT=L{blank} π \pi π为路径,假设某一条长度为 6 6 6的路径 π i \pi_{i} πi C C A A T T CCAATT CCAATT ∏ t = 1 T y π t t = y C 1 × y C 2 × y A 3 × y A 4 × y T 5 × y T 6 \prod_{t=1}^{T} y_{\pi_{t}}^{t}= y_{C}^{1}\times y_{C}^{2} \times y_{A}^{3} \times y_{A}^{4} \times y_{T}^{5} \times y_{T}^{6} t=1Tyπtt=yC1×yC2×yA3×yA4×yT5×yT6

p ( π ∣ x ) = ∏ t = 1 T y π t t , ∀ π ∈ L ′ T p\left( \pi | x \right) = \prod_{t=1}^{T} y_{\pi_{t}}^{t}, \forall \pi \in L ^{'^{T}} p(πx)=t=1Tyπtt,πLT
接下来就需要定义一个多对一的映射 B : L ′ T ↦ L ≤ T B:{L}'^{T}\mapsto L^{\leq T} B:LTLT,目的是为了合并有相同输出的路径。举个例子,我们定一个规则:仅仅合并两个‘-’间多余的字符并且移除所有的‘-’,那么 B ( a − a b − ) = B ( − a a − − a b b ) = a a b B\left ( a-ab- \right ) = B\left ( -aa--abb \right )=aab B(aab)=B(aaabb)=aab,也就是同一个输出有不同的路径。那么对于一个特定的字符序列 l ∈ L ≤ T l \in L^{\leq T} lLT,它的条件概率就是与它相关的路径概率之和。
p ( l ∣ x ) = ∑ π ∈ B − 1 ( l ) p ( π ∣ x ) p\left ( l |x \right ) = \sum_{\pi \in B^{-1}\left ( l \right )} p \left ( \pi | x \right ) p(lx)=πB1(l)p(πx)
其中 B − 1 ( l ) B^{-1}\left( l\right) B1(l)表示经过 B B B变换后输出是 l l l的所有路径的集合。

为了能够有效地计算 p ( l ∣ x ) p\left ( l |x \right ) p(lx),借鉴了隐马尔科夫HMM的前向-后算计算思路。
CTC loss

前向

对于长度为 r r r的序列 q q q,定义 q 1 : p q_{1:p} q1:p q r − p : r q_{r-p:r} qrp:r分别作为序列 q q q的前 p p p标识和后 p p p标识。对于一个字符序列 l l l,定义它的前向变量为 α t ( s ) \alpha_{t}\left( s \right) αt(s)为在时刻 t t t l 1 : s l_{1:s} l1:s的总概率,用公式表示为:
α t ( s ) = ∑ π ∈ N T B ( π 1 : t ) = l 1 : s ∏ t ′ = 1 t y π t ′ t ′ \alpha_{t} \left ( s \right ) = \sum _{\begin{matrix} \pi \in N^{T} \\ B \left ( \pi_{1:t} \right ) = l_{1:s} \end{matrix}} \prod _{{t}'=1}^{t} y_{\pi_{{t}'}}^{{t}'} αt(s)=πNTB(π1:t)=l1:st=1tyπtt
l l l的最前面和最后面插入空格,以及每两个字符的中间插入空格,得到新的标签序列,记为 l ′ {l}' l,那么 l ′ {l}' l的长度将会变成 2 ∣ l ∣ + 1 2\left | l \right | + 1 2l+1。例如字符序列 l l l c a t cat cat,我们用 − - 表示空格,那么新的字符序列 l ′ {l}' l则为 − c − a − t − -c-a-t- cat。为了能够计算 l ′ {l}' l前缀的概率,我们允许空格和字符之间可以转移,还有任何独立的字符之间可以转移。论文中定义所有的前缀可以从一个空格或者字符开始,也就说初始化可以定义为:
α 1 ( 1 ) = y b 1 α 1 ( 2 ) = y l 1 1 α 1 ( s ) = 0 , ∀ s > 2 \begin{matrix} \alpha_{1}\left( 1 \right) = y_{b}^{1} \\ \alpha_{1}\left( 2 \right) = y_{l_{1}}^{1} \\ \alpha_{1}\left( s \right) = 0, \forall s > 2 \end{matrix} α1(1)=yb1α1(2)=yl11α1(s)=0,s>2
α t ( s ) \alpha_{t} \left( s \right) αt(s)可以递推得到,分为两种情况:

  1. 如果 l s ′ {l}'_{s} ls为空或者 l s − 2 ′ = l s ′ {l}'_{s-2} = {l}'_{s} ls2=ls,那么 α t ( s ) \alpha_{t} \left( s \right) αt(s)只能从 α t − 1 ( s ) \alpha _{t-1}\left ( s \right ) αt1(s) α t − 1 ( s − 1 ) \alpha _{t-1}\left ( s - 1 \right ) αt1(s1)得到。
  2. 如果不满足第一种情况,那么 α t ( s ) \alpha _{t}\left ( s \right ) αt(s)只能从 α t − 1 ( s ) \alpha _{t-1}\left ( s \right ) αt1(s) α t − 1 ( s − 1 ) \alpha _{t-1}\left ( s -1 \right ) αt1(s1) α t − 1 ( s − 2 ) \alpha _{t-1}\left ( s - 2 \right ) αt1(s2)得到。

总公式如下所示:
α t ( s ) = { ( α t − 1 ( s ) + α t − 1 ( s − 1 ) ) y l s ′ t i f      l s ′ = b l a n k    o r    l s − 2 ′ = l s ′ ( α t − 1 ( s ) + α t − 1 ( s − 1 ) + α t − 1 ( s − 2 ) ) y l s ′ t o t h e r s i z e \alpha _{t}\left ( s \right ) = \left\{\begin{matrix} \left ( \alpha _{t-1}\left ( s \right ) + \alpha _{t-1}\left ( s-1 \right ) \right )y_{{l}'_{s}}^{t} \qquad \qquad \qquad & if \; \; {l}'_{s} = blank \; or \; {l}'_{s-2} = {l}'_{s} \\ \left ( \alpha _{t-1}\left ( s \right ) + \alpha _{t-1}\left ( s-1 \right ) + \alpha _{t-1}\left ( s-2 \right ) \right )y_{{l}'_{s}}^{t} & othersize \end{matrix}\right. αt(s)={(αt1(s)+αt1(s1))ylst(αt1(s)+αt1(s1)+αt1(s2))ylstifls=blankorls2=lsothersize
在时刻 T T T时,最后一个字符有可能是空格,有可能也是某个字符,那么字符标签 l l l的概率是这两种概率的综合,公式表示为:
p ( l ∣ x ) = α T ( ∣ l ′ ∣ ) + α T ( ∣ l ′ ∣ − 1 ) p\left ( l | x \right ) = \alpha_{T} \left ( \left | {l}' \right | \right ) + \alpha_{T} \left ( \left | {l}' \right | - 1 \right ) p(lx)=αT(l)+αT(l1)

后向

和前向算法类似,定义 β t ( s ) \beta_{t}\left( s\right) βt(s)在时刻 t t t l s : ∣ l ∣ l_{s : \left | l \right |} ls:l的总概率。
β t ( s ) = ∑ π ∈ N T B ( π t : T ) = l s : ∣ l ∣ ∏ t ′ = t T y π t ′ t ′ \beta _{t} \left ( s \right ) = \sum _{\begin{matrix} \pi \in N^{T} \\ B \left ( \pi_{t:T} \right ) = l_{s:\left | l \right |} \end{matrix}} \prod _{{t}'=t}^{T} y_{\pi_{{t}'}}^{{t}'} βt(s)=πNTB(πt:T)=ls:lt=tTyπtt
初始化为:
β T ( ∣ l ′ ∣ ) = y b T β T ( ∣ l ′ − 1 ∣ ) = y l ∣ l ∣ T β T ( s ) = 0 , ∀ s < ∣ l ′ ∣ − 1 \begin{matrix} \beta_{T}\left ( \left | {l}' \right | \right ) = y_{b}^{T} \qquad \qquad \\ \beta_{T}\left ( \left | {l}' - 1 \right | \right ) = y_{l_{\left | l \right |}}^{T} \qquad \\ \beta_{T}\left( s \right ) = 0 , \forall s < \left | {l}' \right | -1 \end{matrix} βT(l)=ybTβT(l1)=yllTβT(s)=0,s<l1
递推公式为:
β t ( s ) = { ( β t + 1 ( s ) + β t + 1 ( s + 1 ) ) y l s ′ t i f      l s ′ = b l a n k    o r    l s − 2 ′ = l s ′ ( β t + 1 ( s ) + β t + 1 ( s + 1 ) + β t + 1 ( s + 2 ) ) y l s ′ t o t h e r s i z e \beta_{t} \left ( s \right ) = \left\{\begin{matrix} \left ( \beta_{t+1} \left ( s \right ) + \beta_{t+1} \left ( s + 1 \right ) \right ) y_{{l}'_{s}}^{t} \qquad \qquad \qquad & if \; \; {l}'_{s} = blank \; or \; {l}'_{s-2} = {l}'_{s} \\ \left ( \beta_{t+1} \left ( s \right ) + \beta_{t+1} \left ( s + 1 \right ) + \beta_{t+1} \left ( s + 2 \right ) \right ) y_{{l}'_{s}}^{t} & othersize \end{matrix}\right. βt(s)={(βt+1(s)+βt+1(s+1))ylst(βt+1(s)+βt+1(s+1)+βt+1(s+2))ylstifls=blankorls2=lsothersize

应用

Tensorflow中已经给出了CTC的接口:

tf.nn.ctc_loss( labels, logits, label_length, logit_length, logits_time_major=True, unique=None, blank_index=None, name=None )

  • labels: 格式为[batch_size, max_label_seq_length]的张量或者稀疏张量

  • logits:默认是格式为[frames, batch_size, num_lables]的张量,如果logits_time_major=True,格式为[batch_size,batch_size, num_lables]

特点

  1. 使用CTC作为loss的计算方法,训练样本无需对齐。而且CTC使用前向后向算法递推计算,提高了计算速度。
  2. CTC 假设每个时间片是相互独立的,熟悉RNN的朋友们都知道这是稍微不合理的。

CTC 的两种解码方法

上述CTC loss 应用于图像文字识别的训练过程中。

在预测过程中,当输入 x x x,我们希望能够得到使得 p ( l ∣ x ) p\left( l | x \right) p(lx)概率最大的标签 l l l。在序列学习问题中,这个问题被称为解码,在有限的时间内得到条件概率最大的序列$l^{*} $。
l ∗ = a r g m a x p ( l ∣ x ) l^{*} = argmax p\left( l | x \right) l=argmaxp(lx)

假设有字符列表 ( ′ − ′ , ′ A ′ , ′ B ′ ) \left( '-', 'A', 'B'\right) (,A,B),时刻 T = 3 T=3 T=3,并且定义在 t t t时刻时,字符 c c c出现的概率为 P ( c , t ) P\left( c, t\right) P(c,t)。如下表所示,以横轴作为时刻序列,纵轴为字符列表,表格中的数字为概率,我们的目标是在这个二维空间中搜索出概率最大的标签$l^{*} $。

CTC loss

greedy decode

贪心的思想是每次都要最好的,那也就是说每次选取当前时刻的最大概率的字符,最后 T T T个字符串成一个标签。如下图所示, T = 1 T=1 T=1时, P ( ’ − ‘ , 1 ) P\left( ’-‘, 1\right) P(,1)概率最大; T = 2 T=2 T=2时, P ( ’ − ‘ , 2 ) P\left( ’-‘, 2\right) P(,2)概率最大; T = 3 T=3 T=3时, P ( ’ − ‘ , 3 ) P\left( ’-‘, 3\right) P(,3)概率最大。最后的输出标签为 " b l a n k " "blank" "blank" p ( l = " b l a n k " ∣ x ) = 0.5 × 0.4 × 0.6 = 0.12 p\left( l = "blank" | x \right) =0.5 \times 0.4 \times 0.6 = 0.12 p(l="blank"x)=0.5×0.4×0.6=0.12

CTC loss
贪心解码只考虑了一条路线,在CTC算法中,我们曾定义过一个多对一的映射 B B B,合并输出标签相同的路径。如下图所示,当输出标签为 " A " "A" "A"时,有三条路径 " A − − " "A--" "A", " − A − " "-A-" "A" " − − A " "--A" "A",表示为 B ( " A − − " ) = B ( " − A − " ) = B ( " − − A " ) = " A " B\left( "A--" \right) = B\left( "-A-" \right) = B\left( "--A" \right)="A" B("A")=B("A")=B("A")="A",那么$p\left( l = “A” | x \right) $应为这三条路径概率的总和
p ( l = " A " ∣ x ) = p ( l = " A − − " ∣ x ) + p ( l = " − A − " ∣ x ) + p ( l = " − − A " ∣ x ) = 0.198 p\left( l = "A" | x \right) = p\left( l = "A--" | x \right) + p\left( l = "-A-" | x \right) + p\left( l = "--A" | x \right) = 0.198 p(l="A"x)=p(l="A"x)+p(l="A"x)+p(l="A"x)=0.198

p ( l = " A " ∣ x ) = 0.198 p\left( l = "A" | x \right) =0.198 p(l="A"x)=0.198远远大于 p ( l = " b l a n k " ∣ x ) = 0.12 p\left( l = "blank" | x \right) = 0.12 p(l="blank"x)=0.12 " A " "A" "A"更应该成为输出标签。
CTC loss

Beam Search

定义 t t t时刻网络输出序列对应的标签为 s s s的概率 P r ( s , t ) Pr \left( s, t \right) Pr(s,t),定义 P r − ( s , t ) Pr^{-} \left( s, t \right) Pr(s,t) t t t时刻输出空字符的概率, P r + ( s , t ) Pr^{+} \left( s, t \right) Pr+(s,t) t t t时刻输出非空字符的概率。那么 P r ( s , t ) = P r − ( s , t ) + P r + ( s , t ) Pr \left( s, t \right) = Pr^{-} \left( s, t \right) + Pr^{+} \left( s, t \right) Pr(s,t)=Pr(s,t)+Pr+(s,t)。Beam Search每一步搜索选取概率$Pr \left( s, t \right) $最大的W个节点进行扩展,W称为Beam Width。

CTC loss

下面的例子,我们选 W = 2 W=2 W=2。在 T = 0 T=0 T=0时刻,标签为空。

T = 1 T=1 T=1时,标签 " A " , " B " , " b l a n k " "A", "B", "blank" "A","B","blank"的概率如下:
P r ( " A " , 1 ) = 0.2 P r ( " B " , 1 ) = 0.3 P r ( " b l a n k " , 1 ) = 0.5 Pr \left( "A" , 1 \right) = 0.2 \\ Pr \left( "B", 1 \right) = 0.3 \\ Pr \left( "blank", 1 \right) = 0.5 Pr("A",1)=0.2Pr("B",1)=0.3Pr("blank",1)=0.5
标签 " B " "B" "B" " b l a n k " "blank" "blank"的概率最高,以标签 ′ ′ A ′ ′ ''A'' A " B " "B" "B"进行下一步扩展。

T = 2 T=2 T=2,标签 " B B " "BB" "BB"出现的概率为:
P r − ( " B B " , 2 ) = 0 P r + ( " B B " , 2 ) = P r ( " B " , 1 ) ∗ P ( ′ B ′ , 2 ) = 0.09 Pr^{-} \left( "BB", 2 \right) = 0 \\ Pr^{+} \left( "BB", 2 \right) = Pr \left( "B" , 1 \right) * P\left( 'B', 2\right)= 0.09 Pr("BB",2)=0Pr+("BB",2)=Pr("B",1)P(B,2)=0.09
P r ( " B B " , 2 ) = P r − ( " B B " , 2 ) + P r + ( " B B " , 2 ) = 0.09 Pr \left( "BB", 2 \right) = Pr^{-} \left( "BB", 2 \right) + Pr^{+} \left( "BB", 2 \right) = 0.09 Pr("BB",2)=Pr("BB",2)+Pr+("BB",2)=0.09
同理,可计算标签 " B A " , " b l a n k " "BA","blank" "BA""blank"的概率。

然而,当 T = 2 T=2 T=2,要计算标签 " A " "A" "A"的概率时,字符 ′ − ′ '-' ′ A ′ 'A' A都有出现的可能,$Pr^{-} \left( “A”, 2 \right) 不 为 零 , 因 此 标 签 不为零,因此标签 “A”$出现的概率计算如下:
P r − ( " A " , 2 ) = P r ( " A " , 1 ) ∗ P ( ′ − ′ , 2 ) = 0.08 P r + ( " A " , 2 ) = P r ( " b l a n k " , 1 ) ∗ P ( ′ A ′ , 2 ) = 0.15 Pr^{-} \left( "A", 2 \right) = Pr \left( "A" , 1 \right) * P\left( '-', 2\right)= 0.08 \\ Pr^{+} \left( "A", 2 \right) = Pr \left( "blank" , 1 \right) * P\left( 'A', 2\right)= 0.15 Pr("A",2)=Pr("A",1)P(,2)=0.08Pr+("A",2)=Pr("blank",1)P(A,2)=0.15
P r ( " A " , 2 ) = P r + ( " A " , 2 ) + P r − ( " A " , 2 ) = 0.23 Pr \left( "A", 2 \right) = Pr^{+} \left( "A", 2 \right) + Pr^{-} \left( "A", 2 \right) = 0.23 Pr("A",2)=Pr+("A",2)+Pr("A",2)=0.23

同理可计算 P r ( " B " , 2 ) = 0.27 Pr \left( "B", 2 \right) = 0.27 Pr("B",2)=0.27

标签 " A " "A" "A" " B " "B" "B"的概率最高,以标签’ ′ A ′ ′ 'A'' A " B " "B" "B"进行下一步扩展。

按照上面的计算公式,当 T = 3 T=3 T=3时,计算出标签 " A " "A" "A"的概率 0.198 0.198 0.198最高,那么就以 " A " "A" "A"为输出标签。

W = 1 W=1 W=1时,beam search 就是 greedy decode。

Tensorflow 函数

tf.nn.ctc_greedy_decoder(
inputs,
sequence_length,
merge_repeated=True
)

tf.nn.ctc_beam_search_decoder(
inputs,
sequence_length,
beam_width=100,
top_paths=1,
merge_repeated=True
)

CTC loss

参考

  1. Supervised Sequence Labelling with Recurrent Neural Networks
  2. https://xiaodu.io/ctc-explained-part2/
  3. https://distill.pub/2017/ctc/

相关文章: