【发布时间】:2020-08-20 12:53:37
【问题描述】:
我尝试在 keras 中使用自定义层。这是一个简单的层,只是一个带有可训练参数的 matmul。
from tensorflow import keras
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.optimizers import RMSprop
from keras.layers import Layer
from tensorflow.keras import backend as K
class MultiLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MultiLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer
if self.output_dim[0] != input_shape[1]:
raise Exception("expect input shape with [{},?], but get input with shape {}".format(self.output_dim[0],input_shape), input_shape)
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[2], self.output_dim[0]),
initializer='uniform',
trainable=True)
super(MultiLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
x_fake = np.random.random((10,28,28))
y_fake = [np.diag(np.ones(28))]*10
input_shape = np.shape(x_fake)[1:]
print(input_shape)
ipt = Input(name='inputs',shape=input_shape)
layer = MultiLayer((input_shape[0],input_shape[0]),name="dev")(ipt)
#layer = Flatten()(layer)
model = Model(inputs=ipt,outputs=layer)
model.summary()
rms = RMSprop()
model.compile(loss="rms", optimizer=rms, metrics=['accuracy'])
model.fit(x_fake,y_fake)
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
inputs (InputLayer) [(None, 28, 28)] 0
_________________________________________________________________
dev (MultiLayer) (None, 28, 28) 784
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________
但是当我拟合这个模型时,就会发生错误。
ValueError: Data cardinality is ambiguous:
x sizes: 10
y sizes: 28, 28, 28, 28, 28, 28, 28, 28, 28, 28
Please provide data which shares the same first dimension.
我不知道尺寸是什么意思。
如何解决?
【问题讨论】:
标签: python tensorflow keras tf.keras