前回MNIST手書きデータを学習したので、それをロードしてテストだけしてみます。
※前回はこちら
DL4Jを使って、学習したMNIST手書きデータをロードするサンプル
前回学習したデータをロードします。
ロード後、MNISTデータを使って、テストだけします。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import java.io.File; import java.io.IOException; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; public class MnistTest2 { public static void main(String[] args) throws IOException { //MNISTデータを準備 DataSetIterator mnistTest = new MnistDataSetIterator(128, false, 123); //学習データをロード MultiLayerNetwork network = MultiLayerNetwork.load(new File("C:\\work\\train.dat"), false); //トレーニング後のテスト実行と評価の出力 Evaluation eval = network.evaluate(mnistTest); System.out.println(eval.stats()); } } |
実行結果
学習したデータをロード。テストの結果が出力されます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder". SLF4J: Defaulting to no-operation (NOP) logger implementation SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details. ========================Evaluation Metrics======================== # of classes: 10 Accuracy: 0.9828 Precision: 0.9829 Recall: 0.9826 F1 Score: 0.9827 Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes) =========================Confusion Matrix========================= 0 1 2 3 4 5 6 7 8 9 --------------------------------------------------- 970 0 2 1 0 1 1 1 3 1 | 0 = 0 0 1128 2 1 0 0 2 0 2 0 | 1 = 1 0 3 1020 1 1 0 1 3 3 0 | 2 = 2 1 1 3 996 0 0 0 2 4 3 | 3 = 3 0 0 2 1 961 0 3 2 0 13 | 4 = 4 2 0 0 9 1 870 4 1 3 2 | 5 = 5 4 2 1 1 1 4 944 0 1 0 | 6 = 6 1 6 7 3 0 0 0 1004 2 5 | 7 = 7 3 0 2 5 1 0 4 2 955 2 | 8 = 8 4 2 0 6 8 2 1 2 4 980 | 9 = 9 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times ================================================================== |
先頭でSLF4Jが見つからない旨、エラー出力されます。
気になる方は、SLF4Jも入れてください。
サンプルの解説
学習済みのデータをロードすることで、時間のかかる学習をスキップできます。
データのロードには、MultiLayerNetwork.load(File, boolean)を使います。
第1引数に学習済みのデータファイル。第2引数は追加の学習有無です。
あとは、前回同様にテストするだけですね。
※このコードを使用するには、別途DL4Jの入手が必要です。
入手方法などはこちらの記事(前回)に書いてあります。