【发布时间】:2021-01-21 15:17:16
【问题描述】:
在 Tensorflow Federated (TFF) 中,您可以将 tff.learning.build_federated_averaging_process 传递给 broadcast_process 和 aggregation_process,它们可以嵌入自定义编码器,例如应用自定义压缩。
说到我的问题,我正在尝试实现一个编码器来稀疏化模型更新/模型权重。
我正在尝试通过实现来自tensorflow_model_optimization.python.core.internal 的EncodingStageInterface 来构建这样的编码器。
但是,我正在努力实现(本地)状态以逐轮累积模型更新/模型权重的归零坐标。请注意,此状态不应传达,只需要在本地维护(因此AdaptiveEncodingStageInterface 应该没有帮助)。一般来说,问题是如何在 Encoder 内部维护一个本地状态,然后将其传递给 fedavg 进程。
我附上了我的编码器实现的代码(除了我想添加的状态之外,它可以像预期的那样无状态地工作)。 然后,我在使用编码器实现的地方附上了我的代码摘录。 如果我对 stateful_encoding_stage_topk.py 中的注释部分进行注释,则代码不起作用:我无法弄清楚如何在 TF 非急切模式下管理状态(即张量)。
stateful_encoding_stage_topk.py
import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te
@te.core.tf_style_encoding_stage
class StatefulTopKEncodingStage(te.core.EncodingStageInterface):
ENCODED_VALUES_KEY = 'stateful_topk_values'
INDICES_KEY = 'indices'
def __init__(self):
super().__init__()
# Here I would like to init my state
#self.A = tf.zeros([800], dtype=tf.float32)
@property
def name(self):
"""See base class."""
return 'stateful_topk'
@property
def compressible_tensors_keys(self):
"""See base class."""
return [self.ENCODED_VALUES_KEY]
@property
def commutes_with_sum(self):
"""See base class."""
return True
@property
def decode_needs_input_shape(self):
"""See base class."""
return True
def get_params(self):
"""See base class."""
return {}, {}
def encode(self, x, encode_params):
"""See base class."""
del encode_params # Unused.
dW = tf.reshape(x, [-1])
# Here I would like to retrieve the state
A = tf.zeros([800], dtype=tf.float32)
#A = self.residual
dW_and_A = tf.math.add(A, dW)
percentage = tf.constant(0.4, dtype=tf.float32)
k_float = tf.multiply(percentage, tf.cast(tf.size(dW), tf.float32))
k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)
values, indices = tf.math.top_k(tf.math.abs(dW_and_A), k = k_int, sorted = False)
indices = tf.expand_dims(indices, 1)
sparse_dW = tf.scatter_nd(indices, values, tf.shape(dW_and_A))
# Here I would like to update the state
A_updated = tf.math.subtract(dW_and_A, sparse_dW)
#self.A = A_updated
encoded_x = {self.ENCODED_VALUES_KEY: values,
self.INDICES_KEY: indices}
return encoded_x
def decode(self,
encoded_tensors,
decode_params,
num_summands=None,
shape=None):
"""See base class."""
del decode_params, num_summands # Unused.
indices = encoded_tensors[self.INDICES_KEY]
values = encoded_tensors[self.ENCODED_VALUES_KEY]
tensor = tf.fill([800], 0.0)
decoded_values = tf.tensor_scatter_nd_update(tensor, indices, values)
return tf.reshape(decoded_values, shape)
def sparse_quantizing_encoder():
encoder = te.core.EncoderComposer(
StatefulTopKEncodingStage() )
return encoder.make()
fedavg_with_sparsification.py
[...]
def sparsification_broadcast_encoder_fn(value):
spec = tf.TensorSpec(value.shape, value.dtype)
return te.encoders.as_simple_encoder(te.encoders.identity(), spec)
def sparsification_mean_encoder_fn(value):
spec = tf.TensorSpec(value.shape, value.dtype)
if value.shape.num_elements() == 800:
return te.encoders.as_gather_encoder(
stateful_encoding_stage_topk.sparse_quantizing_encoder(), spec)
else:
return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
encoded_broadcast_process = (
tff.learning.framework.build_encoded_broadcast_process_from_model(
model_fn, sparsification_broadcast_encoder_fn))
encoded_mean_process = (
tff.learning.framework.build_encoded_mean_process_from_model(
model_fn, sparsification_mean_encoder_fn))
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.004),
client_weight_fn=lambda _: tf.constant(1.0),
broadcast_process=encoded_broadcast_process,
aggregation_process=encoded_mean_process)
[...]
我正在使用:
- 张量流 2.4.0
- 张量流联合 0.17.0
【问题讨论】:
标签: python tensorflow-federated