支持向量机 (support vector machine,SVM)

支持向量机 (support vector machine,SVM)

flyfish

支持向量机是一种用于分类和回归的机器学习模型。在分类任务中,SVM试图找到一个最佳的分隔超平面,使得不同类别的数据点在空间中被尽可能宽的间隔分开。

超平面方程和直线方程

超平面(hyperplane)是一个在高维空间中将空间分成两个部分的几何对象。它的方程可以在不同维度的空间中有不同的形式。

一维空间中的"超平面"

在一维空间中,超平面就是一个点。假设我们在一维空间中有一个超平面,它可以表示为:
x = a x = a x=a

其中, a a a 是某个常数。这表示一维空间中的一个特定点,将空间分成两个部分: x < a x < a x<a 和 x > a x > a x>a。

二维空间中的超平面(直线)

在二维空间中,超平面就是一条直线。直线的方程可以表示为:
y = k x + b y = kx + b y=kx+b

其中, k k k 是斜率, b b b 是截距。或者,可以表示为标准形式:
a x + b y + c = 0 ax + by + c = 0 ax+by+c=0

其中, a a a、 b b b、 c c c 是常数。

这条直线将二维空间分成两个半平面。

三维空间中的超平面(平面)

在三维空间中,超平面是一个平面。平面的方程可以表示为:
a x + b y + c z + d = 0 ax + by + cz + d = 0 ax+by+cz+d=0

其中, a a a、 b b b、 c c c 和 d d d 是常数。

这个平面将三维空间分成两个半空间。

一般形式的超平面方程

在更高维度的空间中,超平面的方程一般可以表示为:
w ⋅ x + b = 0 \mathbf{w} \cdot \mathbf{x} + b = 0 w⋅x+b=0

其中:

  • w = ( w 1 , w 2 , ... , w n ) \mathbf{w} = (w_1, w_2, \ldots, w_n) w=(w1,w2,...,wn) 是一个权重向量,定义了超平面的方向。

  • x = ( x 1 , x 2 , ... , x n ) \mathbf{x} = (x_1, x_2, \ldots, x_n) x=(x1,x2,...,xn) 是一个点的坐标向量。

  • b b b 是偏置。

    这个超平面将 n n n 维空间分成两个半空间。

直线方程是超平面方程在二维空间中的一种特例。一般来说,超平面是 n n n 维空间中的一个 ( n − 1 ) (n-1) (n−1) 维的对象:

  • 在一维空间中,超平面是一个点。

  • 在二维空间中,超平面是一个直线。

  • 在三维空间中,超平面是一个平面。

  • 在四维及更高维空间中,超平面是一个 ( n − 1 ) (n-1) (n−1) 维的对象。

示例和理解

一维空间中的超平面

x = 2 x = 2 x=2

这是在一维空间中的一个点,将空间分为 x < 2 x < 2 x<2 和 x > 2 x > 2 x>2 两部分。

二维空间中的超平面

标准形式:
2 x + 3 y − 6 = 0 2x + 3y - 6 = 0 2x+3y−6=0

或者:
y = − 2 3 x + 2 y = -\frac{2}{3}x + 2 y=−32x+2

这是在二维空间中的一条直线。

三维空间中的超平面

2 x + 3 y + 4 z − 5 = 0 2x + 3y + 4z - 5 = 0 2x+3y+4z−5=0

这是在三维空间中的一个平面。

py 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm

# 生成一些数据
np.random.seed(0)
X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
Y = [0] * 20 + [1] * 20

# 拟合模型
clf = svm.SVC(kernel='linear')
clf.fit(X, Y)

# 绘制数据点和分类超平面
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()

# 创建网格以评估模型
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf.decision_function(xy).reshape(XX.shape)

# 绘制分类超平面
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--'])
ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100, linewidth=1, facecolors='none', edgecolors='k')
plt.show()
py 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from mpl_toolkits.mplot3d import Axes3D

# 生成三维数据
np.random.seed(0)
X = np.r_[np.random.randn(20, 3) - [2, 2, 2], np.random.randn(20, 3) + [2, 2, 2]]
Y = [0] * 20 + [1] * 20

# 拟合模型
clf = svm.SVC(kernel='linear')
clf.fit(X, Y)

# 创建一个网格来绘制分类平面
xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))
zz = (-clf.intercept_[0] - clf.coef_[0][0] * xx - clf.coef_[0][1] * yy) / clf.coef_[0][2]

# 绘制数据点和分类平面
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(X[:20, 0], X[:20, 1], X[:20, 2], color='b', marker='o', label='Class 0')
ax.scatter(X[20:, 0], X[20:, 1], X[20:, 2], color='r', marker='^', label='Class 1')

ax.plot_surface(xx, yy, zz, color='g', alpha=0.5, rstride=100, cstride=100)

ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('X3')

plt.legend()
plt.show()

最大间隔解释

py 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 生成一个简单的二维分类数据集
X, y = datasets.make_blobs(n_samples=50, centers=2, random_state=6)

# 训练一个线性支持向量机
clf = SVC(kernel='linear', C=1000)
clf.fit(X, y)

# 获取分隔超平面
w = clf.coef_[0]
b = clf.intercept_[0]

# 计算分隔超平面的两个端点
x = np.linspace(-10, 10, 100)
y_hyperplane = -w[0]/w[1] * x - b/w[1]

# 计算间隔边界
margin = 1 / np.sqrt(np.sum(w ** 2))
y_margin_up = y_hyperplane + margin
y_margin_down = y_hyperplane - margin

# 绘制数据点、分隔超平面及其间隔边界
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm')
plt.plot(x, y_hyperplane, 'k-', label='分隔超平面')
plt.plot(x, y_margin_up, 'k--', label='上间隔边界')
plt.plot(x, y_margin_down, 'k--', label='下间隔边界')

# 绘制支持向量
plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], 
            s=100, facecolors='none', edgecolors='k', label='支持向量')

plt.legend()
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('最大化间隔的 SVM')
plt.show()

拉格朗日乘子法

相关推荐
LDG_AGI3 小时前
【人工智能】Transformers之Pipeline(七):图像分割(image-segmentation)
大数据·人工智能·python·深度学习·机器学习·计算机视觉·docker
zhangbin_2373 小时前
【Python机器学习】利用AdaBoost元算法提高分类性能——基于数据集多重抽样的分类器
开发语言·python·算法·机器学习·分类
zhangbin_2373 小时前
【Python机器学习】利用AdaBoost元算法提高分类性能——基于AdaBoost的分类
开发语言·人工智能·python·算法·机器学习·分类
嫦娥妹妹等等我3 小时前
理解 Objective-C 中 +load 方法的执行顺序
c++·算法·图论
Bee.Bee.4 小时前
【前端面试】七、算法-数组展平
前端·javascript·算法
B站计算机毕业设计超人5 小时前
计算机毕业设计Python+Tensorflow股票推荐系统 股票预测系统 股票可视化 股票数据分析 量化交易系统 股票爬虫 股票K线图 大数据毕业设计 AI
人工智能·爬虫·python·深度学习·机器学习·tensorflow·数据可视化
临街的小孩6 小时前
yolov8 剪枝
算法·yolo·剪枝
╰★忝若冇凊★丶7 小时前
C++面试基础算法的简要介绍
c++·算法·排序算法
近听水无声4778 小时前
排序算法2:直接选择排序与快速排序
数据结构·算法·排序算法
zhqh1008 小时前
动手学强化学习 第 12 章 PPO 算法(PPOContinuous) 训练代码
人工智能·pytorch·深度学习·神经网络·算法