【Python机器学习】实验11 神经网络-感知器

文章目录

  • 人工神经网络
    • 感知机
    • [1. 基于手写代码的感知器模型](#1. 基于手写代码的感知器模型)
      • [1.1 数据读取](#1.1 数据读取)
      • [1.2 构建感知器模型](#1.2 构建感知器模型)
      • [1.3 实例化模型并训练模型](#1.3 实例化模型并训练模型)
      • [1.4 可视化](#1.4 可视化)
    • [2. 基于sklearn的感知器实现](#2. 基于sklearn的感知器实现)
      • [2.1 数据获取与前面相同](#2.1 数据获取与前面相同)
      • [2.2 导入类库](#2.2 导入类库)
      • [2.3 实例化感知器](#2.3 实例化感知器)
      • [2.4 采用数据拟合感知器](#2.4 采用数据拟合感知器)
      • [2.5 可视化](#2.5 可视化)
    • [实验1 将上面数据划分为训练数据和测试数据,并在Perpetron_model类中定义score函数,训练后利用score函数来输出测试分数](#实验1 将上面数据划分为训练数据和测试数据,并在Perpetron_model类中定义score函数,训练后利用score函数来输出测试分数)

人工神经网络

感知机

1.感知机是根据输入实例的特征向量 x x x对其进行二类分类的线性分类模型:

f ( x ) = sign ⁡ ( w ⋅ x + b ) f(x)=\operatorname{sign}(w \cdot x+b) f(x)=sign(w⋅x+b)

感知机模型对应于输入空间(特征空间)中的分离超平面 w ⋅ x + b = 0 w \cdot x+b=0 w⋅x+b=0。

2.感知机学习的策略是极小化损失函数:

min ⁡ w , b L ( w , b ) = − ∑ x i ∈ M y i ( w ⋅ x i + b ) \min {w, b} L(w, b)=-\sum{x_{i} \in M} y_{i}\left(w \cdot x_{i}+b\right) w,bminL(w,b)=−xi∈M∑yi(w⋅xi+b)

损失函数对应于误分类点到分离超平面的总距离。

3.感知机学习算法是基于随机梯度下降法的对损失函数的最优化算法,有原始形式和对偶形式。算法简单且易于实现。原始形式中,首先任意选取一个超平面,然后用梯度下降法不断极小化目标函数。在这个过程中一次随机选取一个误分类点使其梯度下降。

4.当训练数据集线性可分时,感知机学习算法是收敛的。感知机算法在训练数据集上的误分类次数 k k k满足不等式:

k ⩽ ( R γ ) 2 k \leqslant\left(\frac{R}{\gamma}\right)^{2} k⩽(γR)2

当训练数据集线性可分时,感知机学习算法存在无穷多个解,其解由于不同的初值或不同的迭代顺序而可能有所不同。

二分类模型

f ( x ) = s i g n ( w ⋅ x + b ) f(x) = sign(w\cdot x + b) f(x)=sign(w⋅x+b)

sign ⁡ ( x ) = { + 1 , x ⩾ 0 − 1 , x < 0 \operatorname{sign}(x)=\left\{\begin{array}{ll}{+1,} & {x \geqslant 0} \\ {-1,} & {x<0}\end{array}\right. sign(x)={+1,−1,x⩾0x<0

给定训练集:

T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x N , y N ) } T=\left\{\left(x_{1}, y_{1}\right),\left(x_{2}, y_{2}\right), \cdots,\left(x_{N}, y_{N}\right)\right\} T={(x1,y1),(x2,y2),⋯,(xN,yN)}

定义感知机的损失函数

L ( w , b ) = − ∑ x i ∈ M y i ( w ⋅ x i + b ) L(w, b)=-\sum_{x_{i} \in M} y_{i}\left(w \cdot x_{i}+b\right) L(w,b)=−∑xi∈Myi(w⋅xi+b)


算法

随即梯度下降法 Stochastic Gradient Descent

随机抽取一个误分类点使其梯度下降。

w = w + η y i x i w = w + \eta y_{i}x_{i} w=w+ηyixi

b = b + η y i b = b + \eta y_{i} b=b+ηyi

当实例点被误分类,即位于分离超平面的错误侧,则调整 w w w, b b b的值,使分离超平面向该无分类点的一侧移动,直至误分类点被正确分类

拿出iris数据集中两个分类的数据和[sepal length,sepal width]作为特征

1. 基于手写代码的感知器模型

1.1 数据读取

python 复制代码
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
%matplotlib inline
python 复制代码
# load data
iris = load_iris()
iris
复制代码
{'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        [5.4, 3.9, 1.7, 0.4],
        [4.6, 3.4, 1.4, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [4.4, 2.9, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.7, 1.5, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3. , 1.4, 0.1],
        [4.3, 3. , 1.1, 0.1],
        [5.8, 4. , 1.2, 0.2],
        [5.7, 4.4, 1.5, 0.4],
        [5.4, 3.9, 1.3, 0.4],
        [5.1, 3.5, 1.4, 0.3],
        [5.7, 3.8, 1.7, 0.3],
        [5.1, 3.8, 1.5, 0.3],
        [5.4, 3.4, 1.7, 0.2],
        [5.1, 3.7, 1.5, 0.4],
        [4.6, 3.6, 1. , 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [4.8, 3.4, 1.9, 0.2],
        [5. , 3. , 1.6, 0.2],
        [5. , 3.4, 1.6, 0.4],
        [5.2, 3.5, 1.5, 0.2],
        [5.2, 3.4, 1.4, 0.2],
        [4.7, 3.2, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5.2, 4.1, 1.5, 0.1],
        [5.5, 4.2, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.2],
        [5. , 3.2, 1.2, 0.2],
        [5.5, 3.5, 1.3, 0.2],
        [4.9, 3.6, 1.4, 0.1],
        [4.4, 3. , 1.3, 0.2],
        [5.1, 3.4, 1.5, 0.2],
        [5. , 3.5, 1.3, 0.3],
        [4.5, 2.3, 1.3, 0.3],
        [4.4, 3.2, 1.3, 0.2],
        [5. , 3.5, 1.6, 0.6],
        [5.1, 3.8, 1.9, 0.4],
        [4.8, 3. , 1.4, 0.3],
        [5.1, 3.8, 1.6, 0.2],
        [4.6, 3.2, 1.4, 0.2],
        [5.3, 3.7, 1.5, 0.2],
        [5. , 3.3, 1.4, 0.2],
        [7. , 3.2, 4.7, 1.4],
        [6.4, 3.2, 4.5, 1.5],
        [6.9, 3.1, 4.9, 1.5],
        [5.5, 2.3, 4. , 1.3],
        [6.5, 2.8, 4.6, 1.5],
        [5.7, 2.8, 4.5, 1.3],
        [6.3, 3.3, 4.7, 1.6],
        [4.9, 2.4, 3.3, 1. ],
        [6.6, 2.9, 4.6, 1.3],
        [5.2, 2.7, 3.9, 1.4],
        [5. , 2. , 3.5, 1. ],
        [5.9, 3. , 4.2, 1.5],
        [6. , 2.2, 4. , 1. ],
        [6.1, 2.9, 4.7, 1.4],
        [5.6, 2.9, 3.6, 1.3],
        [6.7, 3.1, 4.4, 1.4],
        [5.6, 3. , 4.5, 1.5],
        [5.8, 2.7, 4.1, 1. ],
        [6.2, 2.2, 4.5, 1.5],
        [5.6, 2.5, 3.9, 1.1],
        [5.9, 3.2, 4.8, 1.8],
        [6.1, 2.8, 4. , 1.3],
        [6.3, 2.5, 4.9, 1.5],
        [6.1, 2.8, 4.7, 1.2],
        [6.4, 2.9, 4.3, 1.3],
        [6.6, 3. , 4.4, 1.4],
        [6.8, 2.8, 4.8, 1.4],
        [6.7, 3. , 5. , 1.7],
        [6. , 2.9, 4.5, 1.5],
        [5.7, 2.6, 3.5, 1. ],
        [5.5, 2.4, 3.8, 1.1],
        [5.5, 2.4, 3.7, 1. ],
        [5.8, 2.7, 3.9, 1.2],
        [6. , 2.7, 5.1, 1.6],
        [5.4, 3. , 4.5, 1.5],
        [6. , 3.4, 4.5, 1.6],
        [6.7, 3.1, 4.7, 1.5],
        [6.3, 2.3, 4.4, 1.3],
        [5.6, 3. , 4.1, 1.3],
        [5.5, 2.5, 4. , 1.3],
        [5.5, 2.6, 4.4, 1.2],
        [6.1, 3. , 4.6, 1.4],
        [5.8, 2.6, 4. , 1.2],
        [5. , 2.3, 3.3, 1. ],
        [5.6, 2.7, 4.2, 1.3],
        [5.7, 3. , 4.2, 1.2],
        [5.7, 2.9, 4.2, 1.3],
        [6.2, 2.9, 4.3, 1.3],
        [5.1, 2.5, 3. , 1.1],
        [5.7, 2.8, 4.1, 1.3],
        [6.3, 3.3, 6. , 2.5],
        [5.8, 2.7, 5.1, 1.9],
        [7.1, 3. , 5.9, 2.1],
        [6.3, 2.9, 5.6, 1.8],
        [6.5, 3. , 5.8, 2.2],
        [7.6, 3. , 6.6, 2.1],
        [4.9, 2.5, 4.5, 1.7],
        [7.3, 2.9, 6.3, 1.8],
        [6.7, 2.5, 5.8, 1.8],
        [7.2, 3.6, 6.1, 2.5],
        [6.5, 3.2, 5.1, 2. ],
        [6.4, 2.7, 5.3, 1.9],
        [6.8, 3. , 5.5, 2.1],
        [5.7, 2.5, 5. , 2. ],
        [5.8, 2.8, 5.1, 2.4],
        [6.4, 3.2, 5.3, 2.3],
        [6.5, 3. , 5.5, 1.8],
        [7.7, 3.8, 6.7, 2.2],
        [7.7, 2.6, 6.9, 2.3],
        [6. , 2.2, 5. , 1.5],
        [6.9, 3.2, 5.7, 2.3],
        [5.6, 2.8, 4.9, 2. ],
        [7.7, 2.8, 6.7, 2. ],
        [6.3, 2.7, 4.9, 1.8],
        [6.7, 3.3, 5.7, 2.1],
        [7.2, 3.2, 6. , 1.8],
        [6.2, 2.8, 4.8, 1.8],
        [6.1, 3. , 4.9, 1.8],
        [6.4, 2.8, 5.6, 2.1],
        [7.2, 3. , 5.8, 1.6],
        [7.4, 2.8, 6.1, 1.9],
        [7.9, 3.8, 6.4, 2. ],
        [6.4, 2.8, 5.6, 2.2],
        [6.3, 2.8, 5.1, 1.5],
        [6.1, 2.6, 5.6, 1.4],
        [7.7, 3. , 6.1, 2.3],
        [6.3, 3.4, 5.6, 2.4],
        [6.4, 3.1, 5.5, 1.8],
        [6. , 3. , 4.8, 1.8],
        [6.9, 3.1, 5.4, 2.1],
        [6.7, 3.1, 5.6, 2.4],
        [6.9, 3.1, 5.1, 2.3],
        [5.8, 2.7, 5.1, 1.9],
        [6.8, 3.2, 5.9, 2.3],
        [6.7, 3.3, 5.7, 2.5],
        [6.7, 3. , 5.2, 2.3],
        [6.3, 2.5, 5. , 1.9],
        [6.5, 3. , 5.2, 2. ],
        [6.2, 3.4, 5.4, 2.3],
        [5.9, 3. , 5.1, 1.8]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'frame': None,
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n                \n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%[email protected])\n    :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...',
 'feature_names': ['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)'],
 'filename': 'iris.csv',
 'data_module': 'sklearn.datasets.data'}
python 复制代码
# load data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
python 复制代码
df.head()

| | sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | label |
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |

4 5.0 3.6 1.4 0.2 0
python 复制代码
df.columns=["sepal length","sepal width","petal length","petal width","label"]
python 复制代码
#查看标签元素列的元素种类和个数
df["label"].value_counts()
复制代码
0    50
1    50
2    50
Name: label, dtype: int64
python 复制代码
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
复制代码
<matplotlib.legend.Legend at 0x215d7f87f40>
python 复制代码
data = np.array(df.iloc[:100, [0, 1, -1]])
python 复制代码
X, y = data[:,:-1], data[:,-1]
python 复制代码
data[:,-1]
复制代码
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
python 复制代码
y = np.array([1 if i == 1 else -1 for i in y])
python 复制代码
y
复制代码
array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1])
python 复制代码
X[:5],y[:5]
复制代码
(array([[5.1, 3.5],
        [4.9, 3. ],
        [4.7, 3.2],
        [4.6, 3.1],
        [5. , 3.6]]),
 array([-1, -1, -1, -1, -1]))

w = w + η y i x i w = w + \eta y_{i}x_{i} w=w+ηyixi

b = b + η y i b = b + \eta y_{i} b=b+ηyi

1.2 构建感知器模型

python 复制代码
y.shape
复制代码
(100,)
python 复制代码
class Perception_model:
    def __init__(self,n):
        self.w=np.zeros(n,dtype=np.float32)
        self.b=0
        self.l_rate=0.1
    def sign(self,x):
        y=np.dot(x,self.w)+self.b
        return y
    def fit(self,X_train,y_train):
        is_wrong=True
        while is_wrong:
            is_wrong=False
            for i in range(len(X_train)):
                if y_train[i]*self.sign(X_train[i])<=0:
                    self.w=self.w+self.l_rate*np.dot(y_train[i],X_train[i])
                    self.b=self.b+self.l_rate*y_train[i]
                    is_wrong=True

1.3 实例化模型并训练模型

python 复制代码
model=Perception_model(X.shape[1])
model.fit(X,y)

1.4 可视化

python 复制代码
np.max(X[:,0]),np.min(X[:,0])
复制代码
(7.0, 4.3)
python 复制代码
X_fig=np.arange(int(np.min(X[:,0])),int(np.max(X[:,0])+1),0.5)
X_fig
#w[0]*x1+w[1]*x2+b=0
复制代码
array([4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5])
python 复制代码
y1=-(model.w[0]*X_fig+model.b)/model.w[1]
plt.plot(X_fig,y1,"r-+")
plt.scatter(X[:50,0],X[:50,1],label=0)
plt.scatter(X[50:100,0],X[50:100,1],label=1)
plt.show()

2. 基于sklearn的感知器实现

2.1 数据获取与前面相同

2.2 导入类库

python 复制代码
from sklearn.linear_model import Perceptron

2.3 实例化感知器

python 复制代码
model=Perceptron(fit_intercept=True,max_iter=1000,shuffle=True)

2.4 采用数据拟合感知器

python 复制代码
model.fit(X,y)
复制代码
Perceptron()
python 复制代码
model.coef_
复制代码
array([[ 23.2, -38.7]])
python 复制代码
model.intercept_
复制代码
array([-5.])

2.5 可视化

python 复制代码
# 画布大小
plt.figure(figsize=(6,4))

# 中文标题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.title('鸢尾花线性数据示例')

X_fig=np.arange(int(np.min(X[:,0])),int(np.max(X[:,0])+1),0.5)
X_fig
y1=-(model.coef_[0][0]*X_fig+model.intercept_)/model.coef_[0][1]
plt.plot(X_fig,y1,"r-+")
plt.scatter(X[:50,0],X[:50,1],label=0)
plt.scatter(X[50:100,0],X[50:100,1],label=1)

plt.legend()  # 显示图例
plt.grid(False)  # 不显示网格
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
plt.show()

注意 !

在上图中,有一个位于左下角的蓝点没有被正确分类,这是因为 SKlearn 的 Perceptron 实例中有一个tol参数。

tol 参数规定了如果本次迭代的损失和上次迭代的损失之差小于一个特定值时,停止迭代。所以我们需要设置 tol=None 使之可以继续迭代:

python 复制代码
model=Perceptron(fit_intercept=True,max_iter=1000,shuffle=True,tol=None)
model.fit(X,y)
复制代码
Perceptron(tol=None)
python 复制代码
# 画布大小
plt.figure(figsize=(6,4))

# 中文标题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.title('鸢尾花线性数据示例')

X_fig=np.arange(int(np.min(X[:,0])),int(np.max(X[:,0])+1),0.5)
X_fig
y1=-(model.coef_[0][0]*X_fig+model.intercept_)/model.coef_[0][1]
plt.plot(X_fig,y1,"r-+")
plt.scatter(X[:50,0],X[:50,1],label=0)
plt.scatter(X[50:100,0],X[50:100,1],label=1)

plt.legend()  # 显示图例
plt.grid(False)  # 不显示网格
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
plt.show()

现在可以看到,所有的两种鸢尾花都被正确分类了。

实验1 将上面数据划分为训练数据和测试数据,并在Perpetron_model类中定义score函数,训练后利用score函数来输出测试分数

1. 数据读取

python 复制代码
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
%matplotlib inline
python 复制代码
# load data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
python 复制代码
df.columns=["sepal length","sepal width","petal length","petal width","label"]
data = np.array(df.iloc[:100, [0, 1, -1]])
X, y = data[:,:-1], data[:,-1]
y = np.array([1 if i == 1 else -1 for i in y])

2. 划分训练数据和测试数据

python 复制代码
from sklearn.model_selection import train_test_split

划分训练数据和测试数据

python 复制代码
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2)

3. 定义感知器类

定义下面的实例方法score函数

python 复制代码
class Perception_model:
    def __init__(self,n):
        self.w=np.zeros(n,dtype=np.float32)
        self.b=0
        self.l_rate=0.1
    def sign(self,x):
        y=np.dot(x,self.w)+self.b
        return y
    def fit(self,X_train,y_train):
        is_wrong=True
        while is_wrong:
            is_wrong=False
            for i in range(len(X_train)):
                if y_train[i]*self.sign(X_train[i])<=0:
                    self.w=self.w+self.l_rate*np.dot(y_train[i],X_train[i])
                    self.b=self.b+self.l_rate*y_train[i]
                    is_wrong=True
                    
                    
    def score(self,X_test,y_test):
        accuracy=0
        for i in range(len(X_test)):
            if self.sign(X_test[i])<=0 and y_test[i]==-1:
                accuracy+=1
            if self.sign(X_test[i])>0 and y_test[i]==1:
                accuracy+=1
        return accuracy/len(X_test)

4. 实例化模型并训练模型

python 复制代码
model_1=Perception_model(len(X_train[0]))
model_1.fit(X_train,y_train)

5. 测试模型

调用实例方法score函数

python 复制代码
model_1.score(X_test,y_test)
复制代码
1.0
相关推荐
jndingxin3 分钟前
OpenCV 图形API(11)对图像进行掩码操作的函数mask()
人工智能·opencv·计算机视觉
Scc_hy13 分钟前
强化学习_Paper_1988_Learning to predict by the methods of temporal differences
人工智能·深度学习·算法
袁煦丞16 分钟前
【亲测】1.5万搞定DeepSeek满血版!本地部署避坑指南+内网穿透黑科技揭秘
人工智能·程序员·远程工作
大模型真好玩17 分钟前
理论+代码一文带你深入浅出MCP:人工智能大模型与外部世界交互的革命性突破
人工智能·python·mcp
_一条咸鱼_19 分钟前
LangChain 入门到精通
机器学习
遇码31 分钟前
大语言模型开发框架——LangChain
人工智能·语言模型·langchain·llm·大模型开发·智能体
在狂风暴雨中奔跑31 分钟前
使用AI开发Android界面
android·人工智能
飞哥数智坊33 分钟前
AI编程实战:30分钟实现Web 3D船舶航行效果
人工智能·three.js
誉鏐36 分钟前
从零开始设计Transformer模型(1/2)——剥离RNN,保留Attention
人工智能·深度学习·transformer
Ai野生菌37 分钟前
工具介绍 | SafeLLMDeploy教程来了 保护本地LLM安全部署
网络·人工智能·安全·大模型·llm