如何使用 SavedModel 在 Tensorflowjs 中读取 predict() 结果

问题描述

使用 tfjs-node代码

const model = await tf.node.loadSavedModel(modelPath);
const data = fs.readFileSync(imgPath);
const tfimage = tf.node.decodeImage(data,3);
const expanded = tfimage.expandDims(0);
const result = model.predict(expanded);
console.log(result);
for (r of result) {
   console.log(r.dataSync());
}

输出

(8) [Tensor,Tensor,Tensor]
Float32Array(100) [48700,48563,48779,49041,...]
Float32Array(400) [0.10901492834091187,0.18931034207344055,0.9181075692176819,0.8344497084617615,...]
Float32Array(100) [61,88,65,84,67,51,62,20,59,9,18,...]
Float32Array(9000) [0.009332209825515747,0.003941178321838379,0.0005068182945251465,0.001926332712173462,0.0020033419132232666,0.000742495059967041,0.022082984447479248,0.0032682716846466064,0.05071520805358887,0.000018596649169921875,...]
Float32Array(100) [0.6730095148086548,0.1356855034828186,0.12674063444137573,0.12360832095146179,0.10837388038635254,0.10075071454048157,...]
Float32Array(1) [100]
Float32Array(196416) [0.738592267036438,0.4373246729373932,0.738592267036438,0.546840488910675,-0.010780575685203075,0.00041256844997406006,0.03478313609957695,0.11279871314764023,-0.0504981130361557,-0.11237315833568573,0.02907072752714157,0.06638012826442719,0.001794634386897087,0.0009463857859373093,...]
Float32Array(4419360) [0.0564018189907074,0.016801774501800537,0.025803595781326294,0.011671125888824463,0.014013528823852539,0.008442580699920654,...]

如何读取对象检测的 predict() 响应?我期待一本包含 num_detectionsdetection_Boxesdetection_classes 等的字典,如 here 所述。

我也尝试过使用 tf.execute(),但它引发了以下错误UnhandledPromiseRejectionWarning: Error: execute() of TFSavedModel is not supported yet

我使用的是从 here 下载的 efficientdet/d0

解决方法

当您使用 dataSync() 下载张量时,它只会保留值。如果你想要一个没有张量的带有每个结果描述的对象,你只需要console.log(result)。然后你在浏览器控制台中展开你的日志结果,它应该返回如下内容:

Tensor {
  "dataId": Object {},"dtype": "float32","id": 160213,"isDisposedInternal": false,"kept": false,"rankType": "2","scopeId": 365032,"shape": Array [
    1,3,],"size": 3,"strides": Array [
    3,}

您的 console.log(result) 的输出中包含 8 tensors,这表明它是正确的。您正在遍历每个结果,每个输出都应遵循以下格式:

['num_detections','detection_boxes','detection_classes','detection_scores','raw_detection_boxes','raw_detection_scores,'detection_anchor_indices','detection_multiclass_scores']