所以我自己尝试了一个解决方案,这似乎有效。不过,我希望有更简单的东西。
我认为第二次打开模型文件并不是最佳选择。如果有人可以做得更好,一定要做到。
import h5py
from keras.models import load_model
from keras.models import save_model
def load_model_ext(filepath, custom_objects=None):
model = load_model(filepath, custom_objects=None)
f = h5py.File(filepath, mode='r')
meta_data = None
if 'my_meta_data' in f.attrs:
meta_data = f.attrs.get('my_meta_data')
f.close()
return model, meta_data
def save_model_ext(model, filepath, overwrite=True, meta_data=None):
save_model(model, filepath, overwrite)
if meta_data is not None:
f = h5py.File(filepath, mode='a')
f.attrs['my_meta_data'] = meta_data
f.close()
由于 h5 文件不接受 python 容器,您应该考虑将元数据转换为字符串。假设您的元数据以字典或列表的形式存在,您可以使用 json 进行转换。这还允许您在模型中存储更复杂的数据结构。
完整用法示例:
import json
import keras
# prepare model and label lookup
model = keras.Sequential();
model.add(keras.layers.Dense(10, input_dim=8, activation='relu'));
model.add(keras.layers.Dense(3, activation='softmax'))
model.compile()
filepath = r".\mymodel.h5"
labels = ["dog", "cat", "automobile"]
# save
labels_string = json.dumps(labels)
save_model_ext(model, filepath, meta_data=labels_string)
# load
loaded_model, loaded_labels_string = load_model_ext(filepath)
loaded_labels = json.loads(loaded_labels_string)
# label of class 0: "dog"
print(loaded_labels[0])
如果您希望为您的类提供字典,请注意 json 会将数字字典键转换为字符串,因此您必须在加载后将它们转换回数字。