在这种情况下,您需要填充两个 logits 和标签,以使它们具有相同的长度。所以,如果你有张量logits 的大小为(batch_size, length, vocab_size) 和labels 的大小为(batch_size, length) 其中length 是你的序列的大小。首先,您必须将它们填充到相同的长度:
def _pad_tensors_to_same_length(logits, labels):
"""Pad x and y so that the results have the same length (second dimension)."""
with tf.name_scope("pad_to_same_length"):
logits_length = tf.shape(logits)[1]
labels_length = tf.shape(labels)[1]
max_length = tf.maximum(logits_length, labels_length)
logits = tf.pad(logits, [[0, 0], [0, max_length - logits_length], [0, 0]])
labels = tf.pad(labels, [[0, 0], [0, max_length - labels_length]])
return logits, labels
然后你可以做填充交叉熵:
def padded_cross_entropy_loss(logits, labels, vocab_size):
"""Calculate cross entropy loss while ignoring padding.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch_size, length_labels]
vocab_size: int size of the vocabulary
Returns:
Returns the cross entropy loss
"""
with tf.name_scope("loss", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
# Calculate cross entropy
with tf.name_scope("cross_entropy", values=[logits, labels]):
xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=targets)
weights = tf.to_float(tf.not_equal(labels, 0))
return xentropy * weights