【问题标题】:ResNet model in Tensorflow FederatedTensorflow Federated 中的 ResNet 模型
【发布时间】:2020-04-24 14:22:18
【问题描述】:

我尝试在 Tensorflow Federated 的“图像分类”教程中自定义模型。 (它最初使用的是顺序模型) 我用的是Keras ResNet50,但是开始训练的时候,总是报错“不兼容的形状”

这是我的代码:

NUM_CLIENTS = 4
NUM_EPOCHS = 10
BATCH_SIZE = 2
SHUFFLE_BUFFER = 5

def create_compiled_keras_model():
  model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet', 
                                                input_tensor=tf.keras.layers.Input(shape=(100, 
                                                300, 3)), pooling=None)

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model


def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)

错误信息: enter image description here

我觉得形状不兼容,因为时代和客户信息不知何故丢失了。如果有人能给我提示,将非常感激。

更新:

tff.learning.build_federated_averaging_process期间发生断言错误

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-164-dac26193d9d8> in <module>()
----> 1 iterative_process = tff.learning.build_federated_averaging_process(model_fn)
      2 
      3 # iterative_process = build_federated_averaging_process(model_fn)

13 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/federated_averaging.py in build_federated_averaging_process(model_fn, server_optimizer_fn, client_weight_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    165   return optimizer_utils.build_model_delta_optimizer_process(
    166       model_fn, client_fed_avg, server_optimizer_fn,
--> 167       stateful_delta_aggregate_fn, stateful_model_broadcast_fn)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py in build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    349   # still need this.
    350   with tf.Graph().as_default():
--> 351     dummy_model_for_metadata = model_utils.enhance(model_fn())
    352 
    353   # ===========================================================================

<ipython-input-159-b2763ace8e5b> in model_fn()
      1 def model_fn():
      2   keras_model = model
----> 3   return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/keras_utils.py in from_compiled_keras_model(keras_model, dummy_batch)
    211   # Model.test_on_batch() once before asking for metrics.
    212   if isinstance(dummy_tensors, collections.Mapping):
--> 213     keras_model.test_on_batch(**dummy_tensors)
    214   else:
    215     keras_model.test_on_batch(*dummy_tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
   1007         sample_weight=sample_weight,
   1008         reset_metrics=reset_metrics,
-> 1009         standalone=True)
   1010     outputs = (
   1011         outputs['total_loss'] + outputs['output_losses'] + outputs['metrics'])

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in test_on_batch(model, x, y, sample_weight, reset_metrics, standalone)
    503       y,
    504       sample_weights=sample_weights,
--> 505       output_loss_metrics=model._output_loss_metrics)
    506 
    507   if reset_metrics:

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    568         xla_context.Exit()
    569     else:
--> 570       result = self._call(*args, **kwds)
    571 
    572     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    606       # In this case we have not created variables on the first call. So we can
    607       # run the first trace but we should fail if variables are created.
--> 608       results = self._stateful_fn(*args, **kwds)
    609       if self._created_variables:
    610         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2407     """Calls a graph function specialized to the inputs."""
   2408     with self._lock:
-> 2409       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2410     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2411 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2765 
   2766       self._function_cache.missed.add(call_context_key)
-> 2767       graph_function = self._create_graph_function(args, kwargs)
   2768       self._function_cache.primary[cache_key] = graph_function
   2769       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2655             arg_names=arg_names,
   2656             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2657             capture_by_value=self._capture_by_value),
   2658         self._function_attributes,
   2659         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

AssertionError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_eager.py:345 test_on_batch  *
        with backend.eager_learning_phase_scope(0):
    /usr/lib/python3.6/contextlib.py:81 __enter__
        return next(self.gen)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py:425 eager_learning_phase_scope
        assert ops.executing_eagerly_outside_functions()

    AssertionError: 

【问题讨论】:

  • 请将您的堆栈跟踪复制为文本,而不是将其发布为图像。
  • 好问题!我们能看到sample_batch 的来源吗?
  • 对不起,我尝试以代码格式进行评论,但看起来很乱。 sample_batch 是批量训练数据。例如,如果批量大小为 2,则 sample_batchOrderedDict([('x', array([], []), ('y', array([], []))])

标签: tensorflow resnet tensorflow-federated federated-learning


【解决方案1】:

啊,我相信这个问题来自对sample_batch 的不匹配期望。 TFF 将sample_batch 传递给 Keras,Keras 调用此样本批次的前向传递来初始化 keras 模型的各种属性。 sample_batch 应该是您将要在服务器端提供模型的文字数据的样本,或者是与您将传入的数据的形状和类型相匹配的一批假数据。

前者的一个例子可以找到here(这里使用tf.data.Dataset),在测试代码中有几个后者的例子,比如here

根据我对模型定义的了解,您的 sample_batch 的 x 元素可能应该是形状为 [2, 100, 300, 3]ndarray(其中 2 表示批量大小,但从技术上讲,这可以是任何非零维度),并且y 元素也应该与您正在使用的数据中预期的y 结构相匹配。

我希望这会有所帮助,如果有任何问题,请回复!

有一点需要注意,这可能有助于思考 TFF——TFF 正在构建一个语法树,表示您通过 build_federated_averaging_process 定义的分布式计算。此错误实际上发生在此对象的构造 期间。 TFF 必须跟踪您传递给它的计算,以便知道要生成什么结构,这就是这里提出的问题。当您在返回的 IterativeProcess 上调用 next 时,会发生模型的实际训练

【讨论】:

  • 非常感谢!我调整了模型的输出层以匹配我的 y 形状并解决了羞耻不兼容的问题。我成功地从 tff.learning.from_compiled_keras_model 获得了一个 tff.learning.model,但是当我运行 iterative_process = tff.learning.build_federated_averaging_process(model_fn) 时,有一个 AssertionError: 这真的很难跟踪。花了很多时间,但不知道它是怎么来的。你能给我一些建议吗?该错误已附在问题的更新中。
  • 嗯,这很有趣,我们刚刚从另一个来源看到了类似的东西。我会深入研究
  • 听起来不错!请随时让我知道任何发现。非常感谢!
  • 只是一个快速更新:我们一直在深入研究内部重现,几乎准备提交错误 - 据我们所知,Keras 训练内部的某些内容没有正确设置某些原因。如果/当我们提交它时,我会在此处链接该错误。
  • 酷!并且不确定这些信息是否有帮助,但我发现断言发生在我应用 Keras Model 类时,即使我只是使用该类设置 1 个密集层。但是,如果我使用 keras.sequential(),则不存在这样的问题。像 keras.applications.resnet.ResNet50 这样的应用程序也没有问题。
【解决方案2】:

我有同样的问题: 如果我执行这一行 状态,指标 = iterative_process.next(状态,federated_train_data) print('round 1, metrics={}'.format(metrics))

我发现这个错误 InvalidArgumentError:找到 2 个根错误。 (0) 无效参数:默认 MaxPoolingOp 仅在设备类型 CPU 上支持 NHWC [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[子计算/StatefulPartitionedCall_1/ReduceDataset]] [[子计算/StatefulPartitionedCall_1/ReduceDataset/_140]] (1) 无效参数:默认 MaxPoolingOp 仅在设备类型 CPU 上支持 NHWC [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[子计算/StatefulPartitionedCall_1/ReduceDataset]] 0 次成功操作。 0 个衍生错误被忽略。

知道我使用的是 VGG16 你对这种类型的错误有什么想法吗

【讨论】:

    猜你喜欢
    • 2020-05-28
    • 2021-12-21
    • 2018-10-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-07-25
    • 1970-01-01
    相关资源
    最近更新 更多