【发布时间】:2020-11-15 17:38:18
【问题描述】:
我有一个多类分类器,它从生成器获取输入:
def generate_train_data(path, x_shape):
genres = {"hip-hop":0, "r&b":1, "pop":2, "jazz":3}
genre_labels = to_categorical(list(genres.values()), num_classes=len(genres))
# some processing to create variables x and genre...
# (mock values)
x = np.zeros(x_shape)
x = x[None, :, :, :]
genre = "hip-hop"
yield (x, genre_labels[genres[genre]])
分类器定义如下:
input_shape = (96, 84, 5)
i = Input(shape=input_shape, name='encoder_input')
cx = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(i)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
x = Flatten()(cx)
x = Dense(20, activation='relu')(x)
x = BatchNormalization()(x)
x = Dense(4, activation='softmax')(x)
classifier = Model(i, x, name='genre_classifier')
classifier.summary()
classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
但是,当我尝试拟合分类器时:
classifier.fit(generate_train_data(path, input_shape), epochs=30, validation_data=generate_test_data(path, input_shape), verbose=verbosity)
我收到以下错误:
ValueError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:789 run_step **
outputs = model.train_step(data)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:749 train_step
y, y_pred, sample_weight, regularization_losses=self.losses)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/compile_utils.py:204 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:149 __call__
losses = ag_call(y_true, y_pred)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:253 call **
return ag_fn(y_true, y_pred, **self._fn_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:1535 categorical_crossentropy
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:4687 categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py:1134 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (None, 1) and (None, 4) are incompatible
生成器返回的类标签值是一个长度为 4 的数组,为什么 keras 建议它的大小为 1?
注意:此代码在 Colab 的 tensorflow 版本 2.3 上运行。可以在此 Colab 链接上找到重现此错误的模拟版本:https://colab.research.google.com/drive/1SQZFspj3UOwP2ApIiaI2lvB2Z59bdVOk?usp=sharing
编辑:在 generate_train_data 中添加了模拟值,以便代码可以重现
【问题讨论】:
-
能否提供可重现的代码?
-
@M.Innat 类型变量可以简单地定义为“hip-hop”、“r&b”等之一,并且 x 变量可以定义为任何二进制 numpy 数组,只要它的形状与提供给分类器输入层的 input_shape 参数相同。
-
代码对我有用。您使用的是哪个 tensorflow 版本?
-
@Marcus 代码是在 Colab 上运行的,所以应该是 2.x 版本
-
它在我的 tf 2.3 上运行良好。我认为您的逻辑看起来不错,并且无法弄清楚为什么会出现该错误。您的可重现代码需要一些修改,例如在您的生成器中,我不得不将 x 重塑为
x = x[None,:,:,:],这可能与您的错误无关,但认为值得一提。
标签: python-3.x tensorflow machine-learning keras google-colaboratory