KotlinDL経由で MNISTデータを取得する方法を調べたその備忘録です。
以前に Python で書いた download.pyと同じ機能をKotlinDLで実装します。
環境は以下の通り。
$ kotlinc -version
info: kotlinc-jvm 1.8.10 (JRE 17.0.8.1+1-Ubuntu-0ubuntu122.04)
必須ではありませんが、 https://github.com/Kotlin/kotlindl からソースコードを取得して、以下のクラスを手始めに関連するコードを読むとよい。
- kotlindl/dataset/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/dataset/embedded/MnistUtil.kt
- kotlindl/dataset/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/dataset/OnHeapDataset.kt
maven レポジトリから関連ライブラリを取得します。
dl.main.kts
@file:Repository("https://repo1.maven.org/maven2/")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2")
import org.jetbrains.kotlinx.dl.dataset.embedded.mnist
val (train, test) = mnist()
これだけで、train の images, labels と test の images, labels が取得できます。 train や test は OnHeapDataset クラスのインスタンスです。
train の方の内容を調べてみます。
println( train ) // OnHeapDataset
println( train.xSize() ) // 60000
println( train.getX(0).size ) // 784 (28x28)
println( train.getY(0) ) // 5.0
60000 件のデータがあることがわかります。 そして getX(), getY() で入力(画像)と出力(正解ラベル)を得ることができます。
$ kotlinc -script dl.main.kts
org.jetbrains.kotlinx.dl.dataset.OnHeapDataset@689cc29a
60000
784
5.0
もし heap が足りないとかのエラーが出たら export JAVA_OPTS="-Xmx2g" などとして回避します。
初回はデータをダウンロードするので少し時間がかかります。 データはカレントディレクトリの ./cache 以下にキャッシュされます。
./cache/datasets/mnist/train-labels-idx1-ubyte.gz
./cache/datasets/mnist/train-images-idx3-ubyte.gz
./cache/datasets/mnist/t10k-labels-idx1-ubyte.gz
./cache/datasets/mnist/t10k-images-idx3-ubyte.gz
train.getX(0) すると (28x28=)784 の長さの FloatArray が取得できます。 これが MNIST の最初の画像データです。FloatArray のそれぞれは 0.0..1.0 の範囲の値になっている。
以下のように書けば、784個のデータを標準出力することができます。
val firstImageDataArray: FloatArray = train.getX(0)
0.until(firstImageDataArray.size).forEach { index->
val pxValue = firstImageDataArray[index]
println("- ${index}=${pxValue}")
}
これでは少し分かりづらいので、以下のように 28個ごとに改行して出力してみます。
val firstImageDataArray: FloatArray = train.getX(0)
0.until(28).forEach { row->
val rowData = 0.until(28).map { col->
val index = row*28 + col
val v = firstImageDataArray[index]
"%.1f".format(v) // 少数第一位まで出力
}.joinToString(",")
println(rowData)
}
実行すると、以下のように出力されます。 目を細めれば、数字の5が見えなくもない。
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.1,0.1,0.5,0.5,0.7,0.1,0.7,1.0,1.0,0.5,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.1,0.4,0.6,0.7,1.0,1.0,1.0,1.0,1.0,0.9,0.7,1.0,0.9,0.8,0.3,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.9,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.4,0.3,0.3,0.2,0.2,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.9,1.0,1.0,1.0,1.0,1.0,0.8,0.7,1.0,0.9,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3,0.6,0.4,1.0,1.0,0.8,0.0,0.0,0.2,0.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.0,0.6,1.0,0.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5,1.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.7,1.0,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.9,0.9,0.6,0.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3,0.9,1.0,1.0,0.5,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.7,1.0,1.0,0.6,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.4,1.0,1.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.5,0.7,1.0,1.0,0.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.6,0.9,1.0,1.0,1.0,1.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.4,0.9,1.0,1.0,1.0,1.0,0.8,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.3,0.8,1.0,1.0,1.0,1.0,0.8,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.7,0.9,1.0,1.0,1.0,1.0,0.8,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.2,0.7,0.9,1.0,1.0,1.0,1.0,1.0,0.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.5,1.0,1.0,1.0,0.8,0.5,0.5,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
まてまて、以下のように 0f 以上の場合だけ数字を出すようにすれば....
0.until(28).forEach { row->
val rowData = 0.until(28).map { col->
val index = row*28 + col
val v = firstImageDataArray[index]
//"%.1f".format(v) // 少数第一位まで出力
if( v>0f ){ "%.1f".format(v) } else { "..." }
}.joinToString(",")
println(rowData)
}
この通り可視化できた。
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,0.0,0.1,0.1,0.1,0.5,0.5,0.7,0.1,0.7,1.0,1.0,0.5,...,...,...,...
...,...,...,...,...,...,...,...,0.1,0.1,0.4,0.6,0.7,1.0,1.0,1.0,1.0,1.0,0.9,0.7,1.0,0.9,0.8,0.3,...,...,...,...
...,...,...,...,...,...,...,0.2,0.9,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.4,0.3,0.3,0.2,0.2,...,...,...,...,...
...,...,...,...,...,...,...,0.1,0.9,1.0,1.0,1.0,1.0,1.0,0.8,0.7,1.0,0.9,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,0.3,0.6,0.4,1.0,1.0,0.8,0.0,...,0.2,0.6,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,0.1,0.0,0.6,1.0,0.4,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,0.5,1.0,0.7,0.0,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,0.0,0.7,1.0,0.3,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,0.1,0.9,0.9,0.6,0.4,0.0,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,0.3,0.9,1.0,1.0,0.5,0.1,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,0.2,0.7,1.0,1.0,0.6,0.1,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,0.1,0.4,1.0,1.0,0.7,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,1.0,1.0,1.0,0.3,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,0.2,0.5,0.7,1.0,1.0,0.8,0.0,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,0.2,0.6,0.9,1.0,1.0,1.0,1.0,0.7,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,0.1,0.4,0.9,1.0,1.0,1.0,1.0,0.8,0.3,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,0.1,0.3,0.8,1.0,1.0,1.0,1.0,0.8,0.3,0.0,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,0.1,0.7,0.9,1.0,1.0,1.0,1.0,0.8,0.3,0.0,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,0.2,0.7,0.9,1.0,1.0,1.0,1.0,1.0,0.5,0.0,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,0.5,1.0,1.0,1.0,0.8,0.5,0.5,0.1,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
それではこれを画像に変換します。
val toBufferedImage: (FloatArray, Int, Int)-> BufferedImage = { pixelValues ,w , h->
val image = BufferedImage(w, h, BufferedImage.TYPE_BYTE_GRAY)
0.until(h).forEach { row->
0.until(w).forEach { col->
val index = row*w + col
val v = (pixelValues[index] * 255).toInt()
image.setRGB(col, row, Color(v, v, v).getRGB())
}
}
image
}
この toBufferedImage 関数を以下のように使えば 0_5.jpg に画像を保存できます。
val index = 0
val bimg = toBufferedImage( train.getX(index), 28, 28 )
ImageIO.write(bimg, "JPEG", File("${index}_${train.getY(index).toInt()}.jpg"))
これで train の先頭のデータを画像として保存できました。
次に、正解ラベルをワンホットエンコーディングに変換、それを CSVファイル として保存します。
train.getY(0) すれば 0 の画像が 0..9 のどれかを示す正解の数値が入っている。 0番目は 数字の 5 なので、 5.0 が入っている。
最初次のようなコードを考えたのですが...
val toOneHotString: (Float)-> String = { v->
when(v.toInt()){
0 -> { "1,0,0,0,0,0,0,0,0,0" }
1 -> { "0,1,0,0,0,0,0,0,0,0" }
2 -> { "0,0,1,0,0,0,0,0,0,0" }
3 -> { "0,0,0,1,0,0,0,0,0,0" }
4 -> { "0,0,0,0,1,0,0,0,0,0" }
5 -> { "0,0,0,0,0,1,0,0,0,0" }
6 -> { "0,0,0,0,0,0,1,0,0,0" }
7 -> { "0,0,0,0,0,0,0,1,0,0" }
8 -> { "0,0,0,0,0,0,0,0,1,0" }
9 -> { "0,0,0,0,0,0,0,0,0,1" }
else -> { "0,0,0,0,0,0,0,0,0,0" }
}
}
これでもちろん機能するのですが、 org.jetbrains.kotlinx.dl.dataset.OnHeapDataset.kt を見ると:
/** Creates binary vector with size [numClasses] from [label]. */
@JvmStatic
public fun toOneHotVector(numClasses: Int, label: Byte): FloatArray {
val ret = FloatArray(numClasses)
ret[label.toInt() and SHIFT_NUMBER] = 1f
return ret
}
という同じ処理をするコードがすでに存在していました。
ちなみに SHIFT_NUMBER = 0xFF です。
このコードを参考にすれば、以下のように書けます。
val toOneHotVector: (Float, Int)-> IntArray = { v, numClasses->
val oneHotV = IntArray(numClasses)
oneHotV[v.toInt()] = 1
oneHotV
}
val toOneHotString: (Float)-> String = { v->
toOneHotVector(v, 10).joinToString(",")
}
v.toInt() and 0xFF するのは省略しました。 ここでは、v の値の範囲は 0.0..9.0 に決まっているので。
これで 0.0..9.0 の数値をワンホット文字列に変換できました。
これまでつくったコードをまとめます。 ./data/ 以下に画像と正解ラベルのCSVを書き出します。
dl.main.kts
@file:Repository("https://repo1.maven.org/maven2/")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2")
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2")
import java.io.File
import java.awt.Color
import java.awt.image.BufferedImage
import javax.imageio.ImageIO
import org.jetbrains.kotlinx.dl.dataset.embedded.mnist
import org.jetbrains.kotlinx.dl.dataset.Dataset
System.setProperty("java.awt.headless", "true")
val zfill5: (Int)->String = { v-> String.format("%1$05d", v) }
val toBufferedImage: (FloatArray, Int, Int)-> BufferedImage = { pixelValues ,w , h->
val image = BufferedImage(w, h, BufferedImage.TYPE_BYTE_GRAY)
0.until(h).forEach { row->
0.until(w).forEach { col->
val index = row*w + col
val v = (pixelValues[index] * 255).toInt()
image.setRGB(col, row, Color(v, v, v).getRGB())
}
}
image
}
val exportJpgFile: (File, Dataset, Int)->Unit = { dir, dataset, imageSize->
0.until(dataset.xSize()).forEach { index->
val bimg = toBufferedImage(dataset.getX(index), imageSize, imageSize)
val filename = "${zfill5(index)}.jpg"
ImageIO.write(bimg, "JPEG", File(dir, filename))
}
}
val toOneHotVector: (Float, Int)-> IntArray = { v, numClasses->
val oneHotV = IntArray(numClasses)
oneHotV[v.toInt()] = 1
oneHotV
}
val toOneHotString: (Float, Int)-> String = { v, numClasses->
toOneHotVector(v, numClasses).joinToString(",")
}
val exportCSVFile: (File, Dataset, Int)->Unit = { dir, dataset, numClasses->
0.until(dataset.xSize()).forEach { index->
val v = toOneHotString(dataset.getY(index), numClasses)
val filename = "${zfill5(index)}.csv"
File(dir, filename).printWriter().use { out ->
out.print( v )
}
}
}
val numClasses = 10
val imageSize = 28
val trainDir = File("data/train")
val trainImgDir = File(trainDir, "img")
val trainLabelDir = File(trainDir, "label")
val testDir = File("data/test")
val testImgDir = File(testDir, "img")
val testLabelDir = File(testDir, "label")
listOf(trainImgDir, trainLabelDir, testImgDir, testLabelDir).forEach { it.mkdirs() }
val (train, test) = mnist()
exportJpgFile(trainImgDir, train, imageSize)
exportCSVFile(trainLabelDir, train, numClasses)
exportJpgFile(testImgDir, test, imageSize)
exportCSVFile(testLabelDir, test, numClasses)
実行する。
$ kotlinc -script dl.main.kts
以上です。