问题描述
我有一个关于使用 tensorflow.js 通过网络摄像头检测对象的问题。目前我使用的是预训练模型 coco-ssd。
index.html:
<html lang="en">
<head>
<title>Multiple object detection using pre trained model in tensorflow.js</title>
<Meta charset="utf-8">
<Meta http-equiv="X-UA-Compatible" content="IE=edge">
<Meta name="viewport" content="width=device-width,initial-scale=1">
<!-- Import the webpage's stylesheet -->
<link rel="stylesheet" href="style.css">
</head>
<body>
<h1>Multiple object detection using pre trained model in tensorflow.js</h1>
<p>Wait for the model to load before clicking the button to enable the webcam - at which point it will become visible to use.</p>
<section id="demos" class="invisible">
<p>Hold some objects up close to your webcam to get a real-time classification! When ready click "enable webcam" below and accept access to the webcam when the browser asks (check the top left of your window)</p>
<div id="liveView" class="camView">
<button id="webcamButton">Enable Webcam</button>
<video id="webcam" autoplay width="640" height="480"></video>
</div>
</section>
<!-- Import tensorflow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
<!-- Load the coco-ssd model to use to recognize things in images -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>
<!-- Import the page's JavaScript to do some stuff -->
<script src="script.js" defer></script>
</body>
</html>
script.js:
const video = document.getElementById('webcam');
const liveView = document.getElementById('liveView');
const demosSection = document.getElementById('demos');
const enableWebcamButton = document.getElementById('webcamButton');
// Check if webcam access is supported.
function getUserMediaSupported() {
return !!(navigator.mediaDevices &&
navigator.mediaDevices.getUserMedia);
}
// If webcam supported,add event listener to button for when user
// wants to activate it to call enableCam function which we will
// define in the next step.
if (getUserMediaSupported()) {
enableWebcamButton.addEventListener('click',enableCam);
} else {
console.warn('getUserMedia() is not supported by your browser');
}
// Enable the live webcam view and start classification.
function enableCam(event) {
// Only continue if the COCO-SSD has finished loading.
if (!model) {
return;
}
// Hide the button once clicked.
event.target.classList.add('removed');
// getUsermedia parameters to force video but not audio.
const constraints = {
video: true
};
// Activate the webcam stream.
navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
video.srcObject = stream;
video.addEventListener('loadeddata',predictWebcam);
});
}
// Store the resulting model in the global scope of our app.
var model = undefined;
// Before we can use COCO-SSD class we must wait for it to finish
// loading. Machine Learning models can be large and take a moment
// to get everything needed to run.
// Note: cocoSsd is an external object loaded from our index.html
// script tag import so ignore any warning in Glitch.
cocoSsd.load().then(function (loadedModel) {
model = loadedModel;
// Show demo section Now model is ready to use.
demosSection.classList.remove('invisible');
});
var children = [];
function predictWebcam() {
// Now let's start classifying a frame in the stream.
model.detect(video).then(function (predictions) {
// Remove any highlighting we did prevIoUs frame.
for (let i = 0; i < children.length; i++) {
liveView.removeChild(children[i]);
}
children.splice(0);
// Now lets loop through predictions and draw them to the live view if
// they have a high confidence score.
for (let n = 0; n < predictions.length; n++) {
// If we are over 66% sure we are sure we classified it right,draw it!
if (predictions[n].score > 0.66) {
const p = document.createElement('p');
p.innerText = predictions[n].class + ' - with '
+ Math.round(parseFloat(predictions[n].score) * 100)
+ '% confidence.';
p.style = 'margin-left: ' + predictions[n].bBox[0] + 'px; margin-top: '
+ (predictions[n].bBox[1] - 10) + 'px; width: '
+ (predictions[n].bBox[2] - 10) + 'px; top: 0; left: 0;';
const Highlighter = document.createElement('div');
Highlighter.setAttribute('class','Highlighter');
Highlighter.style = 'left: ' + predictions[n].bBox[0] + 'px; top: '
+ predictions[n].bBox[1] + 'px; width: '
+ predictions[n].bBox[2] + 'px; height: '
+ predictions[n].bBox[3] + 'px;';
liveView.appendChild(Highlighter);
liveView.appendChild(p);
children.push(Highlighter);
children.push(p);
}
}
// Call this function again to keep predicting when the browser is ready.
window.requestAnimationFrame(predictWebcam);
});
}
现在我想自定义脚本以使用我自己的模型,我之前使用 Tensorflow for Python 创建和训练了该模型。我已经使用转换器 tfjs_convert 将其转换为 .json 格式。
如何修改我的代码以便现在使用我自己的模型?我已经尝试了一些东西,但很遗憾没有取得任何进展。
解决方法
您可以使用 @tensorflow/tfjs-converter 中的 loadGraphModel 从 Json 加载。
我喜欢 this 示例。