【问题标题】:FedProx with TensorFlow FederatedFedProx 与 TensorFlow 联合
【发布时间】:2020-06-25 15:18:12
【问题描述】:

有人知道如何使用 TensorFlow Federated 实现 FedProx 优化算法吗?似乎可以在线获得的唯一实现是直接使用 TensorFlow 开发的。 TFF 实现可以更轻松地与使用框架支持的 FedAvg 的实验进行比较。

这是 FedProx 存储库的链接:https://github.com/litian96/FedProx

论文链接:https://arxiv.org/abs/1812.06127

【问题讨论】:

    标签: tensorflow-federated


    【解决方案1】:

    目前,FedProx 实施不可用。我同意这将是一个有价值的算法。

    如果您有兴趣贡献 FedProx,最好的起点是 simple_fedavg,它是 FedAvg 的最小实现,旨在作为扩展的起点——有关更多详细信息,请参阅那里的自述文件。

    我认为client_update 方法需要发生重大变化,您可以将取决于model_weightsinitial_weights 的近端项添加到前向传递计算的损失中。

    【讨论】:

    • 我想知道在client_update 中添加以下行(紧跟在client_optimizer.apply_gradients(zip(grads, model_weights.trainable)) 之后)是否可以从simple_fedavg 开始实施FedProx:prox_term = tf.nest.map_structure(lambda a, b: (-1)*mu*(a - b), model_weights.trainable, initial_weights.trainable)client_optimizer.apply_gradients(zip(prox_term, model_weights.trainable)) 考虑到@ 987654332@ 为client_optimizer。我知道实现自定义优化器会更优雅。
    • 可能是的,但如果没有适当的例子,很难真正知道。但是在loss中加上prox term可能会更好更通用,让tf.GradientTape处理。
    • 我试着听从你的建议(谢谢!)。我在这篇文章的另一个答案中报告了我的实现。
    【解决方案2】:

    我在下面提供了我在 TFF 中的 FedProx 实现。我不是 100% 确定这是正确的实现;我发布这个答案也是为了讨论实际的代码示例。

    我尝试遵循 Jacub Konecny 的回答和评论中的建议。

    simple_fedavg(参考TFF Github repo)开始,我只是修改了client_update方法,并专门更改了GradientTape计算梯度的输入参数,即不只是传入输入outputs.loss,磁带计算梯度时考虑了 outputs.loss + proximal_term 之前(和迭代地)计算的。

    @tf.function
    def client_update(model, dataset, server_message, client_optimizer):
    """Performans client local training of "model" on "dataset".Args:
    model: A "tff.learning.Model".
    dataset: A "tf.data.Dataset".
    server_message: A "BroadcastMessage" from server.
    client_optimizer: A "tf.keras.optimizers.Optimizer". 
    Returns:
    A "ClientOutput".
    """ 
    
    def difference_model_norm_2_square(global_model, local_model):
        """Calculates the squared l2 norm of a model difference (i.e.
        local_model - global_model)
        Args:
            global_model: the model broadcast by the server
            local_model: the current, in-training model
    
        Returns: the squared norm
    
        """
        model_difference = tf.nest.map_structure(lambda a, b: a - b,
                                               local_model,
                                               global_model)
        squared_norm = tf.square(tf.linalg.global_norm(model_difference))
        return squared_norm
    
    model_weights = model.weights
    initial_weights = server_message.model_weights
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                         initial_weights)
    
    num_examples = tf.constant(0, dtype=tf.int32)
    loss_sum = tf.constant(0, dtype=tf.float32)
    # Explicit use `iter` for dataset is a trick that makes TFF more robust in
    # GPU simulation and slightly more performant in the unconventional usage
    # of large number of small datasets.
    
    for batch in iter(dataset):
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
    
            # ------ FedProx ------
            mu = tf.constant(0.2, dtype=tf.float32)
            prox_term =(mu/2)*difference_model_norm_2_square(model_weights.trainable, initial_weights.trainable)
            fedprox_loss = outputs.loss + prox_term
    
        # Letting GradientTape dealing with the FedProx's loss
        grads = tape.gradient(fedprox_loss, model_weights.trainable)
    
        client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
    
        batch_size = tf.shape(batch['x'])[0]
        num_examples += batch_size
        loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)
    
    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          model_weights.trainable,
                                          initial_weights.trainable)
    client_weight = tf.cast(num_examples, tf.float32)
    return ClientOutput(weights_delta, client_weight, loss_sum / client_weight)
    

    【讨论】:

    • 看起来不错,你也可以考虑tf.linalg.global_norm来简化新的部分。
    猜你喜欢
    • 1970-01-01
    • 2021-11-28
    • 1970-01-01
    • 2017-06-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-08-21
    • 2020-07-10
    相关资源
    最近更新 更多