如何通过 Huggingface 将分类与零镜头分类并行化?

问题描述

我有大约 70 个类别(也可以是 20 或 30 个),我希望能够使用 ray 并行化该过程,但出现错误

import pandas as pd
import swifter
import json
import ray
from transformers import pipeline

classifier = pipeline("zero-shot-classification")

labels = ["vegetables","potato","bell pepper","tomato","onion","carrot","broccoli","lettuce","cucumber","celery","corn","garlic","mashrooms","cabbage","spinach","beans","cauliflower","asparagus","fruits","bananas","apples","strawberries","grapes","oranges","lemons","avocados","peaches","blueberries","pineapple","cherries","pears","mangoe","berries","red meat","beef","pork","mutton","veal","lamb","venison","goat","mince","white meat","chicken","turkey","duck","goose","pheasant","rabbit","Processed meat","sausages","bacon","ham","hot dogs","frankfurters","tinned meat","salami","pâtés","beef jerky","chorizo","pepperoni","corned beef","fish","catfish","cod","pangasius","pollock","tilapia","tuna","salmon","seafood","shrimp","squid","mussels","scallop","octopus","grains","rice","wheat","bulgur","oat","quinoa","buckwheat","meals","salad","soup","steak","pizza","pie","burger","backery","bread","souce","pasta","sandwich","waffles","barbecue","roll","wings","ribs","cookies"]


ray.init()
@ray.remote
def get_meal_category(seq,labels,n=3):
    res_dict = classifier(seq,labels)
    return list(zip([seq for i in range(n)],res_dict["labels"][0:n],res_dict["scores"][0:n]))

res_list = ray.get([get_meal_category.remote(merged_df["title"][i],labels) for i in range(10)])

其中 merge_df 是一个大数据框,其标签列中包含餐点名称,例如:

['Cappuccino','Stove Top Stuffing Mix For Turkey (Kraft)','Roasted Dark Turkey Meat','Cappuccino','Low Fat 2% Small Curd Cottage Cheese (Daisy)','Rice Cereal (Gerber)','Oranges']

请指教如何避免光线误差和并行化分类

错误

2021-02-17 16:54:51,689 WARNING worker.py:1107 -- Warning: The remote function __main__.get_meal_category has size 1630925709 when pickled. It will be stored in Redis,which Could cause memory issues. This may mean that its deFinition uses a large array or other object.
---------------------------------------------------------------------------
ConnectionResetError                      Traceback (most recent call last)
~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self,command,check_health)
    705             for item in command:
--> 706                 sendall(self._sock,item)
    707         except socket.timeout:

~/.local/lib/python3.8/site-packages/redis/_compat.py in sendall(sock,*args,**kwargs)
      8 def sendall(sock,**kwargs):
----> 9     return sock.sendall(*args,**kwargs)
     10 

ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception,another exception occurred:

ConnectionError                           Traceback (most recent call last)
<ipython-input-9-1a5345832fba> in <module>
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i],labels) for i in range(10)])

<ipython-input-9-1a5345832fba> in <listcomp>(.0)
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i],labels) for i in range(10)])

~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote_proxy(*args,**kwargs)
     99         @wraps(function)
    100         def _remote_proxy(*args,**kwargs):
--> 101             return self._remote(args=args,kwargs=kwargs)
    102 
    103         self.remote = _remote_proxy

~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote(self,args,kwargs,num_returns,num_cpus,num_gpus,memory,object_store_memory,accelerator_type,resources,max_retries,placement_group,placement_group_bundle_index,placement_group_capture_child_tasks,override_environment_variables,name)
    205 
    206             self._last_export_session_and_job = worker.current_session_and_job
--> 207             worker.function_actor_manager.export(self)
    208 
    209         kwargs = {} if kwargs is None else kwargs

~/.local/lib/python3.8/site-packages/ray/function_manager.py in export(self,remote_function)
    142         key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
    143                + remote_function._function_descriptor.function_id.binary())
--> 144         self._worker.redis_client.hset(
    145             key,146             mapping={

~/.local/lib/python3.8/site-packages/redis/client.py in hset(self,name,key,value,mapping)
   3048                 items.extend(pair)
   3049 
-> 3050         return self.execute_command('HSET',*items)
   3051 
   3052     def hsetnx(self,value):

~/.local/lib/python3.8/site-packages/redis/client.py in execute_command(self,**options)
    898         conn = self.connection or pool.get_connection(command_name,**options)
    899         try:
--> 900             conn.send_command(*args)
    901             return self.parse_response(conn,command_name,**options)
    902         except (ConnectionError,TimeoutError) as e:

~/.local/lib/python3.8/site-packages/redis/connection.py in send_command(self,**kwargs)
    723     def send_command(self,**kwargs):
    724         "Pack and send a command to the Redis server"
--> 725         self.send_packed_command(self.pack_command(*args),726                                  check_health=kwargs.get('check_health',True))
    727 

~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self,check_health)
    715                 errno = e.args[0]
    716                 errmsg = e.args[1]
--> 717             raise ConnectionError("Error %s while writing to socket. %s." %
    718                                   (errno,errmsg))
    719         except BaseException:

ConnectionError: Error 104 while writing to socket. Connection reset by peer.

解决方法

发生此错误是因为向 redis 发送大对象。 merged_df 是一个大型数据帧,由于您调用 get_meal_category 10 次,Ray 将尝试序列化 merged_df 10 次。相反,如果您只将 merged_df 放入 Ray 对象存储中一次,然后传递对该对象的引用,这应该可以工作。

编辑:由于分类器也很大,因此也做类似的事情。

你能不能试试这样的:

ray.init()
df_ref = ray.put(merged_df)
model_ref = ray.put(classifier)

@ray.remote
def get_meal_category(classifier,df,i,labels,n=3):
    seq = df["title"][i]
    res_dict = classifier(seq,labels)
    return list(zip([seq for i in range(n)],res_dict["labels"][0:n],res_dict["scores"][0:n]))

res_list = ray.get([get_meal_category.remote(model_ref,df_ref,labels) for i in range(10)])