这对你来说可能是一个迟到的答案,但为了将来参考,我在这里根据你的想象模型提供一个入门代码基础。目前内置了三个attention layers,分别是
- MultiHeadAttention layer
- Attention layer (a.k.a. Luong-style attention)
- AdditiveAttention layer (a.k.a. Bahdanau-style attention)
对于起始代码,我们将在编码器部分使用 Luong 风格,在解码器部分使用 Bahdanau 风格的注意力机制。整个自动编码器架构将是
a. encoder: input -> embedding -> gru -> luong-style-attn
b. decoder: input -> lstm -> bahdanau-style-attn -> gap -> classifier
↓_____________________?
# whole model
autoencoder: encoder + decoder
让我们相应地构建模型。
编码器
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import *
from tensorflow.keras import backend
from tensorflow.keras import utils
backend.clear_session()
# int sequences.
enc_inputs = Input(shape=(20,), name='enc_inputs')
# Embedding lookup and GRU
embedding = Embedding(input_dim=100, output_dim=64)(enc_inputs)
whole_sequence = GRU(4, return_sequences=True)(embedding)
# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = Attention()([whole_sequence, whole_sequence])
# build encoder model
encoder = Model(enc_inputs, query_value_attention_seq, name='encoder')
检查布局。
utils.plot_model(encoder, show_shapes=True)
解码器
# int sequences.
dec_input = Input(shape=(20, 4), name='dec_inputs')
# LSTM
whole_sequence = LSTM(4, return_sequences=True)(dec_input)
# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = AdditiveAttention()([whole_sequence, dec_input])
# Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_value_attention = GlobalAveragePooling1D()(query_value_attention_seq)
# classification
dec_output = Dense(1, activation='sigmoid')(query_value_attention)
# build decoder model
decoder = Model(dec_input, dec_output, name='decoder')
检查布局。
自动编码器
# encoder
encoder_init = Input(shape=(20, ))
encoder_output = encoder(encoder_init); print(encoder_output.shape)
# decoder
decoder_output = decoder(encoder_output); print(decoder_output.shape)
# bind all: autoencoder
autoencoder = Model(encoder_init, decoder_output)
# check layout
utils.plot_model(autoencoder, show_shapes=True, expand_nested=True)
假人训练
x_train = np.random.randint(0, 10, (100,20)); print(x_train.shape)
y_train = np.random.randint(2, size=(100, 1)); print(y_train.shape)
(100, 20)
(100, 1)
autoencoder.compile('adam', 'binary_crossentropy')
autoencoder.fit(x_train, y_train, epochs=5, verbose=2)
Epoch 1/5
4/4 - 4s - loss: 0.6674
Epoch 2/5
4/4 - 0s - loss: 0.6637
Epoch 3/5
4/4 - 0s - loss: 0.6600
Epoch 4/5
4/4 - 0s - loss: 0.6590
Epoch 5/5
4/4 - 0s - loss: 0.6571
资源
另外,你可以阅读我关于注意力机制的其他答案。
这是我最喜欢的多头变压器,是3系列的视频。