前回 DeepLab v3 Semantic Segmentation を TensorFlow.js で試す(その2)TensorFlow Lite モデル編 のコードを流用して、別のモデル Mobile food segmentation model の TFLiteを試す。
推測結果
以下の3種類のパンの写っている写真を使用しました。 このモデルを使って推測するには、入力用の画像が 513 x 513 ピクセルである必要があります。
入力画像
推測結果の画像
実際に使用した 513x513 の JPG画像はこちら
入力画像
推測結果の画像
実際に使用した 513x513 の JPG画像はこちら
入力画像
推測結果の画像
チェリーのピクセル部分はパンではないので、(当然)除外されている。
実際に使用した 513x513 の JPG画像はこちら
https://tfhub.dev/google/seefood/segmenter/mobile_food_segmenter_V1/1 の labelmap をダウンロードして確認しましょう。
id,name
0,background
1,vegetables | leafy_greens
2,vegetables | stem_vegetables
3,vegetables | non-starchy_roots
4,vegetables | other
5,fruits
6,protein | meat
7,protein | poultry
8,protein | seafood
9,protein | eggs
10,protein | beans/nuts
11,starches/grains | baked_goods
12,starches/grains | rice/grains/cereals
13,starches/grains | noodles/pasta
14,starches/grains | starchy_vegetables
15,starches/grains | other
16,soups/stews
17,herbs/spices
18,dairy
19,snacks
20,sweets/desserts
21,beverages
22,fats/oils/sauces
23,food_containers
24,dining_tools
25,other_food
今回は、パンのピクセルを取得したいので、 推測結果から 11,starches/grains | baked_goods と推定されたピクセルを狙います。
環境:
$ node --version
v18.12.1
$ npm --version
8.19.2
必要なモジュール:
$ npm install @tensorflow/tfjs-node
$ npm install tfjs-tflite-node
$ npm install jimp
コードの内容は、前回とほとんど同じなので細かい説明は省きます。
前回のモデルでは21種類、 このモデルでは、26種類の識別対象があるので、そこは変更しています。( toPredictedClassIndex の numOfLabels )
const fs = require('fs')
const tf = require('@tensorflow/tfjs-node')
const tflite = require('tfjs-tflite-node')
const Jimp = require('jimp')
const loadModel = async () => {
const modelURL = 'https://tfhub.dev/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1'
return await tflite.loadTFLiteModel(modelURL)
}
const range = (v)=>{ return [...Array(v).keys()] }
//
// そのピクセル(x,y)において推測されたクラス番号を返す.
//
const toPredictedClassIndex = (jsArray, x,y)=>{
const numOfLabels = 26
const predictValues = range(numOfLabels).map(classIndex=> jsArray[0][y][x][classIndex])
const maxPredictValue = predictValues.reduce( (acc,value) => (acc<value) ? value : acc )
return range(numOfLabels).map((classIndex)=> {
return {
ok: (jsArray[0][y][x][classIndex] == maxPredictValue),
index: classIndex}
}).filter( item=> item.ok )[0].index
}
// 入力画像ファイル名を受け取り、結果を出力するファイル名も用意する.
const inputJPGFilename = process.argv[2]
inputJPGFilename.match(/(.*).jpg/)
const outputJPGFilename = `masked_${RegExp.$1}.jpg`
loadModel().then((model)=>{
tf.tidy(()=>{
// 画像をロードする.
const targetImage = fs.readFileSync(inputJPGFilename)
const targetImageTensor = tf.node.decodeImage(targetImage)
console.log(targetImageTensor.shape) // [ 513, 513, 3 ]
//const inputTensor = targetImageTensor.reshape( [ 1, 513, 513, 3 ] )
const inputTensor = tf.expandDims(targetImageTensor, 0)
console.log(inputTensor.shape) // [1, 513, 513, 3 ]
// 推測する.
const outputTensor = model.predict(inputTensor)
console.log(outputTensor.shape) // [ 1, 513, 513, 26 ]
// jsArrayにする.
const jsArray = outputTensor.arraySync()
// 結果を視覚化する.
const image = new Jimp(513, 513, 'black', (err, image) => {})
const imageW = image.bitmap.width
const imageH = image.bitmap.height
image.scan(0, 0, imageW, imageH, (x, y, idx)=> {
const predictedClassIndex = toPredictedClassIndex(jsArray, x,y)
if( predictedClassIndex==11 ){
// 11,starches/grains | baked_goods
image.bitmap.data[idx + 0] = 255 // red
image.bitmap.data[idx + 1] = 255 // green
image.bitmap.data[idx + 2] = 255 // blue
image.bitmap.data[idx + 3] = 255 // alpha
}
})
image.write(outputJPGFilename)
console.log(`save result as ${outputJPGFilename}`)
})
})
以下のように実行します。
$ node index.js cup-cake.jpg
結果は masked_cup-cake.jpg に保存されます。
パンはそれなりに識別できることがわかりました。
バナナも識別できるか試してみます。 対象写真もわざと雑然としたところで撮影してみました。
実際に使用した 513x513 の JPG画像はこちら
バナナを識別したいので、今度は推測結果の 5,fruits を対象にします。
if( predictedClassIndex==5 ){
// 5,fruits
image.bitmap.data[idx + 0] = 255 // red
image.bitmap.data[idx + 1] = 100 // green
image.bitmap.data[idx + 2] = 100 // blue
image.bitmap.data[idx + 3] = 255 // alpha
}
実行した結果の画像
黒くなったバナナの柄の部分はバナナとして識別されませんでした。
最終的にできたパンとバナナを識別できるコードを掲載します。
const fs = require('fs')
const tf = require('@tensorflow/tfjs-node')
const tflite = require('tfjs-tflite-node')
const Jimp = require('jimp')
const loadModel = async () => {
const modelURL = 'https://tfhub.dev/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1'
return await tflite.loadTFLiteModel(modelURL)
}
const range = (v)=>{ return [...Array(v).keys()] }
//
// そのピクセル(x,y)において推測されたクラス番号を返す.
//
const toPredictedClassIndex = (jsArray, x,y)=>{
const numOfLabels = 26
const predictValues = range(numOfLabels).map(classIndex=> jsArray[0][y][x][classIndex])
const maxPredictValue = predictValues.reduce( (acc,value) => (acc<value) ? value : acc )
return range(numOfLabels).map((classIndex)=> {
return {
ok: (jsArray[0][y][x][classIndex] == maxPredictValue),
index: classIndex}
}).filter( item=> item.ok )[0].index
}
// 入力画像ファイル名を受け取り、結果を出力するファイル名も用意する.
const inputJPGFilename = process.argv[2]
inputJPGFilename.match(/(.*).jpg/)
const outputJPGFilename = `masked_${RegExp.$1}.jpg`
loadModel().then((model)=>{
tf.tidy(()=>{
// 画像をロードする.
const targetImage = fs.readFileSync(inputJPGFilename)
const targetImageTensor = tf.node.decodeImage(targetImage)
console.log(targetImageTensor.shape) // [ 513, 513, 3 ]
//const inputTensor = targetImageTensor.reshape( [ 1, 513, 513, 3 ] )
const inputTensor = tf.expandDims(targetImageTensor, 0)
console.log(inputTensor.shape) // [1, 513, 513, 3 ]
// 推測する.
const outputTensor = model.predict(inputTensor)
console.log(outputTensor.shape) // [ 1, 513, 513, 26 ]
// jsArrayにする.
const jsArray = outputTensor.arraySync()
// 結果を視覚化する.
const image = new Jimp(513, 513, 'black', (err, image) => {})
const imageW = image.bitmap.width
const imageH = image.bitmap.height
image.scan(0, 0, imageW, imageH, (x, y, idx)=> {
const predictedClassIndex = toPredictedClassIndex(jsArray, x,y)
if( predictedClassIndex==5 ){
// 5,fruits
image.bitmap.data[idx + 0] = 255 // red
image.bitmap.data[idx + 1] = 100 // green
image.bitmap.data[idx + 2] = 100 // blue
image.bitmap.data[idx + 3] = 255 // alpha
}
if( predictedClassIndex==11 ){
// 11,starches/grains | baked_goods
image.bitmap.data[idx + 0] = 255 // red
image.bitmap.data[idx + 1] = 255 // green
image.bitmap.data[idx + 2] = 255 // blue
image.bitmap.data[idx + 3] = 255 // alpha
}
})
image.write(outputJPGFilename)
console.log(`save result as ${outputJPGFilename}`)
})
})
実行:
$ node index.js cup-cake.jpg
[ 513, 513, 3 ]
[ 1, 513, 513, 3 ]
WARNING: converting 'int32' to 'uint8'
[ 1, 513, 513, 26 ]
save result as masked_cup-cake.jpg
$ node index.js banana_513x513.jpg
[ 513, 513, 3 ]
[ 1, 513, 513, 3 ]
WARNING: converting 'int32' to 'uint8'
[ 1, 513, 513, 26 ]
save result as masked_banana_513x513.jpg
- カップケーキ
- 入力画像 cup-cake.jpg
- 出力画像 masked_cup-cake.jpg
- バナナ
- 入力画像 banana_513x513.jpg
- 出力画像 masked_banana_513x513.jpg
以上です。