【问题标题】:Proper way of saving clients' federated datasets保存客户端联合数据集的正确方法
【发布时间】:2021-03-24 06:53:00
【问题描述】:

我想使用emnist 数据集训练两个独立的TFF 模型。每个模型都应该在从数据集中随机抽取的 1000 不同的参与者上进行训练。

代码如下

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

participants_ids = np.random.choice(a=emnist_train.client_ids, 
                                    size=1000,
                                    replace=False)

federated_dataset = 
        [data_train.create_tf_dataset_for_client(i) for i in participants_ids]

nested_dataset = tf.data.Dataset.from_tensor_slices(federated_dataset)

尝试保存数据集

tf.data.experimental.save(nested_dataset, 'model_dataset')

生成以下警告。但是,保存已完成。

E tensorflow/core/framework/dataset.cc:89] The Encode() method is not implemented for DatasetVariantWrapper objects.

加载数据集并尝试检查其内容时会出现问题

dataset = tf.data.experimental.load('model_dataset', 
                      element_spec= 
                      DatasetSpec(collections.OrderedDict([
                         ('label', TensorSpec(shape=(), dtype=tf.int32)),
                         ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32))]), 
                      TensorShape([])

# verifying elements
for example in dataset:
        print(example)

以下错误

tensorflow.python.framework.errors_impl.DataLossError: Unable to parse tensor from stored proto.

尝试其他方法如pickle.dumpnp.save,均出现如下错误

tensorflow.python.framework.errors_impl.InternalError: Tensorflow type 21 not convertible to numpy dtype.

有什么好的方法可以保存新创建的数据集吗?

【问题讨论】:

  • 阅读上面的描述:数据集真的需要保存吗? TFF 应该在每次调用 load_data 时返回相同的 cliend_id -> 数据集映射——所以我们可以通过仅保存 ID 并在函数调用中重新创建单个数据集对象来实现相同的效果吗?
  • @KeithRush,最终目标是使用联合数据集训练一个federated model,然后将相同的数据集转换为flatten dataset,并保存以供以后训练几个传统的centralized models。这是为了在 federated 模型和 centralized 模型之间建立一些比较,在相同的数据集上进行训练。

标签: tensorflow tensorflow-datasets tensorflow-federated


【解决方案1】:

不是保存数据集的数据集,而是保存采样的客户端 ID 并在加载时构建数据集?

# Create a dataset of the participating IDs.
id_ds = tf.data.Dataset.from_tensor_slices(participants_ids)
tf.data.experimental.save(id_ds, '/tmp/id_dataset')

# Loaded the dataset back later.
loaded_ds = tf.data.experimental.load(
  '/tmp/id_dataset',
  element_spec=tf.TensorSpec(shape=[], dtype=tf.string))

# Create a federated dataset that yield (client_id, dataset).
federated_dataset = loaded_ds.map(
    lambda id: (id, emnist_train.serializable_dataset_fn(id)))
print(f'Loaded dataset with {tf.data.Dataset.cardinality(federated_dataset)} clients')
>>> Loaded dataset with 1000 clients.

print(f'Dataset element types: {next(iter(federated_dataset))[1].element_spec}')
>>> Dataset element types: OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

for id, dataset in federated_dataset.take(5):
  print(f'Client [{id}] has {sum(1 for _ in iter(dataset))} examples')
>>> Client [b'f2174_61'] has 106 examples
>>> Client [b'f1378_08'] has 99 examples
>>> Client [b'f1550_34'] has 106 examples
>>> Client [b'f3817_22'] has 106 examples
>>> Client [b'f1000_45'] has 109 examples

然后可以通过将tf.data.Dataset.map 替换为tf.data.Dataset.flat_map 来创建扁平化数据集:

flat_dataset = loaded_ds.flat_map(
    lambda id: emnist_train.serializable_dataset_fn(id))

print(f'Dataset element types: {next(iter(federated_dataset))[1].element_spec}')
>>> Dataset element types: OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

print(f'Flat dataset has {sum(1 for _ in flat_dataset):,} examples.')
>>> Flat dataset has 101,619 examples.

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2016-04-15
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-04-25
    • 1970-01-01
    • 1970-01-01
    • 2021-07-30
    相关资源
    最近更新 更多