index.js

Created Diff never expires
22 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
104 lines
15 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
97 lines
const functions = require('firebase-functions');
const express = require('express');
const express = require('express');
const Busboy = require('busboy');
const Busboy = require('busboy');
const path = require('path');
const path = require('path');


const tf = require('@tensorflow/tfjs-node');
const tf = require('@tensorflow/tfjs-node');

const labels = require('./model/new_object_detection_1/assets/labels.json');
const labels = require('./model/new_object_detection_1/assets/labels.json');

const app = express();
const app = express();


let objectDetectionModel;
let objectDetectionModel;


async function loadModel() {
async function loadModel() {
// Warm up the model
// Warm up the model
if (!objectDetectionModel) {
if (!objectDetectionModel) {
// Load the TensorFlow SavedModel through tfjs-node API. You can find more
// Load the TensorFlow SavedModel through tfjs-node API. You can find more
// details in the API documentation:
// details in the API documentation:
// https://js.tensorflow.org/api_node/1.3.1/#node.loadSavedModel
// https://js.tensorflow.org/api_node/1.3.1/#node.loadSavedModel
objectDetectionModel = await tf.node.loadSavedModel(
objectDetectionModel = await tf.node.loadSavedModel(
'./model/new_object_detection_1', ['serve'], 'serving_default');
'./model/new_object_detection_1', ['serve'], 'serving_default');
}
}
const tempTensor = tf.zeros([1, 2, 2, 3]).toInt();
const tempTensor = tf.zeros([1, 2, 2, 3]).toInt();
objectDetectionModel.predict(tempTensor);
objectDetectionModel.predict(tempTensor);
}
}


app.get('/', async (req, res) => {
app.get('/', async (req, res) => {
res.sendFile(path.join(__dirname + '/index.html'));
res.sendFile(path.join(__dirname + '/index.html'));
loadModel();
})
})


app.post('/predict', async (req, res) => {
app.post('/predict', async (req, res) => {
// Receive and parse the image from client side, then feed it into the model
// Receive and parse the image from client side, then feed it into the model
// for inference.
// for inference.
const busboy = new Busboy({headers: req.headers});
const busboy = new Busboy({headers: req.headers});
let fileBuffer = new Buffer('');
let fileBuffer = Buffer.from('')
req.files = {file: []};
const files = [];

busboy.on('field', (fieldname, value) => {
req.body[fieldname] = value;
});

busboy.on('file', (fieldname, file, filename, encoding, mimetype) => {
busboy.on('file', (fieldname, file, filename, encoding, mimetype) => {
file.on('data', (data) => {fileBuffer = Buffer.concat([fileBuffer, data])});
file.on('data', (data) => {
fileBuffer = Buffer.concat([fileBuffer, data]);
});


file.on('end', () => {
file.on('end', () => {
const file_object = {
const file_object = {
fieldname,
fieldname,
'originalname': filename,
'originalname': filename,
encoding,
encoding,
mimetype,
mimetype,
buffer: fileBuffer
buffer: fileBuffer
};
};


req.files.file.push(file_object)
files.push(file_object)
});
});
});
});


busboy.on('finish', async () => {
busboy.on('finish', async () => {
const buf = req.files.file[0].buffer;
const buf = files[0].buffer;
const uint8array = new Uint8Array(buf);
const uint8array = new Uint8Array(buf);


loadModel();
// Decode the image into a tensor.
// Decode the image into a tensor.
const imageTensor = await tf.node.decodeImage(uint8array);
const imageTensor = await tf.node.decodeImage(uint8array);
const input = imageTensor.expandDims(0);
const input = imageTensor.expandDims(0);


// Feed the image tensor into the model for inference.
// Feed the image tensor into the model for inference.
const startTime = tf.util.now();
const startTime = tf.util.now();
let outputTensor = objectDetectionModel.predict({'x': input});
let outputTensor = objectDetectionModel.predict({'x': input});


// Parse the model output to get meaningful result(get detection class and
// Parse the model output to get meaningful result(get detection class and
// object location).
// object location).
const scores = await outputTensor['detection_scores'].arraySync();
const scores = await outputTensor['detection_scores'].arraySync();
const boxes = await outputTensor['detection_boxes'].arraySync();
const boxes = await outputTensor['detection_boxes'].arraySync();
const names = await outputTensor['detection_classes'].arraySync();
const names = await outputTensor['detection_classes'].arraySync();
const endTime = tf.util.now();
const endTime = tf.util.now();
outputTensor['detection_scores'].dispose();
outputTensor['detection_scores'].dispose();
outputTensor['detection_boxes'].dispose();
outputTensor['detection_boxes'].dispose();
outputTensor['detection_classes'].dispose();
outputTensor['detection_classes'].dispose();
outputTensor['num_detections'].dispose();
outputTensor['num_detections'].dispose();
const detectedBoxes = [];
const detectedBoxes = [];
const detectedNames = [];
const detectedNames = [];
for (let i = 0; i < scores[0].length; i++) {
for (let i = 0; i < scores[0].length; i++) {
if (scores[0][i] > 0.3) {
if (scores[0][i] > 0.3) {
detectedBoxes.push(boxes[0][i]);
detectedBoxes.push(boxes[0][i]);
detectedNames.push(labels[names[0][i]]);
detectedNames.push(labels[names[0][i]]);
}
}
}
}
res.send({
res.send({
boxes: detectedBoxes,
boxes: detectedBoxes,
names: detectedNames,
names: detectedNames,
inferenceTime: endTime - startTime
inferenceTime: endTime - startTime
});
});
});
});


busboy.end(req.rawBody);
req.pipe(busboy);
req.pipe(busboy);
});
});


loadModel();
loadModel().then(function(){

console.log(`loaded saved model`);
exports.app = functions.https.onRequest(app);
const port = process.env.PORT || 8080;
app.listen(port, () => {
console.log(`listening on port ${port}`);
});
});