问题描述
我是一个试图定义 SageMaker 管道的 Tensorflow 和 Python 新手。 目前,我在尝试在 SageMaker 中运行时遇到了保存模型 Movie Lens example 的问题。我设法使用下面的代码训练模型,并使用 SavedModel API 将其保存到适当的 S3 存储桶。当我加载模型并尝试使用加载的模型进行预测时,我收到错误消息:
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* <_VariantDataset shapes: (),types: tf.string>
* True
* None
Keyword arguments:
{}
Expected these arguments to match one of the following 0 option(s):
模型加载和预测:
model_path = "/opt/ml/processing/model"
tar_path = os.path.join(model_path,"model.tar.gz")
logger.info(tar_path)
with tarfile.open(tar_path) as tar:
tar.extractall(path=model_path)
logger.info("Extracted model.")
model = tf.saved_model.load(model_path)
scores,titles = model(tf_ratings.take(1))
模型类
class MovieLensModel(tfrs.Model):
# We derive from a custom base class to help reduce boilerplate. Under the hood,# these are still plain Keras Models.
def __init__(self,user_model: tf.keras.Model,movie_model: tf.keras.Model,task: tfrs.tasks.Retrieval):
super(tfrs.Model,self).__init__() #added arguments for super
# Set up user and movie representations.
self.user_model = user_model
self.movie_model = movie_model
# Set up a retrieval task.
self.task = task
@tf.function
def __call__(self,x,training=True,mask=None):
user_embeddings = self.user_model(x)
movie_embeddings = self.movie_model(X)
return self.task(user_embeddings,movie_embeddings)
def compute_loss(self,features: Dict[Text,tf.Tensor],training=False) -> tf.Tensor:
# Define how the loss is computed.
user_embeddings = self.user_model(features[0])
movie_embeddings = self.movie_model(features[0])
return self.task(user_embeddings,movie_embeddings)
模型训练和保存:
task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
tf_movies.batch(128).map(movie_model)
))
logger.info("Model training...")
# Create a retrieval model.
model = MovieLensModel(user_model,movie_model,task)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.5))
# Train for 3 epochs.
model.fit(tf_ratings.batch(args.batch_size),epochs = args.epochs)
# save mod
model_path = os.path.join(args.model_dir,"movie_lens")
logger.info("Model path is " + args.model_dir)
model.task = tfrs.tasks.Retrieval() # Removes the metrics.
tf.saved_model.save(model,args.model_dir)
用于 SageMaker 容器的镜像是 tensorflow-training:2.2-cpu-py37
我认为来自上述错误的位置参数与模型类中的 __call__
函数匹配。这里困扰我的是我不明白给出的错误,如果在执行 __call__
时识别出 model(value)
函数,预测为什么我会得到没有匹配函数的错误?
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)