问题描述
我正在建立一个tfx管道并使用tensorflow服务来服务我的模型。我用model.save(...)
保存签名。
到目前为止,我能够使用tf_transform_output.transform_features_layer()
进行预测之前使用变换层对特征进行变换(请参见下面的代码)。
但是,我想知道如何检测输入数据中的异常?例如,我不想预测一个输入值与以前使用某功能训练过的分布相距太远。
tfdv
库提供的功能类似于generate_statistics_from_[csv|dataframe|tfrecord]
,但是我找不到任何好的示例来为序列化的tf.Example
(或未保存在文件中的内容)生成统计信息,例如csv,tfrecords等。
我知道以下example in the documentation:
import tensorflow_data_validation as tfdv
import tfx_bsl
import pyarrow as pa
decoder = tfx_bsl.coders.example_coder.ExamplesToRecordBatchDecoder()
example = decoder.DecodeBatch([serialized_tfexample])
options = tfdv.StatsOptions(schema=schema)
anomalies = tfdv.validate_instance(example,options)
但是在此示例中,serialized_tfexample
是字符串,而在我的代码中,参数serialized_tf_examples
下面是字符串的张量。
很抱歉,如果这是一个明显的问题。我整天都在寻找解决方案,但没有成功。也许我把这一切弄错了。也许这不是进行验证的正确位置。因此,我更笼统的问题是:在生产中通过tfx管道创建的模型提供服务时,如何在预测之前验证传入的输入数据? 我感谢您为正确的方向提供了指导。
...
tf_transform_output = tft.TFTransformOutput(...)
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function(input_signature=[
tf.TensorSpec(shape=[None],dtype=tf.string,name='examples')
])
def serve_tf_examples_fn(serialized_tf_examples):
#### How can I generate stats and validate serialized_tf_examples? ###
#### Is this the right place? ###
feature_spec = tf_transform_output.raw_feature_spec()
feature_spec.pop(TARGET_LABEL)
parsed_features = tf.io.parse_example(serialized_tf_examples,feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
...
model.save(serving_model_dir,save_format='tf',signatures={
'serving_default': serve_tf_examples_fn
})
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)