こんにちは!おかいです。
今回は、MATLABで始めるディープラーニングNo.1で行ったディープラーニングを行い、その後 予測結果、オクリュージョンを使い判断根拠をみていきます。
ディープラーニング
フォルダ
分類する個数は3つです。
分類は、犬、猫、鳥にします。
.
├── sample.m
└── pictures
├── dog
│ ├── ....jpg
│ ├── ....jpg
│ ....
│ └── ....jpg
├── cat
│ ├── ....jpg
│ ├── ....jpg
│ ....
│ └── ....jpg
└── bird
├── ....jpg
....
├── ....jpg
└── ....jpg
では、MATLABで始めるディープラーニングNo.1で書いた内容のプログラムを以下に書きます。
%% データ読み込み
imds = imageDatastore('./pictures','IncludeSubfolders',true, 'LabelSource','foldernames');
%% データの分割
[trainData, testData] = splitEachLabel(imds, 0.6, 'randomized');
%% ニューラルネットワーク読み込み
net = vgg16;
% deepNetworkDesigner(net)
%% ネットワーク層修正
layers = net.Layers;
numClasses = numel(categories(imds.Labels));
layers(39) = fullyConnectedLayer(numClasses);
layers(41) = classificationLayer;
%% ネットワークのインプットサイズに合わせて画像リサイズ
inputSize = net.Layers(1).InputSize;
augTrainData = augmentedImageDatastore(inputSize, trainData);
augTestData = augmentedImageDatastore(inputSize, testData);
%% 学習オプション
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augTrainData, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
%% 転移学習
netTransfer = trainNetwork(augTrainData, layers, options);

上は学習状況をプロットした画像です。
テストデータを使って結果を予測
%% テストデータで予測
[predLabels, scores] = classify(netTransfer, augTestData);
predLabelsには予測したラベル(犬or猫or鳥)が格納されます。
scoresは犬or猫or鳥の割合が格納されています。例えば
0.99994731 3.3723525e-05 1.9006877e-05
この数値に合わせて一番数値の高いラベルが予測ラベルとしてpredLabelsに保存されます。
予測結果の検証
予測したラベルが本当に正しいのかどうか検証しましょう。
テストデータのラベル(正解ラベル)をtestDataLabelに格納します。
%% テストデータを用いてモデル精度の検証
testDataLabel = testData.Labels;

テストデータのラベルが保存された変数はaugmentedImageDatastore関数で変換したaugTestData変数ではダメなのですか?

ダメです。いや、ダメと言うよりaugTestData変数にはそもそもラベルデータが保存されていないのです。
中身を見てみましょう。
augTestData変数の中身

testData変数の中身


本当だ!サイズを変更した後のaugTestData変数にはLabelsが入ってないですね。
納得。。
次に、正解ラベルと予測したラベルを比べて、nnz関数を使い、正しく予測した個数を数えます。
テストデータの全個数をnumel関数で数えます
% 正しく予測されたデータの個数をcorrectに格納
correct = nnz(testDataLabel == predLabels);
% テストデータの全個数をtestDataCount格納
testDataCount = numel(testData.Labels);
MATLAB公式サイト: nnz関数
MATLAB公式サイト: numel関数
ll変数に 正解率(正解した個数 / 全個数) を保存します。
% ll変数に正解率を格納
ll = correct / testDataCount;

数値だけだと分かりにくいので可視化しましょうか。
%% 可視化
% 真のラベルと予測したラベル
confusionchart(testDataLabel, predLabels);

何と、珍しく100%の正答率が出てしまいました。ww
llの中身も1となっております。
画像を読み込ませてみよう

モデルの作り方もわかったし、検証の仕方もわかりました。
でも、以下の画像のように、自分が画像を読ませて予測するっていうってのをやってみたいです。


そうですね。せっかくモデルと作ったのでやりましょうか
フォルダ構成
では、予測を行う画像とそのコードを書いていくので改めて以下のようにファイルを作ります。
.
├── sample.m
├── sample2.m
├── dog.jpg
├── cat.jpg
├── bird.jpg
└── pictures
├── dog
│ ├── ....jpg
│ ├── ....jpg
│ ....
│ └── ....jpg
├── cat
│ ├── ....jpg
│ ├── ....jpg
│ ....
│ └── ....jpg
└── bird
├── ....jpg
....
├── ....jpg
└── ....jpg
作成したモデルのインプットサイズと最終層を変数へ
%% ネットワークのインプットサイズと出力クラス
inputSize = netTransfer.Layers(1).InputSize(1:2);
classes = netTransfer.Layers(end).Classes;
予測させたい画像を読み込む
imread関数で画像を読み込み、imresize関数で作成モデルに入れるサイズに変更します。
%% 画像読み込み
img = imread("dog.jpg");
img = imresize(img,inputSize);
MATLAB公式サイト: imread関数
MATLAB公式サイト: imresize関数
予測(分類)
モデルを作成しテストデータを予測した時同様にclassify関数を使います.
その後、予測結果をみやすいのでグラフとどんな予測をしているのかをラベル名(数値)で表します.
コードと説明
%% 分類
[Ypred, scores] = classify(netTransfer, img); % 予測
[~,topIdx] = maxk(scores, 3); % maxk関数によりscoresのトップ3を取り出し、toIdxにはインデックス番号(1番数値の高いインデックス番号順)が入る
topScores = scores(topIdx); % score(インデックス番号)よりtopScoresには1番高い数値から順に保存される
topClasses = classes(topIdx); % topScoresと同様
imshow(img)
titleString = compose("%s (%.2f)",topClasses,topScores'); % compose関数によりタイトルの文
title(sprintf(join(titleString, "; ")));

MATLAB公式サイト: maxk関数
MATLAB公式サイト: compose関数
オクリュージョンを使って判断根拠を記す
世の中の多くの物事は根拠がないと認めてもらえない、納得できないことが多いかと思います。
今回作ったモデルは読み込んだ画像のどこを見て、犬or猫or鳥を判断しているのでしょうか。
以下のようにコードを書きます。(コードの説明は省きます)
%% どこを見ているのかデータに落とす
map = occlusionSensitivity(netTransfer,img,Ypred);
%% 可視化
imshow(img,'InitialMagnification', 150)
hold on
imagesc(map,'AlphaData',0.5)
colormap jet
colorbar
title(sprintf("Occlusion sensitivity (%s)", ...
Ypred))
MATLAB公式サイト: occlusionSensitivity関数

どうやら顔を見て犬と判断しているようですね
最後に、猫と鳥の画像でもやってみましょう。
猫


鳥


全コード
%% データ読み込み
imds = imageDatastore('./pictures','IncludeSubfolders',true, 'LabelSource','foldernames');
%% データの分割
[trainData, testData] = splitEachLabel(imds, 0.6, 'randomized');
%% ニューラルネットワーク読み込み
net = vgg16;
% deepNetworkDesigner(net)
%% ネットワーク層修正
layers = net.Layers;
numClasses = numel(categories(imds.Labels));
layers(39) = fullyConnectedLayer(numClasses);
layers(41) = classificationLayer;
%% ネットワークのインプットサイズに合わせて画像リサイズ
inputSize = net.Layers(1).InputSize;
augTrainData = augmentedImageDatastore(inputSize, trainData);
augTestData = augmentedImageDatastore(inputSize, testData);
%% 学習オプション
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augTrainData, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
%% 転移学習
netTransfer = trainNetwork(augTrainData, layers, options);
%% テストデータで予測
[predLabels, scores] = classify(netTransfer, augTestData);
%% テストデータを用いてモデル精度の検証
testDataLabel = testData.Labels;
% 正しく予測されたデータの個数をcorrectに格納
correct = nnz(testDataLabel == predLabels);
% テストデータの全個数をtestDataCount格納
testDataCount = numel(testData.Labels);
% ll変数に正解率を格納
ll = correct / testDataCount;
%% 可視化
% 真のラベルと予測したラベル
confusionchart(testDataLabel, predLabels);
%% ネットワークのインプットサイズと出力クラス
inputSize = netTransfer.Layers(1).InputSize(1:2);
classes = netTransfer.Layers(end).Classes;
%% 画像読み込み
img = imread("dog.jpg");
img = imresize(img,inputSize);
%% 分類
[Ypred, scores] = classify(netTransfer, img); % 予測
[~,topIdx] = maxk(scores, 3); % maxk関数によりscoresのトップ3を取り出し、toIdxにはインデックス番号(1番数値の高いインデックス番号順)が入る
topScores = scores(topIdx); % score(インデックス番号)よりtopScoresには1番高い数値から順に保存される
topClasses = classes(topIdx); % topScoresと同様
imshow(img)
titleString = compose("%s (%.2f)",topClasses,topScores'); % compose関数によりタイトルの文
title(sprintf(join(titleString, "; ")));
%% どこを見ているのかデータに落とす
map = occlusionSensitivity(netTransfer,img,Ypred);
%% 可視化
imshow(img,'InitialMagnification', 150)
hold on
imagesc(map,'AlphaData',0.5)
colormap jet
colorbar
title(sprintf("Occlusion sensitivity (%s)", ...
Ypred))
ありがとうございました。
コメント