【机器学习】Exam4

实现线性不可分logistic逻辑回归

我们目前所学的都是线性回归,例如 y = w 1 x 1 + w 2 x 2 + b y = w_1x_1+w_2x_2+b y=w1x1+w2x2+b

用肉眼来看数据集的话不难发现,线性回归没有用了,那么根据课程所学,我们是不是可以增加 x 3 = x 1 x x , x 4 = x 1 2 , x 5 = x 2 2 x_3=x_1x_x,x_4=x_1^2,x_5=x_2^2 x3=x1xx,x4=x12,x5=x22呢?那么逻辑回归就可以变成
y = w 1 x 1 + w 2 x 2 + w 3 x 3 + w 4 x 4 + w 5 x 5 + b y=w_1x_1+w_2x_2+w_3x_3+w_4x_4+w_5x_5+b y=w1x1+w2x2+w3x3+w4x4+w5x5+b

python 复制代码
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt


def sigmoid(x):
    return 1/(1+np.exp(-x))


def compute_loss(X, y, w, b, lambada):
    m = X.shape[0]
    cost = 0.
    cost_gradient = 0.
    for i in range(m):
        z_i = sigmoid((np.dot(X[i], w) + b))
        cost += -y[i] * np.log(z_i) - (1 - y[i]) * np.log(1 - z_i)
        cost_gradient += w[i] ** 2
    return cost / m + lambada * cost_gradient / (2 * m)

def compute_gradient_logistic(X, y, w, b, eta, lambada):
    m, n = X.shape
    db_w = np.zeros(n)
    db_b = 0
    for i in range(m):
        z_i = sigmoid((np.dot(X[i], w) + b))
        err_i = z_i - y[i]
        for j in range(n):
            db_w[j] += err_i * X[i][j]
        db_b += err_i
    return db_w / m, db_b / m


def gradient_descent(X, y, w, b, eta, lambada, iterator):
    m, n = X.shape
    for i in range(iterator):
        w_tmp = np.copy(w)
        b_tmp = b
        db_w, db_b = compute_gradient_logistic(X, y, w_tmp, b, eta, lambada)
        db_w += lambada * w / m
        w = w - eta * db_w
        b = b - eta * db_b
    return w, b

if __name__ == '__main__':
    data = pd.read_csv(r'D:\BaiduNetdiskDownload\data_sets\ex2data2.txt')
    X_train = data.iloc[:, 0:-1].to_numpy()
    y_train = data.iloc[:, -1].to_numpy()

    x1 = (X_train[:, 0] * X_train[:, 1]).reshape(-1, 1)
    x2 = (X_train[:, 0] ** 2).reshape(-1, 1)
    x3 = (X_train[:, 1] ** 2).reshape(-1, 1)

    X_train = np.hstack((X_train, x1, x2, x3))
    w_tmp = np.zeros_like(X_train[0])
    b_tmp = 0.
    alph = 0.1
    lambada = 0.01
    iters = 10000
    w_out, b_out = gradient_descent(X_train, y_train, w_tmp, b_tmp, alph, lambada, iters)

    count = 0
    for i in range(X_train.shape[0]):
        ans = sigmoid(np.dot(X_train[i], w_out) + b_out)
        prediction = 1 if ans > 0.5 else 0
        if y_train[i] == prediction:
            count += 1
    print('Accuracy = {}'.format(count/X_train.shape[0]))
    print(w_out, b_out)
    plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train)

    # 绘制决策边界
    x_min, x_max = X_train[:, 0].min() - 0.1, X_train[:, 0].max() + 0.1
    y_min, y_max = X_train[:, 1].min() - 0.1, X_train[:, 1].max() + 0.1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                         np.arange(y_min, y_max, 0.01))

    # 创建与网格形状匹配的特征
    grid = np.c_[xx.ravel(), yy.ravel()]
    print('grid_shape : {}'.format(grid.shape))
    grid_x1 = (grid[:, 0] * grid[:, 1]).reshape(-1, 1)
    grid_x2 = (grid[:, 0] ** 2).reshape(-1, 1)
    grid_x3 = (grid[:, 1] ** 2).reshape(-1, 1)
    grid_features = np.hstack((grid, grid_x1, grid_x2, grid_x3))

    # 计算网格点的预测值
    Z = sigmoid(np.dot(grid_features, w_out) + b_out)
    Z = Z.reshape(xx.shape)

    # 绘制决策边界
    plt.contour(xx, yy, Z, levels=[0.5], colors='g')

    # 显示图形
    plt.xlabel('x1')
    plt.ylabel('x2')
    plt.title('Decision Boundary')
    plt.show()
    
一些图

Accuracy = 0.8376068376068376

然后就是各个参数w1,w2,w3,w4,b

2.12915132 2.82388529 -4.83135528 -8.64819153 -8.31828602\] 3.7305124000753627

相关推荐
白熊18827 分钟前
【计算机视觉】CV实战项目 - 基于YOLOv5的人脸检测与关键点定位系统深度解析
人工智能·yolo·计算机视觉
nenchoumi311929 分钟前
VLA 论文精读(十六)FP3: A 3D Foundation Policy for Robotic Manipulation
论文阅读·人工智能·笔记·学习·vln
后端小肥肠39 分钟前
文案号搞钱潜规则:日入四位数的Coze工作流我跑通了
人工智能·coze
LCHub低代码社区41 分钟前
钧瓷产业原始创新的许昌共识:技术破壁·产业再造·生态重构(一)
大数据·人工智能·维格云·ai智能体·ai自动化·大禹智库·钧瓷码
-曾牛41 分钟前
Spring AI 快速入门:从环境搭建到核心组件集成
java·人工智能·spring·ai·大模型·spring ai·开发环境搭建
阿川20151 小时前
云智融合普惠大模型AI,政务服务重构数智化路径
人工智能·华为云·政务·deepseek
自由鬼1 小时前
开源AI开发工具:OpenAI Codex CLI
人工智能·ai·开源·软件构建·开源软件·个人开发
生信碱移1 小时前
大语言模型时代,单细胞注释也需要集思广益(mLLMCelltype)
人工智能·经验分享·深度学习·语言模型·自然语言处理·数据挖掘·数据可视化
88号技师1 小时前
【1区SCI】Fusion entropy融合熵,多尺度,复合多尺度、时移多尺度、层次 + 故障识别、诊断-matlab代码
开发语言·机器学习·matlab·时序分析·故障诊断·信息熵·特征提取
一个数据大开发2 小时前
解读《数据资产质量评估实施规则》:企业数据资产认证落地的关键指南
大数据·数据库·人工智能