支持向量机 (support vector machine,SVM)

支持向量机 (support vector machine,SVM)

flyfish

支持向量机是一种用于分类和回归的机器学习模型。在分类任务中,SVM试图找到一个最佳的分隔超平面,使得不同类别的数据点在空间中被尽可能宽的间隔分开。

超平面方程和直线方程

超平面(hyperplane)是一个在高维空间中将空间分成两个部分的几何对象。它的方程可以在不同维度的空间中有不同的形式。

一维空间中的"超平面"

在一维空间中,超平面就是一个点。假设我们在一维空间中有一个超平面,它可以表示为:
x = a x = a x=a

其中, a a a 是某个常数。这表示一维空间中的一个特定点,将空间分成两个部分: x < a x < a x<a 和 x > a x > a x>a。

二维空间中的超平面(直线)

在二维空间中,超平面就是一条直线。直线的方程可以表示为:
y = k x + b y = kx + b y=kx+b

其中, k k k 是斜率, b b b 是截距。或者,可以表示为标准形式:
a x + b y + c = 0 ax + by + c = 0 ax+by+c=0

其中, a a a、 b b b、 c c c 是常数。

这条直线将二维空间分成两个半平面。

三维空间中的超平面(平面)

在三维空间中,超平面是一个平面。平面的方程可以表示为:
a x + b y + c z + d = 0 ax + by + cz + d = 0 ax+by+cz+d=0

其中, a a a、 b b b、 c c c 和 d d d 是常数。

这个平面将三维空间分成两个半空间。

一般形式的超平面方程

在更高维度的空间中,超平面的方程一般可以表示为:
w ⋅ x + b = 0 \mathbf{w} \cdot \mathbf{x} + b = 0 w⋅x+b=0

其中:

  • w = ( w 1 , w 2 , ... , w n ) \mathbf{w} = (w_1, w_2, \ldots, w_n) w=(w1,w2,...,wn) 是一个权重向量,定义了超平面的方向。

  • x = ( x 1 , x 2 , ... , x n ) \mathbf{x} = (x_1, x_2, \ldots, x_n) x=(x1,x2,...,xn) 是一个点的坐标向量。

  • b b b 是偏置。

    这个超平面将 n n n 维空间分成两个半空间。

直线方程是超平面方程在二维空间中的一种特例。一般来说,超平面是 n n n 维空间中的一个 ( n − 1 ) (n-1) (n−1) 维的对象:

  • 在一维空间中,超平面是一个点。

  • 在二维空间中,超平面是一个直线。

  • 在三维空间中,超平面是一个平面。

  • 在四维及更高维空间中,超平面是一个 ( n − 1 ) (n-1) (n−1) 维的对象。

示例和理解

一维空间中的超平面

x = 2 x = 2 x=2

这是在一维空间中的一个点,将空间分为 x < 2 x < 2 x<2 和 x > 2 x > 2 x>2 两部分。

二维空间中的超平面

标准形式:
2 x + 3 y − 6 = 0 2x + 3y - 6 = 0 2x+3y−6=0

或者:
y = − 2 3 x + 2 y = -\frac{2}{3}x + 2 y=−32x+2

这是在二维空间中的一条直线。

三维空间中的超平面

2 x + 3 y + 4 z − 5 = 0 2x + 3y + 4z - 5 = 0 2x+3y+4z−5=0

这是在三维空间中的一个平面。

py 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm

# 生成一些数据
np.random.seed(0)
X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
Y = [0] * 20 + [1] * 20

# 拟合模型
clf = svm.SVC(kernel='linear')
clf.fit(X, Y)

# 绘制数据点和分类超平面
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()

# 创建网格以评估模型
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf.decision_function(xy).reshape(XX.shape)

# 绘制分类超平面
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--'])
ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100, linewidth=1, facecolors='none', edgecolors='k')
plt.show()
py 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from mpl_toolkits.mplot3d import Axes3D

# 生成三维数据
np.random.seed(0)
X = np.r_[np.random.randn(20, 3) - [2, 2, 2], np.random.randn(20, 3) + [2, 2, 2]]
Y = [0] * 20 + [1] * 20

# 拟合模型
clf = svm.SVC(kernel='linear')
clf.fit(X, Y)

# 创建一个网格来绘制分类平面
xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))
zz = (-clf.intercept_[0] - clf.coef_[0][0] * xx - clf.coef_[0][1] * yy) / clf.coef_[0][2]

# 绘制数据点和分类平面
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(X[:20, 0], X[:20, 1], X[:20, 2], color='b', marker='o', label='Class 0')
ax.scatter(X[20:, 0], X[20:, 1], X[20:, 2], color='r', marker='^', label='Class 1')

ax.plot_surface(xx, yy, zz, color='g', alpha=0.5, rstride=100, cstride=100)

ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('X3')

plt.legend()
plt.show()

最大间隔解释

py 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 生成一个简单的二维分类数据集
X, y = datasets.make_blobs(n_samples=50, centers=2, random_state=6)

# 训练一个线性支持向量机
clf = SVC(kernel='linear', C=1000)
clf.fit(X, y)

# 获取分隔超平面
w = clf.coef_[0]
b = clf.intercept_[0]

# 计算分隔超平面的两个端点
x = np.linspace(-10, 10, 100)
y_hyperplane = -w[0]/w[1] * x - b/w[1]

# 计算间隔边界
margin = 1 / np.sqrt(np.sum(w ** 2))
y_margin_up = y_hyperplane + margin
y_margin_down = y_hyperplane - margin

# 绘制数据点、分隔超平面及其间隔边界
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm')
plt.plot(x, y_hyperplane, 'k-', label='分隔超平面')
plt.plot(x, y_margin_up, 'k--', label='上间隔边界')
plt.plot(x, y_margin_down, 'k--', label='下间隔边界')

# 绘制支持向量
plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], 
            s=100, facecolors='none', edgecolors='k', label='支持向量')

plt.legend()
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('最大化间隔的 SVM')
plt.show()

拉格朗日乘子法

相关推荐
掘金安东尼2 小时前
Amazon Lambda + API Gateway 实战,无服务器架构入门
算法·架构
码流之上3 小时前
【一看就会一写就废 指间算法】设计电子表格 —— 哈希表、字符串处理
javascript·算法
快手技术5 小时前
快手提出端到端生成式搜索框架 OneSearch,让搜索“一步到位”!
算法
CoovallyAIHub1 天前
中科大DSAI Lab团队多篇论文入选ICCV 2025,推动三维视觉与泛化感知技术突破
深度学习·算法·计算机视觉
NAGNIP1 天前
Serverless 架构下的大模型框架落地实践
算法·架构
moonlifesudo1 天前
半开区间和开区间的两个二分模版
算法
moonlifesudo1 天前
300:最长递增子序列
算法
CoovallyAIHub1 天前
港大&字节重磅发布DanceGRPO:突破视觉生成RLHF瓶颈,多项任务性能提升超180%!
深度学习·算法·计算机视觉
CoovallyAIHub1 天前
英伟达ViPE重磅发布!解决3D感知难题,SLAM+深度学习完美融合(附带数据集下载地址)
深度学习·算法·计算机视觉
聚客AI2 天前
🙋‍♀️Transformer训练与推理全流程:从输入处理到输出生成
人工智能·算法·llm