【发布时间】:2022-01-11 14:01:33
【问题描述】:
我正在读取通过内存映射保存到硬盘中的 4D 图像 numpy 数组,并将其与目标变量一起输入 tf.data.Dataset.from_tensor_slices。 numpy 数组大小超过 10 GB。当我进行 3d CNN 模型训练时,内核因内存耗尽而死亡。如何解决这个问题?
下面是部分代码:
def load_train():
X = np.load('trainX.npy', mmap_mode='c')
y = np.load('trainY.npy', mmap_mode='c')
return X, y
x_train, y_train = load_train()
def load_val():
X = np.load('testX.npy', mmap_mode='c')
y = np.load('testY.npy', mmap_mode='c')
return X, y
x_val, y_val = load_val()
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
batch_size = 2
train_dataset = (
train_loader.shuffle(len(x_train))
.map(train_preprocessing)
.batch(batch_size)
.prefetch(2)
)
validation_dataset = (
validation_loader.shuffle(len(x_val))
.map(validation_preprocessing)
.batch(batch_size)
.prefetch(2)
)
def get_model(width=512, height=512, depth=645):
inputs = keras.Input((width, height, depth, 1))
x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
x = layers.MaxPool3D(pool_size=2)(x)
x = layers.BatchNormalization()(x)
x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
x = layers.MaxPool3D(pool_size=2)(x)
x = layers.BatchNormalization()(x)
x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
x = layers.MaxPool3D(pool_size=2)(x)
x = layers.BatchNormalization()(x)
x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.MaxPool3D(pool_size=2)(x)
x = layers.BatchNormalization()(x)
x = layers.GlobalAveragePooling3D()(x)
x = layers.Dense(units=512, activation="relu")(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(units=1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs, name="3dcnn")
return model
model = get_model(width=512, height=512, depth=645)
model.compile(
loss="binary_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
metrics=["acc"],
)
model.fit(
train_dataset,
validation_data=validation_dataset,
epochs=epochs,
shuffle=True,
verbose=2,
callbacks=[checkpoint_cb, early_stopping_cb],
)
【问题讨论】:
标签: python tensorflow machine-learning image-processing conv-neural-network