【发布时间】:2020-06-25 15:18:12
【问题描述】:
有人知道如何使用 TensorFlow Federated 实现 FedProx 优化算法吗?似乎可以在线获得的唯一实现是直接使用 TensorFlow 开发的。 TFF 实现可以更轻松地与使用框架支持的 FedAvg 的实验进行比较。
这是 FedProx 存储库的链接:https://github.com/litian96/FedProx
【问题讨论】:
有人知道如何使用 TensorFlow Federated 实现 FedProx 优化算法吗?似乎可以在线获得的唯一实现是直接使用 TensorFlow 开发的。 TFF 实现可以更轻松地与使用框架支持的 FedAvg 的实验进行比较。
这是 FedProx 存储库的链接:https://github.com/litian96/FedProx
【问题讨论】:
目前,FedProx 实施不可用。我同意这将是一个有价值的算法。
如果您有兴趣贡献 FedProx,最好的起点是 simple_fedavg,它是 FedAvg 的最小实现,旨在作为扩展的起点——有关更多详细信息,请参阅那里的自述文件。
我认为client_update 方法需要发生重大变化,您可以将取决于model_weights 和initial_weights 的近端项添加到前向传递计算的损失中。
【讨论】:
client_update 中添加以下行(紧跟在client_optimizer.apply_gradients(zip(grads, model_weights.trainable)) 之后)是否可以从simple_fedavg 开始实施FedProx:prox_term = tf.nest.map_structure(lambda a, b: (-1)*mu*(a - b), model_weights.trainable, initial_weights.trainable) 和client_optimizer.apply_gradients(zip(prox_term, model_weights.trainable)) 考虑到@ 987654332@ 为client_optimizer。我知道实现自定义优化器会更优雅。
tf.GradientTape处理。
我在下面提供了我在 TFF 中的 FedProx 实现。我不是 100% 确定这是正确的实现;我发布这个答案也是为了讨论实际的代码示例。
我尝试遵循 Jacub Konecny 的回答和评论中的建议。
从simple_fedavg(参考TFF Github repo)开始,我只是修改了client_update方法,并专门更改了GradientTape计算梯度的输入参数,即不只是传入输入outputs.loss,磁带计算梯度时考虑了 outputs.loss + proximal_term 之前(和迭代地)计算的。
@tf.function
def client_update(model, dataset, server_message, client_optimizer):
"""Performans client local training of "model" on "dataset".Args:
model: A "tff.learning.Model".
dataset: A "tf.data.Dataset".
server_message: A "BroadcastMessage" from server.
client_optimizer: A "tf.keras.optimizers.Optimizer".
Returns:
A "ClientOutput".
"""
def difference_model_norm_2_square(global_model, local_model):
"""Calculates the squared l2 norm of a model difference (i.e.
local_model - global_model)
Args:
global_model: the model broadcast by the server
local_model: the current, in-training model
Returns: the squared norm
"""
model_difference = tf.nest.map_structure(lambda a, b: a - b,
local_model,
global_model)
squared_norm = tf.square(tf.linalg.global_norm(model_difference))
return squared_norm
model_weights = model.weights
initial_weights = server_message.model_weights
tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
initial_weights)
num_examples = tf.constant(0, dtype=tf.int32)
loss_sum = tf.constant(0, dtype=tf.float32)
# Explicit use `iter` for dataset is a trick that makes TFF more robust in
# GPU simulation and slightly more performant in the unconventional usage
# of large number of small datasets.
for batch in iter(dataset):
with tf.GradientTape() as tape:
outputs = model.forward_pass(batch)
# ------ FedProx ------
mu = tf.constant(0.2, dtype=tf.float32)
prox_term =(mu/2)*difference_model_norm_2_square(model_weights.trainable, initial_weights.trainable)
fedprox_loss = outputs.loss + prox_term
# Letting GradientTape dealing with the FedProx's loss
grads = tape.gradient(fedprox_loss, model_weights.trainable)
client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
batch_size = tf.shape(batch['x'])[0]
num_examples += batch_size
loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)
weights_delta = tf.nest.map_structure(lambda a, b: a - b,
model_weights.trainable,
initial_weights.trainable)
client_weight = tf.cast(num_examples, tf.float32)
return ClientOutput(weights_delta, client_weight, loss_sum / client_weight)
【讨论】:
tf.linalg.global_norm来简化新的部分。