predict() 得到了一个意外的关键字参数“stats”

问题描述

我正在尝试从 AI 平台上提供的 Tensor Flow 自定义例程中获取预测。

我设法使用以下设置为其提供服务:--runtime-version 2.3 --python-version 3.7 --machine-type mls1-c4-m2

但是当我尝试做出任何预测时,我总是收到此错误

ERROR:root:Prediction Failed: predict() got an unexpected keyword argument 'stats'
ERROR:root:Prediction Failed: unkNown error.

该例程有两个步骤:

  1. 获取输入(一个字符串)并使用 .pkl 格式的弓形模型将其转换为嵌入
  2. 使用嵌入以使用保存为 .h5 文件的 keras 模型获取预测

这是我的 setup.py

from setuptools import setup

required_PACKAGES = ['Keras==2.3.1','sklearn==0.0','h5py<3.0.0','numpy==1.16.0','scipy==1.4.1','pyyaml==5.2']

setup(
        name='my_custom_code',version='0.1',scripts=['predictor.py'],install_requires=required_PACKAGES,packages=find_packages(),include_package_data=False,description=''
) 

这是我的predictor.py

import os
import pickle
import tensorflow as tf
import numpy as np

class MyPredictor(object):

    def __init__(self,model,bow_model):
        self._model = model
        self._bow_model = bow_model

    def predict(self,instances):

        outputs = []

        for x in instances:
            vector = self.embedding(x)
            output = self._model.predict(vector)
            outputs.append(output)

        return outputs

    def embedding(self,statement):
        vector = self._bow_model.transform(statement).toarray()
        vector = vector.to_list()
        return vector


    @classmethod
    def from_path(cls,model_dir):

        model_path = os.path.join(model_dir,'model.h5')
        model = tf.keras.models.load_model(model_path,compile = False)

        preprocessor_path = os.path.join(model_dir,'bow.pkl')
        with open(preprocessor_path,'rb') as f:
            bow_model = pickle.load(f)

        return cls(model,bow_model)

我用于测试的脚本是:

import googleapiclient.discovery

instances = ['test','test']]

service = googleapiclient.discovery.build('ml','v1')
name = 'projects/{}/models/{}/versions/{}'.format(PROJECT_ID,MODEL_NAME,VERSION_NAME)

response = service.projects().predict(
    name=name,body={'instances': instances}
).execute()

if 'error' in response:
    raise RuntimeError(response['error'])
else:
  print(response['predictions'])

解决方法

根据自定义预测例程 documentation,一旦创建预测器 classpredict() 方法应提供 self,instances,**kwargs 参数以正确处理预测请求。

instances:预测输入实例列表。

**kwargs:在预测请求正文中作为附加字段提供的关键字参数字典。