【无标题】

Day31

一、目录结构

复制代码
project/
├── config.py
├── data/
│   ├── heart.csv
│   ├── heart_train.csv
│   └── heart_test.csv
├── models/
├── src/
│   ├── data_preparation.py
│   ├── evaluate.py
│   ├── train.py
│   └── utils.py
└── main.py

二、具体结构

1、config.py(该文件用于集中存储所有配置参数)

复制代码
# 文件路径配置
DATA_DIR = '../data'
MODELS_DIR = '../models'
INPUT_FILE_PATH = f'{DATA_DIR}/heart.csv'
TRAIN_FILE_PATH = f'{DATA_DIR}/heart_train.csv'
TEST_FILE_PATH = f'{DATA_DIR}/heart_test.csv'

# 数据拆分配置
TEST_SIZE = 0.2
RANDOM_STATE = 42

2、src/utils.py(此文件存放通用的辅助函数)

复制代码
import pandas as pd
from config import DATA_DIR, MODELS_DIR

def load_csv_data(file_path):
    """
    加载 CSV 数据
    """
    return pd.read_csv(file_path)

def save_csv_data(data, file_path):
    """
    保存数据到 CSV 文件
    """
    data.to_csv(file_path, index=False)

def get_model_save_path(model_name):
    """
    获取模型保存路径
    """
    return f'{MODELS_DIR}/{model_name}.pkl'

3、src/data_preparation.py(负责数据处理部分,包括数据的加载和拆分)

复制代码
from sklearn.model_selection import train_test_split
from src.utils import load_csv_data, save_csv_data
from config import INPUT_FILE_PATH, TRAIN_FILE_PATH, TEST_FILE_PATH, TEST_SIZE, RANDOM_STATE

def prepare_data():
    """
    加载数据并拆分为训练集和测试集,然后保存
    """
    df = load_csv_data(INPUT_FILE_PATH)
    train_data, test_data = train_test_split(df, test_size=TEST_SIZE, random_state=RANDOM_STATE)
    save_csv_data(train_data, TRAIN_FILE_PATH)
    save_csv_data(test_data, TEST_FILE_PATH)
    return train_data, test_data

4、src/train.py(用于编写模型训练的代码)

复制代码
import joblib
from sklearn.linear_model import LinearRegression
from src.utils import get_model_save_path

def train_model(train_data):
    """
    训练模型
    """
    X_train = train_data.drop('target', axis=1)  # 假设目标列为 'target',需根据实际调整
    y_train = train_data['target']
    model = LinearRegression()
    model.fit(X_train, y_train)
    # 保存模型
    model_save_path = get_model_save_path('linear_regression_model')
    joblib.dump(model, model_save_path)
    return model

5、src/evalution.py(对训练好的模型进行评估)

复制代码
from sklearn.metrics import mean_squared_error
import joblib
from src.utils import get_model_save_path

def evaluate_model(model, test_data):
    """
    评估模型
    """
    X_test = test_data.drop('target', axis=1)  # 假设目标列为 'target',需根据实际调整
    y_test = test_data['target']
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    return mse

def load_saved_model(model_name):
    """
    加载保存的模型
    """
    model_path = get_model_save_path(model_name)
    return joblib.load(model_path)

6、main.py(主文件,协调数据处理、训练和评估的流程)

复制代码
from src.data_preparation import prepare_data
from src.train import train_model
from src.evaluate import evaluate_model, load_saved_model
from config import DATA_DIR, MODELS_DIR

# 数据准备
train_data, test_data = prepare_data()

# 模型训练
model = train_model(train_data)

# 模型评估
mse = evaluate_model(model, test_data)
print(f"模型的均方误差: {mse}")

# 加载保存的模型进行评估(可选,展示加载功能)
saved_model = load_saved_model('linear_regression_model')
saved_mse = evaluate_model(saved_model, test_data)
print(f"加载保存模型的均方误差: {saved_mse}")

@浙大疏锦行

相关推荐
m0_737302584 天前
iOS IPA 安装 Plist 文件生成工具
macos·objective-c·cocoa
pop_xiaoli10 天前
effective-Objective-C 第四章阅读笔记
笔记·ios·objective-c·cocoa·xcode
松叶似针13 天前
Flutter三方库适配OpenHarmony【secure_application】— iOS 端原生模糊遮罩实现分析
flutter·ios·cocoa
追夢秋陽14 天前
Cocoa 使用NSCollectionView显示列表,数据不足布局异常处理
macos·objective-c·cocoa·swift·collectionview
TheNextByte116 天前
如何修复iPhone短信消失问题?
ios·cocoa·iphone
TheNextByte116 天前
如何在没有iTunes的情况下重启/恢复出厂设置iPhone
ios·cocoa·iphone
带娃的IT创业者16 天前
解密OpenClaw系列04-OpenClaw技术架构
macos·架构·cocoa·agent·ai agent·openclaw
带娃的IT创业者16 天前
解密OpenClaw_03-OpenClaw核心功能特性
macos·系统架构·objective-c·cocoa·软件工程·智能体开发·openclaw
符哥200819 天前
Moya+Alamofire搭建网络框架
网络·macos·cocoa
TheNextByte119 天前
如何轻松地将 iPhone 上的信息传输到荣耀手机
智能手机·cocoa·iphone