【问题标题】:Serve Tensorflow models in parallel with Ray与 Ray 并行服务 Tensorflow 模型
【发布时间】:2020-10-16 23:41:59
【问题描述】:

我正在查看有关使用 ray.serve 以并行预测保存的 TF 模型的 StackOverflow 线程: https://stackoverflow.com/a/62459372

我尝试了以下类似的方法:

import ray
from ray import serve; serve.init()
import tensorflow as tf

class A:
    def __init__(self):
        self.model = tf.constant(1.0) # dummy example

   @serve.accept_batch
    def __call__(self, *, input_data=None):
        print(input_data) # test if method is entered
        # do stuff, serve model

if __name__ == '__main__':
    serve.create_backend("tf", A,
        # configure resources
        ray_actor_options={"num_cpus": 2},
        # configure replicas
        config={
            "num_replicas": 2, 
            "max_batch_size": 24,
            "batch_wait_timeout": 0.1
        }
    )
    serve.create_endpoint("tf", backend="tf")
    handle = serve.get_handle("tf")

    args = [1,2,3]

    futures = [handle.remote(input_data=i) for i in args]
    result = ray.get(futures)

但是,我收到以下错误: TypeError: __call__() takes 1 positional argument but 2 positional arguments (and 1 keyword-only argument) were given。传递给 __call__ 的参数有问题。

这似乎是一个简单的错误,我应该如何更改args 数组以便实际输入__call__ 方法?

【问题讨论】:

    标签: python tensorflow ray


    【解决方案1】:

    Ray 1.0 的 API 已更新。请参阅迁移指南https://gist.github.com/simon-mo/6d23dfed729457313137aef6cfbc7b54

    对于您发布的具体代码示例,您可以将其更新为:

    import ray
    from ray import serve
    import tensorflow as tf
    
    class A:
        def __init__(self):
            self.model = tf.Constant(1.0) # dummy example
    
       @serve.accept_batch
        def __call__(self, requests):
            for req in requests:
                print(req.data) # test if method is entered
            
            # do stuff, serve model
    
    if __name__ == '__main__':
        client = serve.start()
        client.create_backend("tf", A,
            # configure resources
            ray_actor_options={"num_cpus": 2},
            # configure replicas
            config={
                "num_replicas": 2, 
                "max_batch_size": 24,
                "batch_wait_timeout": 0.1
            }
        )
        client.create_endpoint("tf", backend="tf")
        handle = client.get_handle("tf")
    
        args = [1,2,3]
    
        futures = [handle.remote(i) for i in args]
        result = ray.get(futures)
    

    【讨论】:

    • 谢谢,这行得通 - 是否强制该函数必须命名为 __call__?我们可以将批处理函数命名为其他名称吗?
    猜你喜欢
    • 2021-09-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-09-24
    • 2018-03-03
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多