【问题标题】:NameError: name 'FileCheckpointManager' is not defined while saving modelNameError:保存模型时未定义名称“FileCheckpointManager”
【发布时间】:2020-05-04 23:44:29
【问题描述】:

在模拟了图像分类的联邦学习代码后,我想保存我的模型,所以我添加了这两行

ckpt_manager = FileCheckpointManager("model.h5")
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=2) 

这是我所有的代码:

import collections
import time

import tensorflow as tf
tf.compat.v1.enable_v2_behavior()

import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()


def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])
def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).shuffle(500).batch(20).map(map_fn)


train_data = [client_data(n) for n in range(10)]
element_spec = train_data[0].element_spec

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  return tff.learning.from_keras_model(
      model,
      input_spec=element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))

....
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = trainer.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))


ckpt_manager = FileCheckpointManager("model.h5")
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=9)

但确实出现了这个错误:

NameError: name 'FileCheckpointManager' is not defined

如果您告诉我如何解决这个问题,我将不胜感激

【问题讨论】:

    标签: tensorflow-federated


    【解决方案1】:

    看起来代码缺少检查点管理器模块的导入。

    FileCheckpointMangercheckpoint_manager 模块中定义:tensorflow_federated/python/research/utils/checkpoint_manager.py

    尝试像这样在文件顶部添加导入(以下示例假设 tensorflow 联合 github 存储库位于导入搜索路径中):

    from tensorflow_federated.python.research.utils import checkpoint_manager
    # ...
    ckpt_manager = checkpoint_manager.FileCheckpointManager("model.h5")
    

    【讨论】:

    • 非常感谢您的回复,但我发现这个错误ModuleNotFoundError: No module named 'tensorflow_federated.python.research' 知道我使用的是Tensorflow 联合版本0.12.0
    • 两个注意事项:CheckpointManager 不在 pip 包中,research/ 目录不是官方 TFF 版本的一部分,支持由个别研究人员处理。错误是说 Python 不知道 tensorflow_federated/python/research/utils/checkpoint_manager.py. 在哪里。一种选择是在本地克隆 github 存储库并在存储库的根目录中运行您的脚本。
    • 很高兴您的回答,非常感谢您
    • 你能告诉我吗,现在我发现了其他问题NameError: name 'ServerState' is not defined我看到了这个link,但我不知道如何使用它?我必须在本地克隆这个 github repo 并在 repo 的根目录中运行我的脚本吗?或如何?谢谢
    猜你喜欢
    • 2023-01-23
    • 1970-01-01
    • 2018-01-24
    • 1970-01-01
    • 1970-01-01
    • 2018-05-18
    • 1970-01-01
    • 2021-04-15
    相关资源
    最近更新 更多