18.Kaggle竞赛--使用ResNet-50网络进行树叶分类

python 复制代码
import torch
from torch.utils.data import Dataset,DataLoader,random_split
from torchvision import transforms
import pandas as pd
from PIL import Image
import torch.nn as nn
import torchvision.models as models
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
##########################################################################################################
#首先处理数据集--训练集与外部测试集
class LeavesDataset(Dataset):
    def __init__(self,csv_path,transform=None):
        self.data=pd.read_csv(csv_path)
        self.transform=transform
        self.label_to_idx=self.encode_labels(self.data['label'])
    def encode_labels(self,labels):
        unique_labels=sorted(set(labels))
        label_to_idx={label:idx for idx,label in enumerate(unique_labels)}
        return label_to_idx
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        image_path=self.data.iloc[idx,0]
        label_text=self.data.iloc[idx,1]
        label=self.label_to_idx[label_text]
        image=Image.open(image_path).convert('RGB')
        if self.transform:
            image=self.transform(image)
        return image,label
#这是测试集的处理方法,测试集没有label
class TestDataset(Dataset):
    def __init__(self,csv_path,transform=None):
        self.data=pd.read_csv(csv_path)
        self.transform=transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        image_path=self.data.iloc[idx,0]
        image=Image.open(image_path).convert('RGB')
        if self.transform:
            image=self.transform(image)
        return image,image_path
##########################################################################################################
#训练过程--绘图与训练
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):
    epochs = range(1, len(train_loss_list) + 1)
    plt.figure(figsize=(4, 3))
    plt.plot(epochs, train_loss_list, label='Train Loss')
    plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')
    plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
def train_model(model,train_data,test_data,num_epochs,device):
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    for epoch in range(num_epochs):
        total_loss=0
        total_acc_sample=0
        total_samples=0
        loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]",colour='#FF5555')
        for X,y in loop1:
            #X=X.reshape(X.shape[0],-1)
            #print(X.shape)
            X=X.to(device)
            y=y.to(device)
            y_hat=model(X)
            loss=CEloss(y_hat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #loss累加
            total_loss+=loss.item()*X.shape[0]
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            total_samples+=X.shape[0]
        test_acc_samples=0
        test_samples=0
        loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]",colour='#FF5555')
        for X,y in loop2:
            X=X.to(device)
            y=y.to(device)
            #X=X.reshape(X.shape[0],-1)
            y_hat=model(X)
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            test_samples+=X.shape[0]
        avg_train_loss=total_loss/total_samples
        avg_train_acc=total_acc_sample/total_samples
        avg_test_acc=test_acc_samples/test_samples
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
        test_acc_list.append(avg_test_acc)
        print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")
    plot_metrics(train_loss_list, train_acc_list, test_acc_list)
    return model
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
################################################################################################################
transforms=transforms.Compose([transforms.RandomHorizontalFlip(),
                               transforms.Resize(224),
                               transforms.RandomRotation(30),
                               transforms.ToTensor(),
                               transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])
data=LeavesDataset("train.csv",transform=transforms)   
label_mapping = data.label_to_idx
train_size=int(0.8*len(data))
test_size=len(data)-train_size
train_data,test_data=random_split(data,[train_size,test_size])
print(len(train_data),len(test_data))
train_dataloader=DataLoader(train_data,batch_size=64,num_workers=8,shuffle=True)
test_dataloader=DataLoader(test_data,batch_size=64,num_workers=8,shuffle=False)
#最终还是决定拿所有的图像进行训练
train_dataloader_all=DataLoader(data,batch_size=64,num_workers=8,shuffle=True)
################################################################################################################
################################################################################################################
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=models.resnet50(pretrained=True)#直接调用ResNet-50进行训练
num_classes=len(data.label_to_idx)
model.fc=nn.Linear(model.fc.in_features,num_classes)
model.apply(init_weights)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_dataloader,test_dataloader,num_epochs=50,device=device)
################################################################################################################
#这里是测试过程
test_dataset_ext = TestDataset('test.csv', transform=transforms)    
test_loader = torch.utils.data.DataLoader(test_dataset_ext, batch_size=32, shuffle=False)
all_preds = []
with torch.no_grad():
    for images, img_paths in test_loader:
        images = images.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1).cpu().tolist()
        # 将数字标签映射回类别名称
        idx_to_label = {v: k for k, v in label_mapping.items()}
        pred_labels = [idx_to_label[p] for p in preds]
        # 保存图片路径和预测结果
        for path, label in zip(img_paths, pred_labels):
            all_preds.append((path, label))
#将结果保存为csv进行提交
df_result = pd.DataFrame(all_preds, columns=['image', 'label'])
df_result.to_csv('pred_submission.csv', index=False)
################################################################################################################
相关推荐
千匠网络13 小时前
破局出海壁垒,千匠网络新能源汽车跨境出海解决方案
人工智能
马丁聊GEO15 小时前
解码AI用户心智,筑牢可信GEO根基——悠易科技深度参与《中国AI用户态度与行为研究报告(2026)》发布会
人工智能·科技
nap-joker15 小时前
Fusion - Mamba用于跨模态目标检测
人工智能·目标检测·计算机视觉·fusion-mamba·可见光-红外成像融合·远距离/伪目标问题
一只幸运猫.16 小时前
2026Java 后端面试完整版|八股简答 + AI 大模型集成技术(最新趋势)
人工智能·面试·职场和发展
Promise微笑16 小时前
2026年国产替代油介损测试仪:油介损全场景解决方案与技术演进
大数据·网络·人工智能
深海鱼在掘金16 小时前
深入浅出 LangChain —— 第三章:模型抽象层
人工智能·langchain·agent
生信碱移16 小时前
PACells:这个方法可以鉴定疾病/预后相关的重要细胞亚群,作者提供的代码流程可以学习起来了,甚至兼容转录组与 ATAC 两种数据类型!
人工智能·学习·算法·机器学习·数据挖掘·数据分析·r语言
workflower16 小时前
具身智能行业应用-生活服务业
大数据·人工智能·机器人·动态规划·生活
GitCode官方16 小时前
基于昇腾 MindSpeed LLM 玩转 DeepSeekV4-Flash 模型的预训练复现部署
人工智能·开源·atomgit
大刘讲IT16 小时前
AI重塑企业信息价值标准:从“系统供给”到“用户定义”的企业数字化新范式
人工智能·经验分享·ai·制造