【发布时间】:2023-03-04 14:55:02
【问题描述】:
我正在实施 WGAN,需要裁剪权重变量。
我目前使用 Tensorflow 和 Keras 作为高级 API。因此使用 Keras 构建层以避免手动创建和初始化变量。
问题是 WGAN 需要裁剪权重变量,这可以在我获得这些权重变量张量后使用 tf.clip_by_value(x, v0, v1) 完成,但我不知道如何安全地获得它们。
一种可能的解决方案可能是使用tf.get_collection() 来获取所有可训练变量。但我不知道如何在没有 bias 变量的情况下仅获得 weight 变量。
另一个解决方案是layer.get_weights(),但它得到numpy 数组,虽然我可以使用numpy API 剪辑它们并使用layer.set_weights() 设置它们,但这可能需要CPU-GPU 公司,并且可能不是不错的选择,因为每个 train step 都需要执行剪辑操作。
我知道的唯一方法是使用 exact 变量名直接访问它们,我可以从 TF 低级 API 或 TensorBoard 获得,但这可能不安全,因为 Keras 的命名规则无法保证保持稳定。
是否有任何干净的方法可以使用 Tensorflow 和 Keras 仅在 Ws 上执行 clip_by_value?
【问题讨论】:
标签: python tensorflow deep-learning keras