【问题标题】:Pytorch: How to get all model's parameters that require grad?Pytorch:如何获取需要grad的所有模型参数?
【发布时间】:2019-07-11 16:33:21
【问题描述】:

我在 pytorch 中有一些模型,我想手动访问和更改其可更新的权重。

如何正确完成?

理想情况下,我想要这些权重的张量。

在我看来

for parameter in model.parameters():
    do_something_to_parameter(parameter)

这不是正确的方法,因为

  1. 它不使用 GPU,也不能
  2. 它甚至不使用低级实现

手动访问模型权重的正确方法是什么(不是通过loss.backwardoptimizer.step)?

【问题讨论】:

  • 不使用 GPU 是什么意思?您对 GPU 上的张量执行的任何张量操作都使用 GPU。这包括模型参数。
  • @Coolness 我的意思是任何使用 python for 循环遍历参数的操作都不会通过任何 GPU 并行性。为了发生任何并行性,必须传递一些函数,这显然不是这里发生的事情。例如,loss.backward() 可以在 GPU 上计算。
  • 但参数集具有不同的形状和含义。不可能将它们作为单个张量,除非你将它们展平为一个向量,在这种情况下它们就失去了意义。你想对它们进行什么样的操作?
  • 我希望有某种方法可以使用一些向量操作一起更新所有参数。一条蟒蛇不可能是要走的路。更具体地说,我正在尝试实现stackoverflow.com/questions/54734556/…,如果这样实现,恐怕大型网络将永远无法完成优化。
  • 因为“将它们放在一个大向量中”会创建所有参数的副本。您需要单独存储它们以保持形状,例如将偏差与权重分开(想想线性层,A x + b 将如何处理单个连接向量?)。 for 循环中的开销完全由模型的实际前向/后向传递控制。

标签: python machine-learning pytorch backpropagation


【解决方案1】:

这是我的方法,您通常可以在此处输入任何模型,它会返回所有 torch.nn.* 事物的列表,只需在其周围添加一个环绕以返回不是模块而是权重

def flatten_model(modules):
    def flatten_list(_2d_list):
        flat_list = []
        # Iterate through the outer list
        for element in _2d_list:
            if type(element) is list:
                # If the element is of type list, iterate through the sublist
                for item in element:
                    flat_list.append(item)
            else:
                flat_list.append(element)
        return flat_list

    ret = []
    try:
        for _, n in modules:
            ret.append(loopthrough(n))
    except:
        try:
            if str(modules._modules.items()) == "odict_items([])":
                ret.append(modules)
            else:
                for _, n in modules._modules.items():
                    ret.append(loopthrough(n))
        except:
            ret.append(modules)
    return flatten_list(ret)

【讨论】:

    猜你喜欢
    • 2019-07-17
    • 2021-07-23
    • 2020-08-31
    • 2014-03-09
    • 2018-11-17
    • 2021-12-04
    • 1970-01-01
    • 2017-10-31
    • 2016-12-05
    相关资源
    最近更新 更多