This example shows how to retrain a pretrained SqueezeNet neural network to perform classification on a new collection of images.

matlab
filename = 'mnist';
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
labelCount = countEachLabel(imds);
numImages = numel(imds.Labels);
classNames = categories(imds.Labels)
numClasses = numel(classNames)
[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.7,0.15,"randomized");
%%
net = imagePretrainedNetwork("squeezenet",NumClasses=numClasses)
% inputSize = net.Layers(1).InputSize(1:2);
inputSize = networkInputSize(net)
[layerName,learnableNames] = networkHead(net)
net = freezeNetwork(net,LayerNamesToIgnore=layerName);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ColorPreprocessing='gray2rgb');
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation,ColorPreprocessing='gray2rgb');
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest,ColorPreprocessing='gray2rgb');
%%
options = trainingOptions("adam", ...
ValidationData=augimdsValidation, ...
ValidationFrequency=5, ...
Plots="training-progress", ...
Metrics="accuracy", ...
Verbose=false);
net = trainnet(augimdsTrain,net,"crossentropy",options);
%%
YTest = minibatchpredict(net,augimdsTest);
YTest = scores2label(YTest,classNames);
Visualize the classification accuracy in a confusion chart.
TTest = imdsTest.Labels;
figure
confusionchart(TTest,YTest)
对于非常小的数据集(每个类不到 20 个图像),使用特征提取。


