孪生神经网络MATLAB实战[含源码]

​一、算法原理

孪生神经网络( Siamese neural network)是一种深度学习网络,它使用两个或多个具有相同架构、共享相同参数和权重的相同子网。孪生网络通常用于寻找两个可比较事物之间的关系的任务。孪生网络的一些常见应用包括面部识别、签名验证或释义识别。孪生神经网络是基于两个人工神经网络建立的耦合架构,孪生神经网络以两个样本为输入,输出其嵌入高维空间的表征,以比较两个样本的相似程度,狭义的孪生神经网络由两个结构相同,且权重共享的神经网络拼接而成,网络框架如下图所示。

广义的孪生神经网络(又称pseudo-siamese network,伪孪生神经网络),可由两个任意的神经网络拼接而成,可由卷积神经网络、循环神经网络等组成,网络框架如下图所示。

简单来说,孪生神经网络就是衡量两个输入的相似程度。孪生神经网络有两个输入,将两个输入输入到两个神经网络,这两个神经网络分别将输入映射到新的空间,形成输入在新的空间中的表示。通过Loss的计算,评价两个输入的相似度。孪生神经网络在这些任务中表现良好,因为它们的共享权重意味着在训练过程中需要学习的参数更少,并且它们可以用相对较少的训练数据产生良好的结果。

二、代码实战

复制代码
%% matlab学习之家
%% 孪生神经网络
clc
clear
%% 读取训练集
dataFolderTrain = "D:\S\孪生神经网络\images_background"; %% 更换路径
imdsTrain = imageDatastore(dataFolderTrain, ...
    IncludeSubfolders=true, ...
    LabelSource="none");
​
files = imdsTrain.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),"-");
imdsTrain.Labels = categorical(labels);
​
%% 显示图片
idx = randperm(numel(imdsTrain.Files),8);
​
for i = 1:numel(idx)
    subplot(4,2,i)
    imshow(readimage(imdsTrain,idx(i)))
    title(imdsTrain.Labels(idx(i)),Interpreter="none");
end
​
batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getTwinBatch(imdsTrain,batchSize);
​
for i = 1:batchSize
    if pairLabel(i) == 1
        s = "similar";
    else
        s = "dissimilar";
    end
    subplot(2,5,i)
    imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);
    title(s)
end
%% 定义孪生神经网络架构
layers = [
    imageInputLayer([105 105 1],Normalization="none")
    convolution2dLayer(10,64,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(7,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(4,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(5,256,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    fullyConnectedLayer(4096,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")];
​
net = dlnetwork(layers);
​
fcWeights = dlarray(0.01*randn(1,4096));
fcBias = dlarray(0.01*randn(1,1));
​
fcParams = struct(...
    "FcWeights",fcWeights,...
    "FcBias",fcBias);
%% 定义损失函数
numIterations = 10000;
miniBatchSize = 180;
learningRate = 6e-5;
gradDecay = 0.9;
gradDecaySq = 0.99;
executionEnvironment = "auto";
​
if canUseGPU
    gpu = gpuDevice;
    disp(gpu.Name + " GPU detected and available for training.")
end
​
trailingAvgSubnet = [];
trailingAvgSqSubnet = [];
trailingAvgParams = [];
trailingAvgSqParams = [];
​
monitor = trainingProgressMonitor(Metrics="Loss",XLabel="Iteration",Info="ExecutionEnvironment");
if canUseGPU
    updateInfo(monitor,ExecutionEnvironment=gpu.Name + " GPU")
else
    updateInfo(monitor,ExecutionEnvironment="CPU")
end
​
start = tic;
iteration = 0;
%% 使用自定义训练循环训练模型。循环遍历训练数据并在每次迭代时更新网络参数。
while iteration < numIterations && ~monitor.Stop
​
    iteration = iteration + 1;
​
    
    [X1,X2,pairLabels] = getTwinBatch(imdsTrain,miniBatchSize);
​
    
    X1 = dlarray(X1,"SSCB");
    X2 = dlarray(X2,"SSCB");
​
  
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        X1 = gpuArray(X1);
        X2 = gpuArray(X2);
    end
​
    % 使用dlfeval和modelLoss评估模型损失和梯度
    [loss,gradientsSubnet,gradientsParams] = dlfeval(@modelLoss,net,fcParams,X1,X2,pairLabels);
​
    % 更新孪生神经网络参数
    [net,trailingAvgSubnet,trailingAvgSqSubnet] = adamupdate(net,gradientsSubnet, ...
        trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
​
    % 更新全连接层参数.
    [fcParams,trailingAvgParams,trailingAvgSqParams] = adamupdate(fcParams,gradientsParams, ...
        trailingAvgParams,trailingAvgSqParams,iteration,learningRate,gradDecay,gradDecaySq);
​
    recordMetrics(monitor,iteration,Loss=loss);
    monitor.Progress = 100 * iteration/numIterations;
​
end
%% 读取测试集
dataFolderTest = "D:\S\孪生神经网络\images_evaluation" %% 更换路径
imdsTest = imageDatastore(dataFolderTest, ...
    IncludeSubfolders=true, ...
    LabelSource="none");
​
files = imdsTest.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),"_");
imdsTest.Labels = categorical(labels);
​
numClasses = numel(unique(imdsTest.Labels));
accuracy = zeros(1,5);
accuracyBatchSize = 150;
​
for i = 1:5
    
    [X1,X2,pairLabelsAcc] = getTwinBatch(imdsTest,accuracyBatchSize);
​
   
    X1 = dlarray(X1,"SSCB");
    X2 = dlarray(X2,"SSCB");
​
   
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        X1 = gpuArray(X1);
        X2 = gpuArray(X2);
    end
​
  
    Y = predictTwin(net,fcParams,X1,X2);
​
    Y = gather(extractdata(Y));
    Y = round(Y);
​
​
    accuracy(i) = sum(Y == pairLabelsAcc)/accuracyBatchSize;
end
​
averageAccuracy = mean(accuracy)*100;
​
testBatchSize = 10;
​
[XTest1,XTest2,pairLabelsTest] = getTwinBatch(imdsTest,testBatchSize);
​
XTest1 = dlarray(XTest1,"SSCB");
XTest2 = dlarray(XTest2,"SSCB");
​
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    XTest1 = gpuArray(XTest1);
    XTest2 = gpuArray(XTest2);
end
YScore = predictTwin(net,fcParams,XTest1,XTest2);
YScore = gather(extractdata(YScore));
YPred = round(YScore);
XTest1 = extractdata(XTest1);
XTest2 = extractdata(XTest2);
f = figure;
tiledlayout(2,5);
f.Position(3) = 2*f.Position(3);
​
predLabels = categorical(YPred,[0 1],["dissimilar" "similar"]);
targetLabels = categorical(pairLabelsTest,[0 1],["dissimilar","similar"]);
%% 用预测的标签和预测的分数绘制图像
for i = 1:numel(pairLabelsTest)
    nexttile
    imshow([XTest1(:,:,:,i) XTest2(:,:,:,i)]);
​
    title( ...
        "Target: " + string(targetLabels(i)) + newline + ...
        "Predicted: " + string(predLabels(i)) + newline + ...
        "Score: " + YScore(i))
end

仿真结果

相关推荐
无水先生8 分钟前
数据集预处理:规范化和标准化
人工智能·深度学习
August_._24 分钟前
【MySQL】触发器、日志、锁机制 深度解析
java·大数据·数据库·人工智能·后端·mysql·青少年编程
磊磊落落24 分钟前
使用 FastMCP 编写一个 MySQL MCP Server
人工智能
零号机1 小时前
使用TRAE 30分钟极速开发一款划词中英互译浏览器插件
前端·人工智能
FunTester1 小时前
基于 Cursor 的智能测试用例生成系统 - 项目介绍与实施指南
人工智能·ai·大模型·测试用例·实践指南·curor·智能测试用例
SEO_juper1 小时前
LLMs.txt 创建指南:为大型语言模型优化您的网站
人工智能·ai·语言模型·自然语言处理·数字营销
淮雵的Blog1 小时前
langGraph通俗易懂的解释、langGraph和使用API直接调用LLM的区别
人工智能
Mintopia1 小时前
🚀 共绩算力:3分钟拥有自己的文生图AI服务-容器化部署 StableDiffusion1.5-WebUI 应用
前端·人工智能·aigc
HPC_C1 小时前
SGLang: Efficient Execution of Structured Language Model Programs
人工智能·语言模型·自然语言处理
王哈哈^_^1 小时前
【完整源码+数据集】草莓数据集,yolov8草莓成熟度检测数据集 3207 张,草莓成熟度数据集,目标检测草莓识别算法系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·视觉检测·毕业设计