Home About Contact
Kotlin , MNIST

KotlinDL で MNIST データを取得してローカルに保存

KotlinDL経由で MNISTデータを取得する方法を調べたその備忘録です。

以前に Python で書いた download.pyと同じ機能をKotlinDLで実装します。

環境

環境は以下の通り。

$ kotlinc -version
info: kotlinc-jvm 1.8.10 (JRE 17.0.8.1+1-Ubuntu-0ubuntu122.04)

MNIST データセットの取得

必須ではありませんが、 https://github.com/Kotlin/kotlindl からソースコードを取得して、以下のクラスを手始めに関連するコードを読むとよい。

maven レポジトリから関連ライブラリを取得します。

dl.main.kts

@file:Repository("https://repo1.maven.org/maven2/")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2")

import org.jetbrains.kotlinx.dl.dataset.embedded.mnist

val (train, test) = mnist()

これだけで、train の images, labels と test の images, labels が取得できます。 train や test は OnHeapDataset クラスのインスタンスです。

train の方の内容を調べてみます。

println( train ) // OnHeapDataset
println( train.xSize() ) // 60000
println( train.getX(0).size ) // 784 (28x28)
println( train.getY(0) ) // 5.0

60000 件のデータがあることがわかります。 そして getX(), getY() で入力(画像)と出力(正解ラベル)を得ることができます。

$ kotlinc -script dl.main.kts
org.jetbrains.kotlinx.dl.dataset.OnHeapDataset@689cc29a
60000
784
5.0

もし heap が足りないとかのエラーが出たら export JAVA_OPTS="-Xmx2g" などとして回避します。

初回はデータをダウンロードするので少し時間がかかります。 データはカレントディレクトリの ./cache 以下にキャッシュされます。

./cache/datasets/mnist/train-labels-idx1-ubyte.gz
./cache/datasets/mnist/train-images-idx3-ubyte.gz
./cache/datasets/mnist/t10k-labels-idx1-ubyte.gz
./cache/datasets/mnist/t10k-images-idx3-ubyte.gz

JPEG 画像にする

train.getX(0) すると (28x28=)784 の長さの FloatArray が取得できます。 これが MNIST の最初の画像データです。FloatArray のそれぞれは 0.0..1.0 の範囲の値になっている。

以下のように書けば、784個のデータを標準出力することができます。

val firstImageDataArray: FloatArray = train.getX(0)

0.until(firstImageDataArray.size).forEach { index->
    val pxValue = firstImageDataArray[index]
    println("- ${index}=${pxValue}")
}

これでは少し分かりづらいので、以下のように 28個ごとに改行して出力してみます。

val firstImageDataArray: FloatArray = train.getX(0)

0.until(28).forEach { row->
    val rowData = 0.until(28).map { col->
        val index = row*28 + col
        val v = firstImageDataArray[index]
        "%.1f".format(v) // 少数第一位まで出力
    }.joinToString(",")
    println(rowData)
}

実行すると、以下のように出力されます。 目を細めれば、数字の5が見えなくもない。

0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.1,0.1,0.5,0.5,0.7,0.1,0.7,1.0,1.0,0.5,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.1,0.4,0.6,0.7,1.0,1.0,1.0,1.0,1.0,0.9,0.7,1.0,0.9,0.8,0.3,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.9,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.4,0.3,0.3,0.2,0.2,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.9,1.0,1.0,1.0,1.0,1.0,0.8,0.7,1.0,0.9,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3,0.6,0.4,1.0,1.0,0.8,0.0,0.0,0.2,0.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.0,0.6,1.0,0.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5,1.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.7,1.0,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.9,0.9,0.6,0.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3,0.9,1.0,1.0,0.5,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.7,1.0,1.0,0.6,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.4,1.0,1.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.5,0.7,1.0,1.0,0.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.6,0.9,1.0,1.0,1.0,1.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.4,0.9,1.0,1.0,1.0,1.0,0.8,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.3,0.8,1.0,1.0,1.0,1.0,0.8,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.7,0.9,1.0,1.0,1.0,1.0,0.8,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.2,0.7,0.9,1.0,1.0,1.0,1.0,1.0,0.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.5,1.0,1.0,1.0,0.8,0.5,0.5,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0

まてまて、以下のように 0f 以上の場合だけ数字を出すようにすれば....

0.until(28).forEach { row->
    val rowData = 0.until(28).map { col->
        val index = row*28 + col
        val v = firstImageDataArray[index]
        //"%.1f".format(v) // 少数第一位まで出力
	if( v>0f ){ "%.1f".format(v) } else { "..." }	
    }.joinToString(",")
    println(rowData)
}

この通り可視化できた。

...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,0.0,0.1,0.1,0.1,0.5,0.5,0.7,0.1,0.7,1.0,1.0,0.5,...,...,...,...
...,...,...,...,...,...,...,...,0.1,0.1,0.4,0.6,0.7,1.0,1.0,1.0,1.0,1.0,0.9,0.7,1.0,0.9,0.8,0.3,...,...,...,...
...,...,...,...,...,...,...,0.2,0.9,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.4,0.3,0.3,0.2,0.2,...,...,...,...,...
...,...,...,...,...,...,...,0.1,0.9,1.0,1.0,1.0,1.0,1.0,0.8,0.7,1.0,0.9,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,0.3,0.6,0.4,1.0,1.0,0.8,0.0,...,0.2,0.6,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,0.1,0.0,0.6,1.0,0.4,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,0.5,1.0,0.7,0.0,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,0.0,0.7,1.0,0.3,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,0.1,0.9,0.9,0.6,0.4,0.0,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,0.3,0.9,1.0,1.0,0.5,0.1,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,0.2,0.7,1.0,1.0,0.6,0.1,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,0.1,0.4,1.0,1.0,0.7,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,1.0,1.0,1.0,0.3,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,0.2,0.5,0.7,1.0,1.0,0.8,0.0,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,0.2,0.6,0.9,1.0,1.0,1.0,1.0,0.7,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,0.1,0.4,0.9,1.0,1.0,1.0,1.0,0.8,0.3,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,0.1,0.3,0.8,1.0,1.0,1.0,1.0,0.8,0.3,0.0,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,0.1,0.7,0.9,1.0,1.0,1.0,1.0,0.8,0.3,0.0,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,0.2,0.7,0.9,1.0,1.0,1.0,1.0,1.0,0.5,0.0,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,0.5,1.0,1.0,1.0,0.8,0.5,0.5,0.1,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...

それではこれを画像に変換します。

val toBufferedImage: (FloatArray, Int, Int)-> BufferedImage = { pixelValues ,w , h->
    val image = BufferedImage(w, h, BufferedImage.TYPE_BYTE_GRAY)

    0.until(h).forEach { row->
        0.until(w).forEach { col->
            val index = row*w + col
            val v = (pixelValues[index] * 255).toInt()
            image.setRGB(col, row, Color(v, v, v).getRGB())
        }
    }

    image
}

この toBufferedImage 関数を以下のように使えば 0_5.jpg に画像を保存できます。

val index = 0
val bimg = toBufferedImage( train.getX(index), 28, 28 )
ImageIO.write(bimg, "JPEG", File("${index}_${train.getY(index).toInt()}.jpg"))

5

これで train の先頭のデータを画像として保存できました。

正解ラベルのワンホットエンコーディング

次に、正解ラベルをワンホットエンコーディングに変換、それを CSVファイル として保存します。

train.getY(0) すれば 0 の画像が 0..9 のどれかを示す正解の数値が入っている。 0番目は 数字の 5 なので、 5.0 が入っている。

最初次のようなコードを考えたのですが...

val toOneHotString: (Float)-> String = { v->
    when(v.toInt()){
        0 -> { "1,0,0,0,0,0,0,0,0,0" }
        1 -> { "0,1,0,0,0,0,0,0,0,0" }
        2 -> { "0,0,1,0,0,0,0,0,0,0" }
        3 -> { "0,0,0,1,0,0,0,0,0,0" }
        4 -> { "0,0,0,0,1,0,0,0,0,0" }
        5 -> { "0,0,0,0,0,1,0,0,0,0" }
        6 -> { "0,0,0,0,0,0,1,0,0,0" }
        7 -> { "0,0,0,0,0,0,0,1,0,0" }
        8 -> { "0,0,0,0,0,0,0,0,1,0" }
        9 -> { "0,0,0,0,0,0,0,0,0,1" }
        else -> { "0,0,0,0,0,0,0,0,0,0" }
    }
}

これでもちろん機能するのですが、 org.jetbrains.kotlinx.dl.dataset.OnHeapDataset.kt を見ると:

/** Creates binary vector with size [numClasses] from [label]. */
@JvmStatic
public fun toOneHotVector(numClasses: Int, label: Byte): FloatArray {
    val ret = FloatArray(numClasses)
    ret[label.toInt() and SHIFT_NUMBER] = 1f
    return ret
}

という同じ処理をするコードがすでに存在していました。

ちなみに SHIFT_NUMBER = 0xFF です。

このコードを参考にすれば、以下のように書けます。

val toOneHotVector: (Float, Int)-> IntArray = { v, numClasses->
    val oneHotV = IntArray(numClasses)
    oneHotV[v.toInt()] = 1
    oneHotV
}

val toOneHotString: (Float)-> String = { v->
    toOneHotVector(v, 10).joinToString(",")
}

v.toInt() and 0xFF するのは省略しました。 ここでは、v の値の範囲は 0.0..9.0 に決まっているので。

これで 0.0..9.0 の数値をワンホット文字列に変換できました。

まとめ

これまでつくったコードをまとめます。 ./data/ 以下に画像と正解ラベルのCSVを書き出します。

dl.main.kts

@file:Repository("https://repo1.maven.org/maven2/")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2")

import java.io.File
import java.awt.Color
import java.awt.image.BufferedImage
import javax.imageio.ImageIO
import org.jetbrains.kotlinx.dl.dataset.embedded.mnist
import org.jetbrains.kotlinx.dl.dataset.Dataset

System.setProperty("java.awt.headless", "true")

val zfill5: (Int)->String = { v-> String.format("%1$05d", v) }

val toBufferedImage: (FloatArray, Int, Int)-> BufferedImage = { pixelValues ,w , h->
    val image = BufferedImage(w, h, BufferedImage.TYPE_BYTE_GRAY)

    0.until(h).forEach { row->
        0.until(w).forEach { col->
            val index = row*w + col
            val v = (pixelValues[index] * 255).toInt()
            image.setRGB(col, row, Color(v, v, v).getRGB())
        }
    }

    image
}

val exportJpgFile: (File, Dataset, Int)->Unit = { dir, dataset, imageSize->
    0.until(dataset.xSize()).forEach { index->
        val bimg = toBufferedImage(dataset.getX(index), imageSize, imageSize)
        val filename = "${zfill5(index)}.jpg"
        ImageIO.write(bimg, "JPEG", File(dir, filename))
    }
}

val toOneHotVector: (Float, Int)-> IntArray = { v, numClasses->
    val oneHotV = IntArray(numClasses)
    oneHotV[v.toInt()] = 1
    oneHotV
}

val toOneHotString: (Float, Int)-> String = { v, numClasses->
    toOneHotVector(v, numClasses).joinToString(",")
}

val exportCSVFile: (File, Dataset, Int)->Unit = { dir, dataset, numClasses->
    0.until(dataset.xSize()).forEach { index->
        val v = toOneHotString(dataset.getY(index), numClasses)
        val filename = "${zfill5(index)}.csv"
        File(dir, filename).printWriter().use { out ->
            out.print( v )
        }
    }
}


val numClasses = 10
val imageSize = 28

val trainDir = File("data/train")
val trainImgDir   = File(trainDir, "img")
val trainLabelDir = File(trainDir, "label")

val testDir = File("data/test")
val testImgDir   = File(testDir, "img")
val testLabelDir = File(testDir, "label")

listOf(trainImgDir, trainLabelDir, testImgDir, testLabelDir).forEach { it.mkdirs() }

val (train, test) = mnist()

exportJpgFile(trainImgDir, train, imageSize)
exportCSVFile(trainLabelDir, train, numClasses)

exportJpgFile(testImgDir, test, imageSize)
exportCSVFile(testLabelDir, test, numClasses)

実行する。

$ kotlinc -script dl.main.kts

以上です。