CTC loss
依据RNN网络的性质,每个时刻输出一个字符,RNN的最终输出是字符序列
S
S
S,需要后处理才能得到标签
T
T
T。在实际应用中,例如文字识别过程中,
S
S
S和
T
T
T的长度是变化的,且不是等长的,那么就需要一种算法来完成对齐操作。CTC算法能够自动地完成
S
S
S和
T
T
T对齐。
假设训练数据集 S S S的空间分布为 D X × Z D_{X \times Z} DX×Z,输入空间 X = ( R ∗ ) ∗ X = \left( \mathbb{R}^{*} \right)^{*} X=(R∗)∗是 m m m维实向量所有序列的集合,目标空间 Z = L ∗ Z=L^{*} Z=L∗ 是有限字符表 L L L中的字符序列集合。对于训练集中的每一个实例,表示为 ( x , z ) \left( x,z\right) (x,z),其中目标序列 z = ( z 1 , ⋯ , z U ) z=\left(z_{1},\cdots,z_{U} \right) z=(z1,⋯,zU)的长度要小于等于输入序列 x = ( x 1 , ⋯ , x T ) x=\left(x_{1},\cdots,x_{T} \right) x=(x1,⋯,xT)的长度。
对于长度为 T T T的输入序列 x x x,定义一个RNN为 N ω : ( R m ) T ↦ ( R n ) T N_{\omega}: \left ( \mathbb{R}^{m} \right )^{T} \mapsto \left ( \mathbb{R}^{n} \right )^{T} Nω:(Rm)T↦(Rn)T,则RNN的输出序列为 y = N ω ( x ) y = N_{\omega}\left( x \right) y=Nω(x),其中 y k t y_{k}^{t} ykt是在时刻t,输出单元 k k k的**函数值,将此值作为在时刻 t t t,字符 k k k出现的概率,那么 ∑ t = 1 T y k t = 1 \sum_{t=1}^{T} y_{k}^{t} = 1 ∑t=1Tykt=1。定义字符集合为 L ′ T = L ∪ { b l a n k } {L}'^{T}=L \cup \left\{blank\right\} L′T=L∪{blank}, π \pi π为路径,假设某一条长度为 6 6 6的路径 π i \pi_{i} πi为 C C A A T T CCAATT CCAATT, ∏ t = 1 T y π t t = y C 1 × y C 2 × y A 3 × y A 4 × y T 5 × y T 6 \prod_{t=1}^{T} y_{\pi_{t}}^{t}= y_{C}^{1}\times y_{C}^{2} \times y_{A}^{3} \times y_{A}^{4} \times y_{T}^{5} \times y_{T}^{6} ∏t=1Tyπtt=yC1×yC2×yA3×yA4×yT5×yT6。
p
(
π
∣
x
)
=
∏
t
=
1
T
y
π
t
t
,
∀
π
∈
L
′
T
p\left( \pi | x \right) = \prod_{t=1}^{T} y_{\pi_{t}}^{t}, \forall \pi \in L ^{'^{T}}
p(π∣x)=t=1∏Tyπtt,∀π∈L′T
接下来就需要定义一个多对一的映射
B
:
L
′
T
↦
L
≤
T
B:{L}'^{T}\mapsto L^{\leq T}
B:L′T↦L≤T,目的是为了合并有相同输出的路径。举个例子,我们定一个规则:仅仅合并两个‘-’间多余的字符并且移除所有的‘-’,那么
B
(
a
−
a
b
−
)
=
B
(
−
a
a
−
−
a
b
b
)
=
a
a
b
B\left ( a-ab- \right ) = B\left ( -aa--abb \right )=aab
B(a−ab−)=B(−aa−−abb)=aab,也就是同一个输出有不同的路径。那么对于一个特定的字符序列
l
∈
L
≤
T
l \in L^{\leq T}
l∈L≤T,它的条件概率就是与它相关的路径概率之和。
p
(
l
∣
x
)
=
∑
π
∈
B
−
1
(
l
)
p
(
π
∣
x
)
p\left ( l |x \right ) = \sum_{\pi \in B^{-1}\left ( l \right )} p \left ( \pi | x \right )
p(l∣x)=π∈B−1(l)∑p(π∣x)
其中
B
−
1
(
l
)
B^{-1}\left( l\right)
B−1(l)表示经过
B
B
B变换后输出是
l
l
l的所有路径的集合。
为了能够有效地计算
p
(
l
∣
x
)
p\left ( l |x \right )
p(l∣x),借鉴了隐马尔科夫HMM的前向-后算计算思路。
前向
对于长度为
r
r
r的序列
q
q
q,定义
q
1
:
p
q_{1:p}
q1:p和
q
r
−
p
:
r
q_{r-p:r}
qr−p:r分别作为序列
q
q
q的前
p
p
p标识和后
p
p
p标识。对于一个字符序列
l
l
l,定义它的前向变量为
α
t
(
s
)
\alpha_{t}\left( s \right)
αt(s)为在时刻
t
t
t时
l
1
:
s
l_{1:s}
l1:s的总概率,用公式表示为:
α
t
(
s
)
=
∑
π
∈
N
T
B
(
π
1
:
t
)
=
l
1
:
s
∏
t
′
=
1
t
y
π
t
′
t
′
\alpha_{t} \left ( s \right ) = \sum _{\begin{matrix} \pi \in N^{T} \\ B \left ( \pi_{1:t} \right ) = l_{1:s} \end{matrix}} \prod _{{t}'=1}^{t} y_{\pi_{{t}'}}^{{t}'}
αt(s)=π∈NTB(π1:t)=l1:s∑t′=1∏tyπt′t′
在
l
l
l的最前面和最后面插入空格,以及每两个字符的中间插入空格,得到新的标签序列,记为
l
′
{l}'
l′,那么
l
′
{l}'
l′的长度将会变成
2
∣
l
∣
+
1
2\left | l \right | + 1
2∣l∣+1。例如字符序列
l
l
l为
c
a
t
cat
cat,我们用
−
-
−表示空格,那么新的字符序列
l
′
{l}'
l′则为
−
c
−
a
−
t
−
-c-a-t-
−c−a−t−。为了能够计算
l
′
{l}'
l′前缀的概率,我们允许空格和字符之间可以转移,还有任何独立的字符之间可以转移。论文中定义所有的前缀可以从一个空格或者字符开始,也就说初始化可以定义为:
α
1
(
1
)
=
y
b
1
α
1
(
2
)
=
y
l
1
1
α
1
(
s
)
=
0
,
∀
s
>
2
\begin{matrix} \alpha_{1}\left( 1 \right) = y_{b}^{1} \\ \alpha_{1}\left( 2 \right) = y_{l_{1}}^{1} \\ \alpha_{1}\left( s \right) = 0, \forall s > 2 \end{matrix}
α1(1)=yb1α1(2)=yl11α1(s)=0,∀s>2
α
t
(
s
)
\alpha_{t} \left( s \right)
αt(s)可以递推得到,分为两种情况:
- 如果 l s ′ {l}'_{s} ls′为空或者 l s − 2 ′ = l s ′ {l}'_{s-2} = {l}'_{s} ls−2′=ls′,那么 α t ( s ) \alpha_{t} \left( s \right) αt(s)只能从 α t − 1 ( s ) \alpha _{t-1}\left ( s \right ) αt−1(s)和 α t − 1 ( s − 1 ) \alpha _{t-1}\left ( s - 1 \right ) αt−1(s−1)得到。
- 如果不满足第一种情况,那么 α t ( s ) \alpha _{t}\left ( s \right ) αt(s)只能从 α t − 1 ( s ) \alpha _{t-1}\left ( s \right ) αt−1(s), α t − 1 ( s − 1 ) \alpha _{t-1}\left ( s -1 \right ) αt−1(s−1)和 α t − 1 ( s − 2 ) \alpha _{t-1}\left ( s - 2 \right ) αt−1(s−2)得到。
总公式如下所示:
α
t
(
s
)
=
{
(
α
t
−
1
(
s
)
+
α
t
−
1
(
s
−
1
)
)
y
l
s
′
t
i
f
l
s
′
=
b
l
a
n
k
o
r
l
s
−
2
′
=
l
s
′
(
α
t
−
1
(
s
)
+
α
t
−
1
(
s
−
1
)
+
α
t
−
1
(
s
−
2
)
)
y
l
s
′
t
o
t
h
e
r
s
i
z
e
\alpha _{t}\left ( s \right ) = \left\{\begin{matrix} \left ( \alpha _{t-1}\left ( s \right ) + \alpha _{t-1}\left ( s-1 \right ) \right )y_{{l}'_{s}}^{t} \qquad \qquad \qquad & if \; \; {l}'_{s} = blank \; or \; {l}'_{s-2} = {l}'_{s} \\ \left ( \alpha _{t-1}\left ( s \right ) + \alpha _{t-1}\left ( s-1 \right ) + \alpha _{t-1}\left ( s-2 \right ) \right )y_{{l}'_{s}}^{t} & othersize \end{matrix}\right.
αt(s)={(αt−1(s)+αt−1(s−1))yls′t(αt−1(s)+αt−1(s−1)+αt−1(s−2))yls′tifls′=blankorls−2′=ls′othersize
在时刻
T
T
T时,最后一个字符有可能是空格,有可能也是某个字符,那么字符标签
l
l
l的概率是这两种概率的综合,公式表示为:
p
(
l
∣
x
)
=
α
T
(
∣
l
′
∣
)
+
α
T
(
∣
l
′
∣
−
1
)
p\left ( l | x \right ) = \alpha_{T} \left ( \left | {l}' \right | \right ) + \alpha_{T} \left ( \left | {l}' \right | - 1 \right )
p(l∣x)=αT(∣l′∣)+αT(∣l′∣−1)
后向
和前向算法类似,定义
β
t
(
s
)
\beta_{t}\left( s\right)
βt(s)在时刻
t
t
t时
l
s
:
∣
l
∣
l_{s : \left | l \right |}
ls:∣l∣的总概率。
β
t
(
s
)
=
∑
π
∈
N
T
B
(
π
t
:
T
)
=
l
s
:
∣
l
∣
∏
t
′
=
t
T
y
π
t
′
t
′
\beta _{t} \left ( s \right ) = \sum _{\begin{matrix} \pi \in N^{T} \\ B \left ( \pi_{t:T} \right ) = l_{s:\left | l \right |} \end{matrix}} \prod _{{t}'=t}^{T} y_{\pi_{{t}'}}^{{t}'}
βt(s)=π∈NTB(πt:T)=ls:∣l∣∑t′=t∏Tyπt′t′
初始化为:
β
T
(
∣
l
′
∣
)
=
y
b
T
β
T
(
∣
l
′
−
1
∣
)
=
y
l
∣
l
∣
T
β
T
(
s
)
=
0
,
∀
s
<
∣
l
′
∣
−
1
\begin{matrix} \beta_{T}\left ( \left | {l}' \right | \right ) = y_{b}^{T} \qquad \qquad \\ \beta_{T}\left ( \left | {l}' - 1 \right | \right ) = y_{l_{\left | l \right |}}^{T} \qquad \\ \beta_{T}\left( s \right ) = 0 , \forall s < \left | {l}' \right | -1 \end{matrix}
βT(∣l′∣)=ybTβT(∣l′−1∣)=yl∣l∣TβT(s)=0,∀s<∣l′∣−1
递推公式为:
β
t
(
s
)
=
{
(
β
t
+
1
(
s
)
+
β
t
+
1
(
s
+
1
)
)
y
l
s
′
t
i
f
l
s
′
=
b
l
a
n
k
o
r
l
s
−
2
′
=
l
s
′
(
β
t
+
1
(
s
)
+
β
t
+
1
(
s
+
1
)
+
β
t
+
1
(
s
+
2
)
)
y
l
s
′
t
o
t
h
e
r
s
i
z
e
\beta_{t} \left ( s \right ) = \left\{\begin{matrix} \left ( \beta_{t+1} \left ( s \right ) + \beta_{t+1} \left ( s + 1 \right ) \right ) y_{{l}'_{s}}^{t} \qquad \qquad \qquad & if \; \; {l}'_{s} = blank \; or \; {l}'_{s-2} = {l}'_{s} \\ \left ( \beta_{t+1} \left ( s \right ) + \beta_{t+1} \left ( s + 1 \right ) + \beta_{t+1} \left ( s + 2 \right ) \right ) y_{{l}'_{s}}^{t} & othersize \end{matrix}\right.
βt(s)={(βt+1(s)+βt+1(s+1))yls′t(βt+1(s)+βt+1(s+1)+βt+1(s+2))yls′tifls′=blankorls−2′=ls′othersize
应用
Tensorflow中已经给出了CTC的接口:
tf.nn.ctc_loss( labels, logits, label_length, logit_length, logits_time_major=True, unique=None, blank_index=None, name=None )
-
labels: 格式为[batch_size, max_label_seq_length]的张量或者稀疏张量
-
logits:默认是格式为[frames, batch_size, num_lables]的张量,如果logits_time_major=True,格式为[batch_size,batch_size, num_lables]
特点
- 使用CTC作为loss的计算方法,训练样本无需对齐。而且CTC使用前向后向算法递推计算,提高了计算速度。
- CTC 假设每个时间片是相互独立的,熟悉RNN的朋友们都知道这是稍微不合理的。
CTC 的两种解码方法
上述CTC loss 应用于图像文字识别的训练过程中。
在预测过程中,当输入
x
x
x,我们希望能够得到使得
p
(
l
∣
x
)
p\left( l | x \right)
p(l∣x)概率最大的标签
l
l
l。在序列学习问题中,这个问题被称为解码,在有限的时间内得到条件概率最大的序列$l^{*} $。
l
∗
=
a
r
g
m
a
x
p
(
l
∣
x
)
l^{*} = argmax p\left( l | x \right)
l∗=argmaxp(l∣x)
假设有字符列表 ( ′ − ′ , ′ A ′ , ′ B ′ ) \left( '-', 'A', 'B'\right) (′−′,′A′,′B′),时刻 T = 3 T=3 T=3,并且定义在 t t t时刻时,字符 c c c出现的概率为 P ( c , t ) P\left( c, t\right) P(c,t)。如下表所示,以横轴作为时刻序列,纵轴为字符列表,表格中的数字为概率,我们的目标是在这个二维空间中搜索出概率最大的标签$l^{*} $。
greedy decode
贪心的思想是每次都要最好的,那也就是说每次选取当前时刻的最大概率的字符,最后 T T T个字符串成一个标签。如下图所示, T = 1 T=1 T=1时, P ( ’ − ‘ , 1 ) P\left( ’-‘, 1\right) P(’−‘,1)概率最大; T = 2 T=2 T=2时, P ( ’ − ‘ , 2 ) P\left( ’-‘, 2\right) P(’−‘,2)概率最大; T = 3 T=3 T=3时, P ( ’ − ‘ , 3 ) P\left( ’-‘, 3\right) P(’−‘,3)概率最大。最后的输出标签为 " b l a n k " "blank" "blank", p ( l = " b l a n k " ∣ x ) = 0.5 × 0.4 × 0.6 = 0.12 p\left( l = "blank" | x \right) =0.5 \times 0.4 \times 0.6 = 0.12 p(l="blank"∣x)=0.5×0.4×0.6=0.12。
贪心解码只考虑了一条路线,在CTC算法中,我们曾定义过一个多对一的映射
B
B
B,合并输出标签相同的路径。如下图所示,当输出标签为
"
A
"
"A"
"A"时,有三条路径
"
A
−
−
"
"A--"
"A−−",
"
−
A
−
"
"-A-"
"−A−"和
"
−
−
A
"
"--A"
"−−A",表示为
B
(
"
A
−
−
"
)
=
B
(
"
−
A
−
"
)
=
B
(
"
−
−
A
"
)
=
"
A
"
B\left( "A--" \right) = B\left( "-A-" \right) = B\left( "--A" \right)="A"
B("A−−")=B("−A−")=B("−−A")="A",那么$p\left( l = “A” | x \right) $应为这三条路径概率的总和
p
(
l
=
"
A
"
∣
x
)
=
p
(
l
=
"
A
−
−
"
∣
x
)
+
p
(
l
=
"
−
A
−
"
∣
x
)
+
p
(
l
=
"
−
−
A
"
∣
x
)
=
0.198
p\left( l = "A" | x \right) = p\left( l = "A--" | x \right) + p\left( l = "-A-" | x \right) + p\left( l = "--A" | x \right) = 0.198
p(l="A"∣x)=p(l="A−−"∣x)+p(l="−A−"∣x)+p(l="−−A"∣x)=0.198
p
(
l
=
"
A
"
∣
x
)
=
0.198
p\left( l = "A" | x \right) =0.198
p(l="A"∣x)=0.198远远大于
p
(
l
=
"
b
l
a
n
k
"
∣
x
)
=
0.12
p\left( l = "blank" | x \right) = 0.12
p(l="blank"∣x)=0.12,
"
A
"
"A"
"A"更应该成为输出标签。
Beam Search
定义 t t t时刻网络输出序列对应的标签为 s s s的概率 P r ( s , t ) Pr \left( s, t \right) Pr(s,t),定义 P r − ( s , t ) Pr^{-} \left( s, t \right) Pr−(s,t)为 t t t时刻输出空字符的概率, P r + ( s , t ) Pr^{+} \left( s, t \right) Pr+(s,t)为 t t t时刻输出非空字符的概率。那么 P r ( s , t ) = P r − ( s , t ) + P r + ( s , t ) Pr \left( s, t \right) = Pr^{-} \left( s, t \right) + Pr^{+} \left( s, t \right) Pr(s,t)=Pr−(s,t)+Pr+(s,t)。Beam Search每一步搜索选取概率$Pr \left( s, t \right) $最大的W个节点进行扩展,W称为Beam Width。
下面的例子,我们选 W = 2 W=2 W=2。在 T = 0 T=0 T=0时刻,标签为空。
在
T
=
1
T=1
T=1时,标签
"
A
"
,
"
B
"
,
"
b
l
a
n
k
"
"A", "B", "blank"
"A","B","blank"的概率如下:
P
r
(
"
A
"
,
1
)
=
0.2
P
r
(
"
B
"
,
1
)
=
0.3
P
r
(
"
b
l
a
n
k
"
,
1
)
=
0.5
Pr \left( "A" , 1 \right) = 0.2 \\ Pr \left( "B", 1 \right) = 0.3 \\ Pr \left( "blank", 1 \right) = 0.5
Pr("A",1)=0.2Pr("B",1)=0.3Pr("blank",1)=0.5
标签
"
B
"
"B"
"B"和
"
b
l
a
n
k
"
"blank"
"blank"的概率最高,以标签
′
′
A
′
′
''A''
′′A′′和
"
B
"
"B"
"B"进行下一步扩展。
当
T
=
2
T=2
T=2,标签
"
B
B
"
"BB"
"BB"出现的概率为:
P
r
−
(
"
B
B
"
,
2
)
=
0
P
r
+
(
"
B
B
"
,
2
)
=
P
r
(
"
B
"
,
1
)
∗
P
(
′
B
′
,
2
)
=
0.09
Pr^{-} \left( "BB", 2 \right) = 0 \\ Pr^{+} \left( "BB", 2 \right) = Pr \left( "B" , 1 \right) * P\left( 'B', 2\right)= 0.09
Pr−("BB",2)=0Pr+("BB",2)=Pr("B",1)∗P(′B′,2)=0.09
P
r
(
"
B
B
"
,
2
)
=
P
r
−
(
"
B
B
"
,
2
)
+
P
r
+
(
"
B
B
"
,
2
)
=
0.09
Pr \left( "BB", 2 \right) = Pr^{-} \left( "BB", 2 \right) + Pr^{+} \left( "BB", 2 \right) = 0.09
Pr("BB",2)=Pr−("BB",2)+Pr+("BB",2)=0.09
同理,可计算标签
"
B
A
"
,
"
b
l
a
n
k
"
"BA","blank"
"BA","blank"的概率。
然而,当
T
=
2
T=2
T=2,要计算标签
"
A
"
"A"
"A"的概率时,字符
′
−
′
'-'
′−′和
′
A
′
'A'
′A′都有出现的可能,$Pr^{-} \left( “A”, 2 \right)
不
为
零
,
因
此
标
签
不为零,因此标签
不为零,因此标签“A”$出现的概率计算如下:
P
r
−
(
"
A
"
,
2
)
=
P
r
(
"
A
"
,
1
)
∗
P
(
′
−
′
,
2
)
=
0.08
P
r
+
(
"
A
"
,
2
)
=
P
r
(
"
b
l
a
n
k
"
,
1
)
∗
P
(
′
A
′
,
2
)
=
0.15
Pr^{-} \left( "A", 2 \right) = Pr \left( "A" , 1 \right) * P\left( '-', 2\right)= 0.08 \\ Pr^{+} \left( "A", 2 \right) = Pr \left( "blank" , 1 \right) * P\left( 'A', 2\right)= 0.15
Pr−("A",2)=Pr("A",1)∗P(′−′,2)=0.08Pr+("A",2)=Pr("blank",1)∗P(′A′,2)=0.15
则
P
r
(
"
A
"
,
2
)
=
P
r
+
(
"
A
"
,
2
)
+
P
r
−
(
"
A
"
,
2
)
=
0.23
Pr \left( "A", 2 \right) = Pr^{+} \left( "A", 2 \right) + Pr^{-} \left( "A", 2 \right) = 0.23
Pr("A",2)=Pr+("A",2)+Pr−("A",2)=0.23
同理可计算 P r ( " B " , 2 ) = 0.27 Pr \left( "B", 2 \right) = 0.27 Pr("B",2)=0.27
标签 " A " "A" "A"和 " B " "B" "B"的概率最高,以标签’ ′ A ′ ′ 'A'' ′A′′和 " B " "B" "B"进行下一步扩展。
按照上面的计算公式,当 T = 3 T=3 T=3时,计算出标签 " A " "A" "A"的概率 0.198 0.198 0.198最高,那么就以 " A " "A" "A"为输出标签。
当 W = 1 W=1 W=1时,beam search 就是 greedy decode。
Tensorflow 函数
tf.nn.ctc_greedy_decoder(
inputs,
sequence_length,
merge_repeated=True
)
tf.nn.ctc_beam_search_decoder(
inputs,
sequence_length,
beam_width=100,
top_paths=1,
merge_repeated=True
)
参考
- Supervised Sequence Labelling with Recurrent Neural Networks
- https://xiaodu.io/ctc-explained-part2/
- https://distill.pub/2017/ctc/