【发布时间】:2021-10-19 12:22:11
【问题描述】:
为了复制Multimodal Few-Shot Learning with Frozen Language Models,我尝试在 TPUv3-32 上训练一个 ~7B 参数子类化 TF2 模型。在 7B 参数中,大约有 6B 参数被冻结。
我想使用模型和数据并行性来尽可能高效地训练它。据我所知,MeshTensorflow只能用于TF1编写的模型。
我尝试使用 TPUStrategy 中的 experimental_device_assignment,但它只将所有变量放在 TPU 的第一个(第 0 个)核心上,这很快就会耗尽内存。
在 TPUv3-8 上,我尝试保持 compute_shape = [2, 2, 1, 2] 和 [1, 1, 1, 2] 和 num_replicas = 1 但没有用。
我也愿意使用 GPU 来训练它。
【问题讨论】:
标签: tensorflow machine-learning google-cloud-platform deep-learning tpu