我有同样的问题,我查看了源代码。
在tf2.0中,update_state函数的末尾有:
current_cm = confusion_matrix.confusion_matrix(
y_true,
y_pred,
self.num_classes,
weights=sample_weight,
dtype=dtypes.float64)
我研究了confusion_matrix 函数,
with ops.name_scope(name, 'confusion_matrix',
(predictions, labels, num_classes, weights)) as name:
labels, predictions = remove_squeezable_dimensions(
ops.convert_to_tensor(labels, name='labels'),
ops.convert_to_tensor(
predictions, name='predictions'))
predictions = math_ops.cast(predictions, dtypes.int64)
labels = math_ops.cast(labels, dtypes.int64)
# Sanity checks - underflow or overflow can cause memory corruption.
labels = control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(
labels, message='`labels` contains negative values')],
labels)
predictions = control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(
predictions, message='`predictions` contains negative values')],
predictions)
if num_classes is None:
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
math_ops.reduce_max(labels)) + 1
else:
num_classes_int64 = math_ops.cast(num_classes, dtypes.int64)
labels = control_flow_ops.with_dependencies(
[check_ops.assert_less(
labels, num_classes_int64, message='`labels` out of bound')],
labels)
predictions = control_flow_ops.with_dependencies(
[check_ops.assert_less(
predictions, num_classes_int64,
message='`predictions` out of bound')],
predictions)
if weights is not None:
weights = ops.convert_to_tensor(weights, name='weights')
predictions.get_shape().assert_is_compatible_with(weights.get_shape())
weights = math_ops.cast(weights, dtype)
shape = array_ops.stack([num_classes, num_classes])
indices = array_ops.stack([labels, predictions], axis=1)
values = (array_ops.ones_like(predictions, dtype)
if weights is None else weights)
cm_sparse = sparse_tensor.SparseTensor(
indices=indices,
values=values,
dense_shape=math_ops.cast(shape, dtypes.int64))
zero_matrix = array_ops.zeros(math_ops.cast(shape, dtypes.int32), dtype)
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
诀窍在代码的 第 6 行,tf 使用 math_ops.cast 将预测转换为 int64,因此当您将 [0.3, 0.6, 0.2, 0.9] 发送到转换函数时,它会返回 [0, 0, 0, 0]。
所以,这就是你得到一个混淆 maxtrix 的原因
[[2., 0.],
[2., 0.]]