昇思25天学习打卡营第14天|基于MindSpore的红酒分类实验

背景介绍

本文主要介绍使用MindSpore在部分wine数据集上进行KNN实验。

K近邻算法原理

K近邻算法(K-Nearest-Neighbor, KNN)是一种用于分类和回归的非参数统计方法,最初由 Cover和Hart于1968年提出(Cover等人,1967),是机器学习最基础的算法之一。它正是基于以上思想:要确定一个样本的类别,可以计算它与所有训练样本的距离,然后找出和该样本最接近的k个样本,统计出这些样本的类别并进行投票,票数最多的那个类就是分类的结果。KNN的三个基本要素:

  • K值,一个样本的分类是由K个邻居的"多数表决"确定的。K值越小,容易受噪声影响,反之,会使类别之间的界限变得模糊。
  • 距离度量,反映了特征空间中两个样本间的相似度,距离越小,越相似。常用的有Lp距离(p=2时,即为欧式距离)、曼哈顿距离、海明距离等。
  • 分类决策规则,通常是多数表决,或者基于距离加权的多数表决(权值与距离成反比)。

数据集

数据准备

Wine数据集是模式识别最著名的数据集之一,Wine数据集的官网:Wine Data Set。这些数据是对来自意大利同一地区但来自三个不同品种的葡萄酒进行化学分析的结果。数据集分析了三种葡萄酒中每种所含13种成分的量,分别是:

  1. Alcohol,酒精
  2. Malic acid,苹果酸
  3. Ash,灰
  4. Alcalinity of ash,灰的碱度
  5. Magnesium,镁
  6. Total phenols,总酚
  7. Flavanoids,类黄酮
  8. Nonflavanoid phenols,非黄酮酚
  9. Proanthocyanins,原花青素
  10. Color intensity,色彩强度
  11. Hue,色调
  12. OD280/OD315 of diluted wines,稀释酒的OD280/OD315
  13. Proline,脯氨酸

获取方式:

Key Value Key Value
Data Set Characteristics: Multivariate Number of Instances: 178
Attribute Characteristics: Integer, Real Number of Attributes: 13
Associated Tasks: Classification Missing Values? No

代码示例:

python 复制代码
from download import download

# 下载红酒数据集
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MachineLearning/wine.zip"  
path = download(url, "./", kind="zip", replace=True)

数据读取与处理

代码示例:

python 复制代码
%matplotlib inline
import os
import csv
import numpy as np
import matplotlib.pyplot as plt

import mindspore as ms
from mindspore import nn, ops

ms.set_context(device_target="CPU")

# 读取Wine数据集并查看部分数据
with open('wine.data') as csv_file:
    data = list(csv.reader(csv_file, delimiter=','))
print(data[56:62]+data[130:133])

# 取三类样本(共178条),将数据集的13个属性作为自变量X。将数据集的3个类别作为因变量Y
X = np.array([[float(x) for x in s[1:]] for s in data[:178]], np.float32)
Y = np.array([s[0] for s in data[:178]], np.int32)

# 将数据集按128:50划分为训练集(已知类别样本)和验证集(待验证样本)
train_idx = np.random.choice(178, 128, replace=False)
test_idx = np.array(list(set(range(178)) - set(train_idx)))
X_train, Y_train = X[train_idx], Y[train_idx]
X_test, Y_test = X[test_idx], Y[test_idx]

# 取样本的某两个属性进行2维可视化,可以看到在某两个属性上样本的分布情况以及可分性
attrs = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols',
         'Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue',
         'OD280/OD315 of diluted wines', 'Proline']
plt.figure(figsize=(10, 8))
for i in range(0, 4):
    plt.subplot(2, 2, i+1)
    a1, a2 = 2 * i, 2 * i + 1
    plt.scatter(X[:59, a1], X[:59, a2], label='1')
    plt.scatter(X[59:130, a1], X[59:130, a2], label='2')
    plt.scatter(X[130:, a1], X[130:, a2], label='3')
    plt.xlabel(attrs[a1])
    plt.ylabel(attrs[a2])
    plt.legend()
plt.show()

运行结果:

模型构建

利用MindSpore提供的tile, square, ReduceSum, sqrt, TopK等算子,通过矩阵运算的方式同时计算输入样本x和已明确分类的其他样本X_train的距离,并计算出top k近邻。

代码示例:

python 复制代码
class KnnNet(nn.Cell):
    def __init__(self, k):
        super(KnnNet, self).__init__()
        self.k = k

    def construct(self, x, X_train):
        #平铺输入x以匹配X_train中的样本数
        x_tile = ops.tile(x, (128, 1))
        square_diff = ops.square(x_tile - X_train)
        square_dist = ops.sum(square_diff, 1)
        dist = ops.sqrt(square_dist)
        #-dist表示值越大,样本就越接近
        values, indices = ops.topk(-dist, self.k)
        return indices

def knn(knn_net, x, X_train, Y_train):
    x, X_train = ms.Tensor(x), ms.Tensor(X_train)
    indices = knn_net(x, X_train)
    topk_cls = [0]*len(indices.asnumpy())
    for idx in indices.asnumpy():
        topk_cls[Y_train[idx]] += 1
    cls = np.argmax(topk_cls)
    return cls

模型预测

在验证集上验证KNN算法的有效性,取 𝑘=5 ,验证精度接近80%,说明KNN算法在该3分类任务上有效,能根据酒的13种属性判断出酒的品种。

代码示例:

python 复制代码
acc = 0
knn_net = KnnNet(5)
for x, y in zip(X_test, Y_test):
    pred = knn(knn_net, x, X_train, Y_train)
    acc += (pred == y)
    print('label: %d, prediction: %s' % (y, pred))
print('Validation accuracy is %f' % (acc/len(Y_test)))

运行结果:

截图时间

相关推荐
铭瑾熙10 分钟前
深度学习之GAN应用
人工智能·深度学习·生成对抗网络
一只老虎27 分钟前
基于 OpenCV 和 dlib 方法进行视频人脸检测的研究
人工智能·opencv·音视频
GOSIM 全球开源创新汇29 分钟前
对话 OpenCV 之父 Gary Bradski:灾难性遗忘和持续学习是尚未解决的两大挑战 | Open AGI Forum
opencv·学习·计算机视觉·ai·自动驾驶
全域观察32 分钟前
开源,一天200star,解锁视频字幕生成新方式——一款轻量级开源字幕工具,免费,支持花字,剪映最新会员模式吃相太难看了
人工智能·新媒体运营·开源软件·内容运营·程序员创富
她说人狗殊途38 分钟前
数据分析24.11.13
数据挖掘·数据分析
L_cl42 分钟前
Python学习从0到1 day29 Python 高阶技巧 ⑦ 正则表达式
学习
不去幼儿园1 小时前
【SSL-RL】自监督强化学习: 好奇心驱动探索 (CDE)算法
大数据·人工智能·python·算法·机器学习·强化学习
SaNDJie1 小时前
24.11.13 机器学习 特征降维(主成份分析) KNN算法 交叉验证(K-Fold) 超参数搜索
人工智能·算法·机器学习
努力成为DBA的小王2 小时前
Linux( 权限+特殊权限 图片+大白话)
linux·运维·服务器·学习
YAy173 小时前
CC3学习记录
java·开发语言·学习·网络安全·安全威胁分析