线性可分logistic逻辑回归
数据集点被分为了两边,根据课程学会归一化函数以及梯度下降即可。
使用线性模型。
python
import copy
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
#归一化函数
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# 计算损失
def compute_cost_logistic(X, y, w, b):
m = X.shape[0]
cost = 0.0
for i in range(m):
z_i = np.dot(X[i], w) + b
f_wb_i = sigmoid(z_i)
cost += -y[i] * np.log(f_wb_i) - (1 - y[i]) * np.log(1 - f_wb_i)
cost = cost / m
return cost
# 计算梯度逻辑
def compute_gradient_logistic(X, y, w, b):
m, n = X.shape
db_w = np.zeros(n)
db_b = 0.
for i in range(m):
z_i = sigmoid(np.dot(X[i], w) + b)
err_i = z_i - y[i]
for j in range(n):
db_w[j] += err_i * X[i][j]
db_b += err_i
return db_w / m, db_b / m
# 梯度下降
def gradient_descent(X, y, w, b, eta, num_iter):
m = X.shape[0]
for i in range(num_iter):
w_temp = copy.deepcopy(w)
b_temp = b
db_w, db_b = compute_gradient_logistic(X, y, w_temp, b_temp)
w = w_temp - eta * db_w
b = b_temp - eta * db_b
"""
if i == 0:
print(compute_cost_logistic(X, y, w, b))
"""
return w, b
if __name__ == '__main__':
data = pd.read_csv(r'D:\BaiduNetdiskDownload\data_sets\ex2data1.txt')
xx = data.iloc[:, 0:-1].to_numpy()
data = (data - data.min()) / (data.max() - data.min())
# 获取X,y训练集
X_train = data.iloc[:, 0:-1]
y_train = data.iloc[:, -1]
X_train = X_train.to_numpy()
y_train = y_train.to_numpy()
w_tmp = np.zeros_like(X_train[0])
b_tmp = 0.
alph = 0.1
iters = 10000
w_out, b_out = gradient_descent(X_train, y_train, w_tmp, b_tmp, alph, iters)
print(w_out, b_out)
# 根据 w,b画出关于x的图表
x = np.linspace(0, 1, 100)
k = (-b_out - w_out[0] * x ) / w_out[1]
X_air = xx[:, 0]
Y_air = np.zeros(X_air.shape[0])
plt.plot(x, k, color='blue')
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train)
plt.show()
# 计算准确率
count = 0
for i in range(X_train.shape[0]):
ans = sigmoid(np.dot(X_train[i], w_out) + b_out)
prediction = 1 if ans >= 0.5 else 0
if(prediction == y_train[i]): count += 1
print('Accuracy:{}'.format(count))
print(f"\nupdated parameters: w:{w_out}, b:{b_out}")
一些图表
回归方程和数据集:
预期结果:
w: [9.24150506 8.78629869] b: -8.125896329768265
Accuracy:88%