【问题标题】:Python: difference between *args and list in a remote functionPython:*args 和远程函数中的列表之间的区别
【发布时间】:2019-01-21 00:45:18
【问题描述】:

定义一个列表

g = [1, 2, 3, 4]

版本 1

@ray.remote
def f(*g):     
    return np.mean(g, axis=0)
f_ids = [f.remote(*g) for _ in range(10)]
print(ray.get(f.remote(*f_ids)))

第 2 版

@ray.remote
def f(g):    # g is object ID list
    return np.mean(g, axis=0)
f_ids = [f.remote(g) for _ in range(10)]
print(ray.get(f.remote(f_ids)))

第一个代码工作正常,但版本 2 不工作。错误信息是

ray.get(f.remote(f_ids)) + 不支持的操作数类型:“common.ObjectID”和“common.ObjectID”

我想做第 2 版之类的事情的原因是我实际上想做以下事情

@remote
def f(g1, g2):    # g1 and g2 are object ID lists
    ...           # do something here

我不知道如何将g1g2 设置为*g1*g2,所以我想出了第2 版。为什么第2 版不起作用?我该如何解决?

参考代码在这里 https://ray.readthedocs.io/en/latest/example-parameter-server.html#synchronous-parameter-server

【问题讨论】:

    标签: python synchronization ray


    【解决方案1】:

    当参数被传递到 Ray 远程函数时,ray.ObjectID 类型的任何参数都会自动替换为解压后的值(这意味着 ray.get 在后台被调用)。所有其他参数均未更改。

    这就是为什么如果你定义一个远程函数像

    # Assuming you already called "import ray" and "ray.init()".
    
    @ray.remote
    def g(x):
        print(x)
    

    你会看到的

    g.remote(1)  # This prints '1'
    g.remote(ray.put(1))  # This also prints '1'
    g.remote([ray.put(1)])  # This prints '[ObjectID(feffffffe3f2116088b37cb305fbb2537b9783ee)]'
    

    在第三行,因为参数是一个列表,所以列表里面的ObjectID没有被它对应的值替换。

    在你的例子中,你有

    @ray.remote
    def f(*xs):
        print(xs)
    

    版本 1版本 2 之间的区别在于,在 版本 1 中,您传递了多个 ObjectID 参数。在版本 2 中,您传入一个参数,该参数是一个包含多个 ObjectIDs 的列表。

    xs = [ray.put(1), ray.put(2)]
    f.remote(*xs)  # This prints '(1, 2)'
    f.remote(xs)  # This prints '([ObjectID(fcffffffe3f2116088b37cb305fbb2537b9783ee), ObjectID(fbffffffe3f2116088b37cb305fbb2537b9783ee)],)'
    

    要做你想做的事,你可能需要做这样的事情(基本上将两个列表合二为一)。这不是最漂亮的,但应该可以。

    @ray.remote
    def h(num_xs, *xs_and_ys):
        xs = xs_and_ys[:num_xs]
        ys = xs_and_ys[num_xs:]
        print(xs, ys)
    
    x_ids = [ray.put(1), ray.put(2)]
    y_ids = [ray.put(3), ray.put(4), ray.put(5)]
    
    h.remote(len(x_ids), *(x_ids + y_ids))  # This prints '(1, 2) (3, 4, 5)'
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-10-01
      • 2021-01-30
      • 2013-09-16
      • 1970-01-01
      • 2017-12-28
      • 2021-07-29
      • 2011-08-25
      • 1970-01-01
      相关资源
      最近更新 更多