为Ray actor功能实现缓存 更通用的方法

问题描述

我的目标是使下面的代码在大约0.3秒而不是0.5秒内执行。我尝试在functools.lru_cache上使用toolz.functoolz.memoizekids.cache.cachefoo的修饰符,但是没有一个起作用(错误消息或执行不正确)。我该怎么做才能使这项工作成功?

import ray


@ray.remote
class Foo:
    def foo(self,x):
        print("executing foo with {}".format(x))
        time.sleep(0.1)


ray.init()
f = Foo.remote()
s = time.time()
ray.get([f.foo.remote(x=i) for i in [1,2,1,4,1]])
print(time.time()-s)
ray.shutdown()

解决方法

一般警告:如果函数产生副作用,则缓存任意函数调用可能很危险。

在这种情况下,大概是您希望程序输出

executing foo with 1 
executing foo with 2 
executing foo with 4 

您提到的其他高速缓存工具不太适合与Ray一起使用,因为它们试图将高速缓存存储在某种全局状态下,并且没有将该状态存储在可通过分布式方式访问的位置道路。由于您已经有演员,因此您可以将全局状态存储在演员中。

@ray.remote
class Foo:
    def __init__(self):
        self.foo_cache = {}

    def foo(self,x):
        def real_foo(x):
            print("executing foo with {}".format(x))
            time.sleep(0.1)
        if x not in self.foo_cache:
            self.foo_cache[x] = real_foo(x)
        return self.foo_cache[x]

这是一种非常通用的缓存技术,这里唯一重要的区别是我们必须将状态存储在参与者中。

更通用的方法

我们还可以通过定义通用函数缓存将这种方法推广到任何Ray函数:

@ray.remote
class FunctionCache:
    def __init__(self,func):
        self.func = ray.remote(func)
        self.cache = {}

    def call(self,*args,**kwargs):
        if (args,kwargs) not in cache:
            cache[(args,kwargs)] = self.func(*args,**kwargs)
        return cache[(args,kwargs)]

然后为了清理使用方式,我们可以定义一个装饰器:

class RemoteFunctionLookAlike:
    def __init__(self,func):
        self.func = func

    def remote(self,**kwargs):
        return self.func(*args,**kwargs)


def ray_cache(func):
    cache = FunctionCache.remote(func)
    def call_with_cache(*args,**kwargs):
        return cache.call.remote(*args,**kwargs)
    return RayFunctionLookAlike(call_with_cache)

最后,要使用此缓存:

@ray_cache
def foo(x):
    print("Function called!")
    return abc

ray.get([foo.remote("constant") for _ in range(100)]) # Only prints "Function called!" once.

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...