Home About Contact
Semantic Segmentation , TensorFlow.js

Mobile Food Segmenter を試す

前回 DeepLab v3 Semantic Segmentation を TensorFlow.js で試す(その2)TensorFlow Lite モデル編 のコードを流用して、別のモデル Mobile food segmentation model の TFLiteを試す。

推測結果

Mobile food segmentation

対象とした画像

以下の3種類のパンの写っている写真を使用しました。 このモデルを使って推測するには、入力用の画像が 513 x 513 ピクセルである必要があります。

カップケーキ(をカットしたもの)

入力画像
cup cake

推測結果の画像
masked cup cake

実際に使用した 513x513 の JPG画像はこちら

パイのコロネ

入力画像
pai coronet

推測結果の画像
masked pai cornet

実際に使用した 513x513 の JPG画像はこちら

ダークチェリー

入力画像
dark cherry

推測結果の画像
masked dark cherry

チェリーのピクセル部分はパンではないので、(当然)除外されている。

実際に使用した 513x513 の JPG画像はこちら

ラベルマップ

https://tfhub.dev/google/seefood/segmenter/mobile_food_segmenter_V1/1labelmap をダウンロードして確認しましょう。

id,name
0,background
1,vegetables | leafy_greens
2,vegetables | stem_vegetables
3,vegetables | non-starchy_roots
4,vegetables | other
5,fruits
6,protein | meat
7,protein | poultry
8,protein | seafood
9,protein | eggs
10,protein | beans/nuts
11,starches/grains | baked_goods
12,starches/grains | rice/grains/cereals
13,starches/grains | noodles/pasta
14,starches/grains | starchy_vegetables
15,starches/grains | other
16,soups/stews
17,herbs/spices
18,dairy
19,snacks
20,sweets/desserts
21,beverages
22,fats/oils/sauces
23,food_containers
24,dining_tools
25,other_food

今回は、パンのピクセルを取得したいので、 推測結果から 11,starches/grains | baked_goods と推定されたピクセルを狙います。

コード

環境:

$ node --version
v18.12.1
$ npm --version
8.19.2

必要なモジュール:

$ npm install @tensorflow/tfjs-node
$ npm install tfjs-tflite-node
$ npm install jimp

コードの内容は、前回とほとんど同じなので細かい説明は省きます。

前回のモデルでは21種類、 このモデルでは、26種類の識別対象があるので、そこは変更しています。( toPredictedClassIndexnumOfLabels )

const fs = require('fs')
const tf = require('@tensorflow/tfjs-node')
const tflite = require('tfjs-tflite-node')
const Jimp = require('jimp')

const loadModel = async () => {
    const modelURL = 'https://tfhub.dev/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1'
    return await tflite.loadTFLiteModel(modelURL)
}


const range = (v)=>{ return [...Array(v).keys()] }

//
// そのピクセル(x,y)において推測されたクラス番号を返す.
//
const toPredictedClassIndex = (jsArray, x,y)=>{
    const numOfLabels = 26
    const predictValues = range(numOfLabels).map(classIndex=> jsArray[0][y][x][classIndex])
    const maxPredictValue = predictValues.reduce( (acc,value) => (acc<value) ? value : acc )
    return range(numOfLabels).map((classIndex)=> {
        return {
            ok: (jsArray[0][y][x][classIndex] == maxPredictValue),
            index: classIndex}
    }).filter( item=> item.ok )[0].index
}


// 入力画像ファイル名を受け取り、結果を出力するファイル名も用意する.
const inputJPGFilename = process.argv[2]
inputJPGFilename.match(/(.*).jpg/)
const outputJPGFilename = `masked_${RegExp.$1}.jpg`

loadModel().then((model)=>{
    tf.tidy(()=>{
        // 画像をロードする.
        const targetImage = fs.readFileSync(inputJPGFilename)
        const targetImageTensor = tf.node.decodeImage(targetImage)
        console.log(targetImageTensor.shape) // [ 513, 513, 3 ]
    
        //const inputTensor = targetImageTensor.reshape( [ 1, 513, 513, 3 ] )
        const inputTensor = tf.expandDims(targetImageTensor, 0)
        console.log(inputTensor.shape) // [1, 513, 513, 3 ]
    
        // 推測する.
        const outputTensor = model.predict(inputTensor)
        console.log(outputTensor.shape) // [ 1, 513, 513, 26 ]
    
        // jsArrayにする.
        const jsArray = outputTensor.arraySync()
    
        // 結果を視覚化する.
        const image = new Jimp(513, 513, 'black', (err, image) => {})
        const imageW = image.bitmap.width
        const imageH = image.bitmap.height
        image.scan(0, 0, imageW, imageH, (x, y, idx)=> {
            const predictedClassIndex = toPredictedClassIndex(jsArray, x,y)
    
            if( predictedClassIndex==11 ){
                // 11,starches/grains | baked_goods
                image.bitmap.data[idx + 0] = 255 // red
                image.bitmap.data[idx + 1] = 255 // green
                image.bitmap.data[idx + 2] = 255 // blue
                image.bitmap.data[idx + 3] = 255 // alpha
            }
        })
    
        image.write(outputJPGFilename)
        console.log(`save result as ${outputJPGFilename}`)
    })
})

以下のように実行します。

$ node index.js cup-cake.jpg

結果は masked_cup-cake.jpg に保存されます。

バナナも識別してみる

パンはそれなりに識別できることがわかりました。

バナナも識別できるか試してみます。 対象写真もわざと雑然としたところで撮影してみました。

banana

実際に使用した 513x513 の JPG画像はこちら

バナナを識別したいので、今度は推測結果の 5,fruits を対象にします。

if( predictedClassIndex==5 ){
    // 5,fruits
    image.bitmap.data[idx + 0] = 255 // red
    image.bitmap.data[idx + 1] = 100 // green
    image.bitmap.data[idx + 2] = 100 // blue
    image.bitmap.data[idx + 3] = 255 // alpha
}

実行した結果の画像

masked banana resized

黒くなったバナナの柄の部分はバナナとして識別されませんでした。

まとめ

最終的にできたパンとバナナを識別できるコードを掲載します。

const fs = require('fs')
const tf = require('@tensorflow/tfjs-node')
const tflite = require('tfjs-tflite-node')
const Jimp = require('jimp')

const loadModel = async () => {
    const modelURL = 'https://tfhub.dev/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1'
    return await tflite.loadTFLiteModel(modelURL)
}


const range = (v)=>{ return [...Array(v).keys()] }

//
// そのピクセル(x,y)において推測されたクラス番号を返す.
//
const toPredictedClassIndex = (jsArray, x,y)=>{
    const numOfLabels = 26
    const predictValues = range(numOfLabels).map(classIndex=> jsArray[0][y][x][classIndex])
    const maxPredictValue = predictValues.reduce( (acc,value) => (acc<value) ? value : acc )
    return range(numOfLabels).map((classIndex)=> {
        return {
            ok: (jsArray[0][y][x][classIndex] == maxPredictValue),
            index: classIndex}
    }).filter( item=> item.ok )[0].index
}


// 入力画像ファイル名を受け取り、結果を出力するファイル名も用意する.
const inputJPGFilename = process.argv[2]
inputJPGFilename.match(/(.*).jpg/)
const outputJPGFilename = `masked_${RegExp.$1}.jpg`


loadModel().then((model)=>{
    tf.tidy(()=>{
        // 画像をロードする.
        const targetImage = fs.readFileSync(inputJPGFilename)
        const targetImageTensor = tf.node.decodeImage(targetImage)
        console.log(targetImageTensor.shape) // [ 513, 513, 3 ]
    
        //const inputTensor = targetImageTensor.reshape( [ 1, 513, 513, 3 ] )
        const inputTensor = tf.expandDims(targetImageTensor, 0)
        console.log(inputTensor.shape) // [1, 513, 513, 3 ]
    
        // 推測する.
        const outputTensor = model.predict(inputTensor)
        console.log(outputTensor.shape) // [ 1, 513, 513, 26 ]
    
        // jsArrayにする.
        const jsArray = outputTensor.arraySync()
    
        // 結果を視覚化する.
        const image = new Jimp(513, 513, 'black', (err, image) => {})
        const imageW = image.bitmap.width
        const imageH = image.bitmap.height
        image.scan(0, 0, imageW, imageH, (x, y, idx)=> {
            const predictedClassIndex = toPredictedClassIndex(jsArray, x,y)

            if( predictedClassIndex==5 ){
                // 5,fruits
                image.bitmap.data[idx + 0] = 255 // red
                image.bitmap.data[idx + 1] = 100 // green
                image.bitmap.data[idx + 2] = 100 // blue
                image.bitmap.data[idx + 3] = 255 // alpha
            }
    
            if( predictedClassIndex==11 ){
                // 11,starches/grains | baked_goods
                image.bitmap.data[idx + 0] = 255 // red
                image.bitmap.data[idx + 1] = 255 // green
                image.bitmap.data[idx + 2] = 255 // blue
                image.bitmap.data[idx + 3] = 255 // alpha
            }
        })
    
        image.write(outputJPGFilename)
        console.log(`save result as ${outputJPGFilename}`)
    })
})

実行:

$ node index.js cup-cake.jpg
[ 513, 513, 3 ]
[ 1, 513, 513, 3 ]
WARNING: converting 'int32' to 'uint8'
[ 1, 513, 513, 26 ]
save result as masked_cup-cake.jpg

$ node index.js banana_513x513.jpg
[ 513, 513, 3 ]
[ 1, 513, 513, 3 ]
WARNING: converting 'int32' to 'uint8'
[ 1, 513, 513, 26 ]
save result as masked_banana_513x513.jpg

以上です。