目录
[3.神经网络 DNN](#3.神经网络 DNN)
1.前言
生活中有很多事情,有因果关系,但关系不明显。
例如,一个店的生意多少和多个因素有关,是否下雨,是否是周末或节假日,今天温度多少度。当我们收集到足够多的数据,也许就能根据明天的日期、天气、温度来预测明天的生意流水。

本文以火灾检测为例子,根据温度,烟雾,co浓度 来判断是否发生了火灾。

数据如下:
python
#三特征三分类问题
#三特征为(温度、CO、烟雾),类别为无火,阴燃火,明火
#无火
Class1 = [(0.0, 0.0, 20.0), (11.1, 5.27, 20.5), (12.4, 6.04, 20.6), (14.2, 7.68, 20.7), (15.2, 8.73, 20.9),
(16.1, 12.2, 21.4), (16.0, 12.8, 21.6), (16.4, 13.4, 21.9), (16.6, 13.3, 22.0), (16.7, 12.9, 22.2),
(16.6, 12.1, 22.4), (16.9, 11.5, 22.6), (17.1, 10.9, 22.7), (19.6, 10.0, 22.8), (18.6, 15.4, 22.9),
(17.6, 15.0, 23.0), (15.5, 13.8, 23.0), (17.9, 13.1, 23.1), (18.2, 13.2, 23.1), (18.3, 13.3, 23.1),
(18.5, 13.7, 23.2), (18.2, 14.0, 23.2), (18.3, 14.2, 23.2), (17.7, 17.0, 23.3), (18.3, 15.3, 23.3),
(19.1, 15.7, 23.4), (20.8, 16.6, 23.5), (21.9, 17.2, 23.6), (23.2, 17.9, 23.7), (26.5, 13.6, 24.0),
(28.8, 14.8, 24.1), (32.1, 16.5, 24.3), (32.4, 21.9, 24.6), (34.0, 26.1, 26.0), (34.4, 28.0, 26.4),
(33.9, 29.1, 26.8), (34.2, 29.7, 27.7), (34.4, 29.0, 28.1), (34.8, 27.9, 28.4), (34.9, 25.9, 28.9), (35.3, 24.9, 29.9),
(35.6, 25.5, 30.1), (35.8, 26.5, 30.3), (35.9, 28.8, 30.8), (36.1, 29.4, 31.0), (36.3, 29.4, 31.2),
(37.2, 27.0, 31.8), (39.7, 20.5, 31.6), (37.9, 19.5, 31.5), (38.5, 19.8, 31.4)]
#阴燃火
Class2=[(40.9, 21.1, 31.2), (43.1, 22.3, 31.2), (45.5, 23.6, 31.2), (49.0, 23.9, 31.2), (50.0, 24.3, 31.2),
(50.5, 24.8, 31.3), (50.4, 25.2, 31.3), (50.0, 25.4, 31.4), (49.3, 25.4, 31.4), (46.7, 24.1, 31.4),
(45.9, 24.2, 31.4), (45.6, 23.9, 31.4), (45.3, 23.8, 31.4), (44.8, 23.5, 31.3), (44.7, 23.7, 31.3),
(44.8, 23.9, 31.2), (46.3, 24.3, 31.2), (47.5, 24.9, 31.1), (49.1, 25.8, 31.1), (51.9, 27.3, 31.2),
(53.4, 28.1, 31.2), (57.7, 30.2, 31.5), (58.5, 30.6, 31.6), (59.8, 31.3, 31.7), (60.4, 31.6, 31.8),
(62.6, 32.8, 32.2), (63.4, 33.3, 32.3), (65.5, 34.4, 32.5), (67.0, 35.2, 32.6), (68.9, 36.2, 32.8),
(73.6, 38.8, 33.1), (76.3, 40.3, 33.2), (80.0, 42.3, 33.4), (88.5, 47.1, 33.8), (91.2, 48.6, 34.0),
(93.0, 49.6, 34.2), (93.0, 49.7, 34.7), (91.7, 49.0, 34.9), (89.9, 47.9, 35.0), (84.6, 45.0, 35.4),
(82.1, 43.6, 35.5), (79.8, 42.3, 35.6), (76.6, 40.6, 35.7), (76.1, 40.3, 35.8), (76.0, 40.2, 35.8),
(77.7, 41.1, 36.0), (79.3, 42.0, 36.0), (84.3, 44.8, 36.2), (87.7, 46.8, 36.3), (102.0, 54.7, 36.9),
(107.0, 57.5, 37.2), (110.0, 59.8, 37.4), (114.0, 62.1, 38.1), (114.0, 61.6, 38.4), (112.0, 60.4, 38.6),
(106.0, 57.4, 39.0), (104.0, 56.1, 39.2), (103.0, 55.2, 39.3), (106.0, 57.0, 39.6), (110.0, 59.7, 39.8),
(117.0, 63.6, 40.1), (131.0, 71.9, 40.7), (136.0, 75.1, 41.1), (140.0, 77.4, 41.5), (142.0, 78.9, 42.3),
(142.0, 78.4, 42.7), (140.0, 77.7, 43.2), (141.0, 77.8, 43.9), (142.0, 78.7, 44.3), (143.0, 79.3, 44.6),
(140.0, 77.5, 45.3), (138.0, 76.3, 45.6), (138.0, 76.0, 45.9), (136.0, 75.3, 46.3), (132.0, 72.8, 46.5),
(126.0, 69.4, 46.6), (118.0, 64.6, 46.7), (116.0, 63.1, 46.6), (114.0, 62.1, 46.6), (112.0, 60.3, 46.4),
(110.0, 59.4, 46.3), (108.0, 58.4, 46.2), (105.0, 56.6, 45.9), (104.0, 55.7, 45.7), (103.0, 55.3, 45.6),
(105.0, 56.3, 45.3), (108.0, 58.0, 45.2), (112.0, 60.3, 45.1), (122.0, 66.6, 45.1), (129.0, 70.4, 45.3),
(135.0, 74.2, 45.4), (151.0, 83.4, 46.0), (153.0, 84.8, 46.3), (152.0, 84.2, 46.6), (146.0, 80.3, 47.2),
(142.0, 78.1, 47.3), (138.0, 75.9, 47.4), (122.0, 66.1, 47.5), (127.0, 73.2, 47.3), (125.0, 71.9, 47.2),
(122.0, 70.4, 47.0), (128.0, 78.3, 46.7), (129.0, 78.5, 46.6), (175.0, 79.6, 46.5), (139.0, 84.6, 46.4),
(147.0, 89.4, 46.4) ]
#明火
Class3=[(158.0, 95.8, 46.6), (172.0, 100.0, 47.1), (175.0, 109.0, 47.6), (158.0, 77.0, 48.2), (172.0, 86.0, 49.6),
(179.0, 95.0, 50.4), (171.0, 89.0, 51.0), (166.0, 93.3, 51.8), (158.0, 88.5, 52.0), (153.0, 85.2, 52.2),
(152.0, 84.7, 52.4), (159.0, 89.1, 52.6), (170.0, 96.1, 52.9), (173.0, 95.0, 53.9), (169.0, 92.0, 54.3),
(171.0, 96.6, 54.6), (163.0, 95.5, 54.9), (155.0, 90.6, 54.9), (159.0, 86.0, 54.8), (161.0, 78.3, 54.6),
(152.0, 76.9, 54.4), (159.0, 78.2, 54.3), (162.0, 85.1, 54.2), (156.0, 80.7, 54.3), (157.0, 87.4, 54.5),
(177.0, 100.0, 55.0), (188.0, 107.0, 55.5), (196.0, 113.0, 56.0), (208.0, 121.0, 57.4), (205.0, 119.0, 58.0),
(199.0, 115.0, 58.6), (180.0, 103.0, 59.7), (175.0, 99.5, 60.1), (177.0, 101.0, 60.4), (195.0, 113.0, 61.2),
(205.0, 119.0, 61.8), (199.0, 125.0, 62.4), (201.0, 132.0, 63.9), (208.0, 136.0, 64.6), (209.0, 144.0, 65.4),
(201.0, 150.0, 67.3), (200.0, 136.0, 68.0), (201.0, 117.0, 68.3), (169.0, 111.0, 67.3), (184.0, 136.0, 66.3),
(172.0, 129.0, 65.3), (178.0, 141.0, 62.9), (189.0, 151.0, 61.7), (205.0, 165.0, 60.7), (200.0, 153.0, 59.2),
(193.0, 145.0, 58.6), (203.0, 155.0, 58.2), (189.0, 143.0, 57.7), (189.0, 143.0, 57.6), (192.0, 145.0, 57.5),
(194.0, 142.0, 57.5), (161.0, 107.0, 57.6), (166.0, 110.0, 57.7), (173.0, 116.0, 57.9), (177.0, 119.0, 58.0),
(181.0, 122.0, 58.1), (190.0, 129.0, 58.2), (193.0, 132.0, 58.2), (195.0, 134.0, 58.2), (197.0, 137.0, 58.1),
(201.0, 141.0, 58.0), (199.0, 135.0, 58.0), (195.0, 120.0, 58.5), (198.0, 116.0, 59.3), (200.0, 138.0, 60.5),
(202.0, 167.0, 64.0), (201.0, 176.0, 65.9), (204.0, 178.0, 67.8), (200.0, 126.0, 73.0), (202.0, 132.0, 72.9),
(189.0, 140.0, 70.3), (188.0, 140.0, 69.6), (189.0, 141.0, 68.9), (194.0, 142.0, 68.1), (202.0, 147.0, 67.8),
(210.0, 153.0, 67.6), (237.0, 173.0, 68.5), (239.0, 174.0, 68.9), (169.0, 105.0, 69.7), (169.0, 105.0, 69.8),
(171.0, 107.0, 69.8), (185.0, 117.0, 69.6), (192.0, 121.0, 69.5), (210.0, 134.0, 69.4), (204.0, 131.0, 69.3),
(189.0, 130.0, 68.6), (184.0, 131.0, 68.0), (183.0, 132.0, 67.3), (201.0, 143.0, 66.2), (182.0, 120.0, 65.8),
(182.0, 121.0, 66.1), (184.0, 124.0, 66.1), (182.0, 124.0, 65.7), (182.0, 125.0, 65.5), (180.0, 124.0, 64.9),
(183.0, 128.0, 64.5), (181.0, 127.0, 64.1), (180.0, 126.0, 63.3), (185.0, 129.0, 63.0), (186.0, 128.0, 62.7),
(208.0, 142.0, 62.6), (199.0, 99.5, 63.0), (208.0, 123.0, 64.0), (245.0, 150.0, 69.0), (234.0, 142.0, 70.7),
(246.0, 126.0, 74.5), (244.0, 130.0, 75.3), (256.0, 145.0, 76.7), (267.0, 166.0, 79.9)]

X 烟雾,y co,z 温度。我们需要把这些三维空间中的点,分成3类。
=======================================
python环境配置 略
numpy:
pip install numpy
keras:
pip install keras -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
tensorflow:
pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple
sklearn:
pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple
================================================
下面使用不同的方法,完成此任务。
2.SVM
svm很适合干这件事


安装sklearn库
from sklearn.svm import SVC
代码如下:
python
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from joblib import dump
#三特征三分类问题
#三特征为(温度、CO、烟雾),类别为无火,阴燃火,明火
#无火
Class1 = [(0.0, 0.0, 20.0), (11.1, 5.27, 20.5), (12.4, 6.04, 20.6), (14.2, 7.68, 20.7), (15.2, 8.73, 20.9),
(16.1, 12.2, 21.4), (16.0, 12.8, 21.6), (16.4, 13.4, 21.9), (16.6, 13.3, 22.0), (16.7, 12.9, 22.2),
(16.6, 12.1, 22.4), (16.9, 11.5, 22.6), (17.1, 10.9, 22.7), (19.6, 10.0, 22.8), (18.6, 15.4, 22.9),
(17.6, 15.0, 23.0), (15.5, 13.8, 23.0), (17.9, 13.1, 23.1), (18.2, 13.2, 23.1), (18.3, 13.3, 23.1),
(18.5, 13.7, 23.2), (18.2, 14.0, 23.2), (18.3, 14.2, 23.2), (17.7, 17.0, 23.3), (18.3, 15.3, 23.3),
(19.1, 15.7, 23.4), (20.8, 16.6, 23.5), (21.9, 17.2, 23.6), (23.2, 17.9, 23.7), (26.5, 13.6, 24.0),
(28.8, 14.8, 24.1), (32.1, 16.5, 24.3), (32.4, 21.9, 24.6), (34.0, 26.1, 26.0), (34.4, 28.0, 26.4),
(33.9, 29.1, 26.8), (34.2, 29.7, 27.7), (34.4, 29.0, 28.1), (34.8, 27.9, 28.4), (34.9, 25.9, 28.9), (35.3, 24.9, 29.9),
(35.6, 25.5, 30.1), (35.8, 26.5, 30.3), (35.9, 28.8, 30.8), (36.1, 29.4, 31.0), (36.3, 29.4, 31.2),
(37.2, 27.0, 31.8), (39.7, 20.5, 31.6), (37.9, 19.5, 31.5), (38.5, 19.8, 31.4)]
#阴燃火
Class2=[(40.9, 21.1, 31.2), (43.1, 22.3, 31.2), (45.5, 23.6, 31.2), (49.0, 23.9, 31.2), (50.0, 24.3, 31.2),
(50.5, 24.8, 31.3), (50.4, 25.2, 31.3), (50.0, 25.4, 31.4), (49.3, 25.4, 31.4), (46.7, 24.1, 31.4),
(45.9, 24.2, 31.4), (45.6, 23.9, 31.4), (45.3, 23.8, 31.4), (44.8, 23.5, 31.3), (44.7, 23.7, 31.3),
(44.8, 23.9, 31.2), (46.3, 24.3, 31.2), (47.5, 24.9, 31.1), (49.1, 25.8, 31.1), (51.9, 27.3, 31.2),
(53.4, 28.1, 31.2), (57.7, 30.2, 31.5), (58.5, 30.6, 31.6), (59.8, 31.3, 31.7), (60.4, 31.6, 31.8),
(62.6, 32.8, 32.2), (63.4, 33.3, 32.3), (65.5, 34.4, 32.5), (67.0, 35.2, 32.6), (68.9, 36.2, 32.8),
(73.6, 38.8, 33.1), (76.3, 40.3, 33.2), (80.0, 42.3, 33.4), (88.5, 47.1, 33.8), (91.2, 48.6, 34.0),
(93.0, 49.6, 34.2), (93.0, 49.7, 34.7), (91.7, 49.0, 34.9), (89.9, 47.9, 35.0), (84.6, 45.0, 35.4),
(82.1, 43.6, 35.5), (79.8, 42.3, 35.6), (76.6, 40.6, 35.7), (76.1, 40.3, 35.8), (76.0, 40.2, 35.8),
(77.7, 41.1, 36.0), (79.3, 42.0, 36.0), (84.3, 44.8, 36.2), (87.7, 46.8, 36.3), (102.0, 54.7, 36.9),
(107.0, 57.5, 37.2), (110.0, 59.8, 37.4), (114.0, 62.1, 38.1), (114.0, 61.6, 38.4), (112.0, 60.4, 38.6),
(106.0, 57.4, 39.0), (104.0, 56.1, 39.2), (103.0, 55.2, 39.3), (106.0, 57.0, 39.6), (110.0, 59.7, 39.8),
(117.0, 63.6, 40.1), (131.0, 71.9, 40.7), (136.0, 75.1, 41.1), (140.0, 77.4, 41.5), (142.0, 78.9, 42.3),
(142.0, 78.4, 42.7), (140.0, 77.7, 43.2), (141.0, 77.8, 43.9), (142.0, 78.7, 44.3), (143.0, 79.3, 44.6),
(140.0, 77.5, 45.3), (138.0, 76.3, 45.6), (138.0, 76.0, 45.9), (136.0, 75.3, 46.3), (132.0, 72.8, 46.5),
(126.0, 69.4, 46.6), (118.0, 64.6, 46.7), (116.0, 63.1, 46.6), (114.0, 62.1, 46.6), (112.0, 60.3, 46.4),
(110.0, 59.4, 46.3), (108.0, 58.4, 46.2), (105.0, 56.6, 45.9), (104.0, 55.7, 45.7), (103.0, 55.3, 45.6),
(105.0, 56.3, 45.3), (108.0, 58.0, 45.2), (112.0, 60.3, 45.1), (122.0, 66.6, 45.1), (129.0, 70.4, 45.3),
(135.0, 74.2, 45.4), (151.0, 83.4, 46.0), (153.0, 84.8, 46.3), (152.0, 84.2, 46.6), (146.0, 80.3, 47.2),
(142.0, 78.1, 47.3), (138.0, 75.9, 47.4), (122.0, 66.1, 47.5), (127.0, 73.2, 47.3), (125.0, 71.9, 47.2),
(122.0, 70.4, 47.0), (128.0, 78.3, 46.7), (129.0, 78.5, 46.6), (175.0, 79.6, 46.5), (139.0, 84.6, 46.4),
(147.0, 89.4, 46.4) ]
#明火
Class3=[(158.0, 95.8, 46.6), (172.0, 100.0, 47.1), (175.0, 109.0, 47.6), (158.0, 77.0, 48.2), (172.0, 86.0, 49.6),
(179.0, 95.0, 50.4), (171.0, 89.0, 51.0), (166.0, 93.3, 51.8), (158.0, 88.5, 52.0), (153.0, 85.2, 52.2),
(152.0, 84.7, 52.4), (159.0, 89.1, 52.6), (170.0, 96.1, 52.9), (173.0, 95.0, 53.9), (169.0, 92.0, 54.3),
(171.0, 96.6, 54.6), (163.0, 95.5, 54.9), (155.0, 90.6, 54.9), (159.0, 86.0, 54.8), (161.0, 78.3, 54.6),
(152.0, 76.9, 54.4), (159.0, 78.2, 54.3), (162.0, 85.1, 54.2), (156.0, 80.7, 54.3), (157.0, 87.4, 54.5),
(177.0, 100.0, 55.0), (188.0, 107.0, 55.5), (196.0, 113.0, 56.0), (208.0, 121.0, 57.4), (205.0, 119.0, 58.0),
(199.0, 115.0, 58.6), (180.0, 103.0, 59.7), (175.0, 99.5, 60.1), (177.0, 101.0, 60.4), (195.0, 113.0, 61.2),
(205.0, 119.0, 61.8), (199.0, 125.0, 62.4), (201.0, 132.0, 63.9), (208.0, 136.0, 64.6), (209.0, 144.0, 65.4),
(201.0, 150.0, 67.3), (200.0, 136.0, 68.0), (201.0, 117.0, 68.3), (169.0, 111.0, 67.3), (184.0, 136.0, 66.3),
(172.0, 129.0, 65.3), (178.0, 141.0, 62.9), (189.0, 151.0, 61.7), (205.0, 165.0, 60.7), (200.0, 153.0, 59.2),
(193.0, 145.0, 58.6), (203.0, 155.0, 58.2), (189.0, 143.0, 57.7), (189.0, 143.0, 57.6), (192.0, 145.0, 57.5),
(194.0, 142.0, 57.5), (161.0, 107.0, 57.6), (166.0, 110.0, 57.7), (173.0, 116.0, 57.9), (177.0, 119.0, 58.0),
(181.0, 122.0, 58.1), (190.0, 129.0, 58.2), (193.0, 132.0, 58.2), (195.0, 134.0, 58.2), (197.0, 137.0, 58.1),
(201.0, 141.0, 58.0), (199.0, 135.0, 58.0), (195.0, 120.0, 58.5), (198.0, 116.0, 59.3), (200.0, 138.0, 60.5),
(202.0, 167.0, 64.0), (201.0, 176.0, 65.9), (204.0, 178.0, 67.8), (200.0, 126.0, 73.0), (202.0, 132.0, 72.9),
(189.0, 140.0, 70.3), (188.0, 140.0, 69.6), (189.0, 141.0, 68.9), (194.0, 142.0, 68.1), (202.0, 147.0, 67.8),
(210.0, 153.0, 67.6), (237.0, 173.0, 68.5), (239.0, 174.0, 68.9), (169.0, 105.0, 69.7), (169.0, 105.0, 69.8),
(171.0, 107.0, 69.8), (185.0, 117.0, 69.6), (192.0, 121.0, 69.5), (210.0, 134.0, 69.4), (204.0, 131.0, 69.3),
(189.0, 130.0, 68.6), (184.0, 131.0, 68.0), (183.0, 132.0, 67.3), (201.0, 143.0, 66.2), (182.0, 120.0, 65.8),
(182.0, 121.0, 66.1), (184.0, 124.0, 66.1), (182.0, 124.0, 65.7), (182.0, 125.0, 65.5), (180.0, 124.0, 64.9),
(183.0, 128.0, 64.5), (181.0, 127.0, 64.1), (180.0, 126.0, 63.3), (185.0, 129.0, 63.0), (186.0, 128.0, 62.7),
(208.0, 142.0, 62.6), (199.0, 99.5, 63.0), (208.0, 123.0, 64.0), (245.0, 150.0, 69.0), (234.0, 142.0, 70.7),
(246.0, 126.0, 74.5), (244.0, 130.0, 75.3), (256.0, 145.0, 76.7), (267.0, 166.0, 79.9)]
# 整合数据并添加标签
data = []
labels = []
for i, cls in enumerate([Class1, Class2, Class3], 1):
for sample in cls:
data.append(sample)
labels.append(i - 1) # 标签从0开始
# 转换为numpy数组
X = np.array(data)
y = np.array(labels)
# 特征缩放(可选)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 选择模型并训练
#model = SVC(kernel='linear', C=1.0, random_state=42) # 使用线性SVM作为示例
#model = SVC(kernel='poly', degree=3, C=1.0, random_state=42) # degree是多项式的阶数
model = SVC(kernel='rbf', gamma='scale', C=1.0, random_state=42) # gamma是RBF的系数
model.fit(X_scaled, y)
# 保存模型
dump(model, 'fire_model3.joblib')
# 保存scaler
dump(scaler, 'scaler_model3.joblib')
其中,SVM的三种模型区别:

使用模型:
python
import numpy as np
from joblib import load
# 加载已保存的模型和预处理工具
model = load('fire_model3.joblib')
scaler = load('scaler_model3.joblib')
# 准备新数据(需保持与训练时相同的特征顺序:温度、CO、烟雾)
new_sample = np.array([[150.0, 215.0, 151.0]]) # 注意必须是二维数组
# 数据预处理(必须与训练时相同的标准化)
scaled_sample = scaler.transform(new_sample)
# 进行预测
prediction = model.predict(scaled_sample)
class_names = ['无火', '阴燃火', '明火']
print(f"预测结果:{class_names[prediction[0]]}")
批量测试
python
from joblib import load
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn.metrics import accuracy_score
# 加载模型
model = load('fire_model3.joblib')
loaded_scaler = load('scaler_model3.joblib') # 加载scaler对象
#无火
test_data_c1 = [ (15.8, 9.79, 21.0), (16.1, 11.5, 21.3), (17.5, 13.5, 23.1), (7.7, 13.2, 23.1),(32.8, 25.8, 24.9),
(33.5, 20.6, 25.2),(35.0, 25.0, 29.2), (34.9, 24.6, 29.4), (36.8, 25.7, 31.6), (0, 0, 21)]
#阴燃
test_data_c2=[(47.9, 24.7, 31.4), (47.1, 24.3, 31.4), (54.7, 28.8, 31.3), (56.8, 30.0, 31.4), (60.9, 31.9, 31.9),
(62.0, 32.5, 32.1),(76.5, 40.5, 35.9), (91.9, 49.1, 36.5),(130.0, 70.8, 47.5), (125.0, 68.1, 47.5), ]
#明火
test_data_c3=[(201.0, 138.0, 72.4), (191.0, 120.0, 71.8),(228.0, 164.0, 67.5), (235.0, 171.0, 68.1),
(195.0, 124.0, 69.5), (205.0, 131.0, 69.4),(180.0, 123.0, 66.1), (181.0, 124.0, 65.9),
(207.0, 167.0, 71.3), (206.0, 147.0, 72.4)]
# 将测试数据转换为numpy数组
test_data_c1 = np.array(test_data_c1)
test_data_c2 = np.array(test_data_c2)
test_data_c3 = np.array(test_data_c3)
# 将所有测试数据合并为一个数组
test_X = np.concatenate((test_data_c1, test_data_c2, test_data_c3), axis=0)
# 特征缩放(可选)
scaler = loaded_scaler
# 如果训练数据经过了特征缩放,测试数据也需要进行相同的缩放
test_X_scaled = scaler.fit_transform(test_X)
# 创建测试数据的标签
# 注意:这些标签是为了验证准确率而人为设置的,实际上在测试阶段你不会知道真实的标签
test_y_c1 = np.zeros(len(test_data_c1), dtype=int) # 无火的标签是0
test_y_c2 = np.ones(len(test_data_c2), dtype=int) # 阴燃火的标签是1
test_y_c3 = np.full(len(test_data_c3), 2, dtype=int) # 明火的标签是2
# 合并测试标签
test_y = np.concatenate((test_y_c1, test_y_c2, test_y_c3))
# 预测测试数据的类别
predictions = model.predict(test_X_scaled)
# 打印每个数据点的预测类别和真实类别
for idx, (pred, true) in enumerate(zip(predictions, test_y)):
print(f"Sample {idx}: Predicted class {pred}, True class {true}")
# 计算准确率
accuracy = accuracy_score(test_y, predictions)
print(f"Model accuracy: {accuracy}")
3.神经网络 DNN
采用全连接神经网络(Dense Neural Network)进行多分类任务:
数据预处理:标准化(StandardScaler) + One-hot编码
网络结构:单隐藏层(10个神经元) + 输出层(3个神经元)
激活函数:ReLU(隐藏层) + Softmax(输出层)
损失函数:分类交叉熵(Categorical Crossentropy)
优化器:Ada
代码如下:
python
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
import tensorflow
from sklearn.preprocessing import StandardScaler
from joblib import dump
#三特征三分类问题
#三特征为(温度、CO、烟雾),类别为无火,阴燃火,明火
#无火
Class1 = [(0.0, 0.0, 20.0), (11.1, 5.27, 20.5), (12.4, 6.04, 20.6), (14.2, 7.68, 20.7), (15.2, 8.73, 20.9),
(16.1, 12.2, 21.4), (16.0, 12.8, 21.6), (16.4, 13.4, 21.9), (16.6, 13.3, 22.0), (16.7, 12.9, 22.2),
(16.6, 12.1, 22.4), (16.9, 11.5, 22.6), (17.1, 10.9, 22.7), (19.6, 10.0, 22.8), (18.6, 15.4, 22.9),
(17.6, 15.0, 23.0), (15.5, 13.8, 23.0), (17.9, 13.1, 23.1), (18.2, 13.2, 23.1), (18.3, 13.3, 23.1),
(18.5, 13.7, 23.2), (18.2, 14.0, 23.2), (18.3, 14.2, 23.2), (17.7, 17.0, 23.3), (18.3, 15.3, 23.3),
(19.1, 15.7, 23.4), (20.8, 16.6, 23.5), (21.9, 17.2, 23.6), (23.2, 17.9, 23.7), (26.5, 13.6, 24.0),
(28.8, 14.8, 24.1), (32.1, 16.5, 24.3), (32.4, 21.9, 24.6), (34.0, 26.1, 26.0), (34.4, 28.0, 26.4),
(33.9, 29.1, 26.8), (34.2, 29.7, 27.7), (34.4, 29.0, 28.1), (34.8, 27.9, 28.4), (34.9, 25.9, 28.9), (35.3, 24.9, 29.9),
(35.6, 25.5, 30.1), (35.8, 26.5, 30.3), (35.9, 28.8, 30.8), (36.1, 29.4, 31.0), (36.3, 29.4, 31.2),
(37.2, 27.0, 31.8), (39.7, 20.5, 31.6), (37.9, 19.5, 31.5), (38.5, 19.8, 31.4)]
#阴燃火
Class2=[(40.9, 21.1, 31.2), (43.1, 22.3, 31.2), (45.5, 23.6, 31.2), (49.0, 23.9, 31.2), (50.0, 24.3, 31.2),
(50.5, 24.8, 31.3), (50.4, 25.2, 31.3), (50.0, 25.4, 31.4), (49.3, 25.4, 31.4), (46.7, 24.1, 31.4),
(45.9, 24.2, 31.4), (45.6, 23.9, 31.4), (45.3, 23.8, 31.4), (44.8, 23.5, 31.3), (44.7, 23.7, 31.3),
(44.8, 23.9, 31.2), (46.3, 24.3, 31.2), (47.5, 24.9, 31.1), (49.1, 25.8, 31.1), (51.9, 27.3, 31.2),
(53.4, 28.1, 31.2), (57.7, 30.2, 31.5), (58.5, 30.6, 31.6), (59.8, 31.3, 31.7), (60.4, 31.6, 31.8),
(62.6, 32.8, 32.2), (63.4, 33.3, 32.3), (65.5, 34.4, 32.5), (67.0, 35.2, 32.6), (68.9, 36.2, 32.8),
(73.6, 38.8, 33.1), (76.3, 40.3, 33.2), (80.0, 42.3, 33.4), (88.5, 47.1, 33.8), (91.2, 48.6, 34.0),
(93.0, 49.6, 34.2), (93.0, 49.7, 34.7), (91.7, 49.0, 34.9), (89.9, 47.9, 35.0), (84.6, 45.0, 35.4),
(82.1, 43.6, 35.5), (79.8, 42.3, 35.6), (76.6, 40.6, 35.7), (76.1, 40.3, 35.8), (76.0, 40.2, 35.8),
(77.7, 41.1, 36.0), (79.3, 42.0, 36.0), (84.3, 44.8, 36.2), (87.7, 46.8, 36.3), (102.0, 54.7, 36.9),
(107.0, 57.5, 37.2), (110.0, 59.8, 37.4), (114.0, 62.1, 38.1), (114.0, 61.6, 38.4), (112.0, 60.4, 38.6),
(106.0, 57.4, 39.0), (104.0, 56.1, 39.2), (103.0, 55.2, 39.3), (106.0, 57.0, 39.6), (110.0, 59.7, 39.8),
(117.0, 63.6, 40.1), (131.0, 71.9, 40.7), (136.0, 75.1, 41.1), (140.0, 77.4, 41.5), (142.0, 78.9, 42.3),
(142.0, 78.4, 42.7), (140.0, 77.7, 43.2), (141.0, 77.8, 43.9), (142.0, 78.7, 44.3), (143.0, 79.3, 44.6),
(140.0, 77.5, 45.3), (138.0, 76.3, 45.6), (138.0, 76.0, 45.9), (136.0, 75.3, 46.3), (132.0, 72.8, 46.5),
(126.0, 69.4, 46.6), (118.0, 64.6, 46.7), (116.0, 63.1, 46.6), (114.0, 62.1, 46.6), (112.0, 60.3, 46.4),
(110.0, 59.4, 46.3), (108.0, 58.4, 46.2), (105.0, 56.6, 45.9), (104.0, 55.7, 45.7), (103.0, 55.3, 45.6),
(105.0, 56.3, 45.3), (108.0, 58.0, 45.2), (112.0, 60.3, 45.1), (122.0, 66.6, 45.1), (129.0, 70.4, 45.3),
(135.0, 74.2, 45.4), (151.0, 83.4, 46.0), (153.0, 84.8, 46.3), (152.0, 84.2, 46.6), (146.0, 80.3, 47.2),
(142.0, 78.1, 47.3), (138.0, 75.9, 47.4), (122.0, 66.1, 47.5), (127.0, 73.2, 47.3), (125.0, 71.9, 47.2),
(122.0, 70.4, 47.0), (128.0, 78.3, 46.7), (129.0, 78.5, 46.6), (175.0, 79.6, 46.5), (139.0, 84.6, 46.4),
(147.0, 89.4, 46.4) ]
#明火
Class3=[(158.0, 95.8, 46.6), (172.0, 100.0, 47.1), (175.0, 109.0, 47.6), (158.0, 77.0, 48.2), (172.0, 86.0, 49.6),
(179.0, 95.0, 50.4), (171.0, 89.0, 51.0), (166.0, 93.3, 51.8), (158.0, 88.5, 52.0), (153.0, 85.2, 52.2),
(152.0, 84.7, 52.4), (159.0, 89.1, 52.6), (170.0, 96.1, 52.9), (173.0, 95.0, 53.9), (169.0, 92.0, 54.3),
(171.0, 96.6, 54.6), (163.0, 95.5, 54.9), (155.0, 90.6, 54.9), (159.0, 86.0, 54.8), (161.0, 78.3, 54.6),
(152.0, 76.9, 54.4), (159.0, 78.2, 54.3), (162.0, 85.1, 54.2), (156.0, 80.7, 54.3), (157.0, 87.4, 54.5),
(177.0, 100.0, 55.0), (188.0, 107.0, 55.5), (196.0, 113.0, 56.0), (208.0, 121.0, 57.4), (205.0, 119.0, 58.0),
(199.0, 115.0, 58.6), (180.0, 103.0, 59.7), (175.0, 99.5, 60.1), (177.0, 101.0, 60.4), (195.0, 113.0, 61.2),
(205.0, 119.0, 61.8), (199.0, 125.0, 62.4), (201.0, 132.0, 63.9), (208.0, 136.0, 64.6), (209.0, 144.0, 65.4),
(201.0, 150.0, 67.3), (200.0, 136.0, 68.0), (201.0, 117.0, 68.3), (169.0, 111.0, 67.3), (184.0, 136.0, 66.3),
(172.0, 129.0, 65.3), (178.0, 141.0, 62.9), (189.0, 151.0, 61.7), (205.0, 165.0, 60.7), (200.0, 153.0, 59.2),
(193.0, 145.0, 58.6), (203.0, 155.0, 58.2), (189.0, 143.0, 57.7), (189.0, 143.0, 57.6), (192.0, 145.0, 57.5),
(194.0, 142.0, 57.5), (161.0, 107.0, 57.6), (166.0, 110.0, 57.7), (173.0, 116.0, 57.9), (177.0, 119.0, 58.0),
(181.0, 122.0, 58.1), (190.0, 129.0, 58.2), (193.0, 132.0, 58.2), (195.0, 134.0, 58.2), (197.0, 137.0, 58.1),
(201.0, 141.0, 58.0), (199.0, 135.0, 58.0), (195.0, 120.0, 58.5), (198.0, 116.0, 59.3), (200.0, 138.0, 60.5),
(202.0, 167.0, 64.0), (201.0, 176.0, 65.9), (204.0, 178.0, 67.8), (200.0, 126.0, 73.0), (202.0, 132.0, 72.9),
(189.0, 140.0, 70.3), (188.0, 140.0, 69.6), (189.0, 141.0, 68.9), (194.0, 142.0, 68.1), (202.0, 147.0, 67.8),
(210.0, 153.0, 67.6), (237.0, 173.0, 68.5), (239.0, 174.0, 68.9), (169.0, 105.0, 69.7), (169.0, 105.0, 69.8),
(171.0, 107.0, 69.8), (185.0, 117.0, 69.6), (192.0, 121.0, 69.5), (210.0, 134.0, 69.4), (204.0, 131.0, 69.3),
(189.0, 130.0, 68.6), (184.0, 131.0, 68.0), (183.0, 132.0, 67.3), (201.0, 143.0, 66.2), (182.0, 120.0, 65.8),
(182.0, 121.0, 66.1), (184.0, 124.0, 66.1), (182.0, 124.0, 65.7), (182.0, 125.0, 65.5), (180.0, 124.0, 64.9),
(183.0, 128.0, 64.5), (181.0, 127.0, 64.1), (180.0, 126.0, 63.3), (185.0, 129.0, 63.0), (186.0, 128.0, 62.7),
(208.0, 142.0, 62.6), (199.0, 99.5, 63.0), (208.0, 123.0, 64.0), (245.0, 150.0, 69.0), (234.0, 142.0, 70.7),
(246.0, 126.0, 74.5), (244.0, 130.0, 75.3), (256.0, 145.0, 76.7), (267.0, 166.0, 79.9)]
X = np.vstack((np.array(Class1), np.array(Class2), np.array(Class3))).astype(np.float32)
y = np.hstack((np.zeros(len(Class1)), np.ones(len(Class2)) * 1, np.ones(len(Class3)) * 2)).astype(
np.int32) # 类别标签编码为0, 1, 2
# 对标签进行one-hot编码
y = to_categorical(y) # 无需减1,因为直接使用0, 1, 2编码
# 数据标准化
scaler = StandardScaler()
X = scaler.fit_transform(X) # 直接对整个数据集进行fit_transform
# 定义网络结构
model = Sequential()
model.add(Dense(10, input_dim=3, activation='relu')) # 假设有一个隐藏层,10个神经元
model.add(Dense(3, activation='softmax')) # 输出层,3个神经元用于三分类,使用softmax激活函数
model.compile(
loss='categorical_crossentropy', # 适用于多分类任务
optimizer='adam', # 自适应学习率优化器
metrics=['accuracy'] # 监控准确率
)
model.fit(X, y, epochs=100, batch_size=10) # 训练100轮,每批10个样本
# 保存模型
model.save('my_model.keras')
# 保存scaler
dump(scaler, 'my_scaler.joblib') # 保存scaler对象
print("保存完成")
运行后:

使用模型
python
from joblib import load
from keras.models import load_model
import numpy as np
# 加载模型
model = load_model('my_model.keras')
scaler = load('my_scaler.joblib') # 加载scaler对象
#无火
test_data_c1 = [(15.8, 9.79, 21.0), (16.1, 11.5, 21.3), (17.5, 13.5, 23.1), (7.7, 13.2, 23.1),(32.8, 25.8, 24.9),
(33.5, 20.6, 25.2),(35.0, 25.0, 29.2), (34.9, 24.6, 29.4), (36.8, 25.7, 31.6), (37, 16, 21)]
#阴燃
test_data_c2=[(47.9, 24.7, 31.4), (47.1, 24.3, 31.4), (54.7, 28.8, 31.3), (46.8, 30.0, 31.4), (60.9, 31.9, 31.9),
(62.0, 32.5, 32.1),(34.5, 40.5, 35.9), (91.9, 49.1, 36.5),(130.0, 70.8, 47.5), (125.0, 68.1, 47.5), ]
#明火
test_data_c3=[(60.0, 138.0, 72.4), (121.0, 120.0, 71.8),(128.0, 164.0, 67.5), (235.0, 171.0, 68.1),
(195.0, 124.0, 69.5), (205.0, 131.0, 69.4),(180.0, 123.0, 66.1), (181.0, 124.0, 65.9),
(207.0, 167.0, 71.3), (206.0, 147.0, 72.4)]
# 合并测试数据
test_data = np.vstack((np.array(test_data_c1), np.array(test_data_c2), np.array(test_data_c3)))
test_data_scaled = scaler.transform(test_data)
# # 对测试数据进行预测
# predictions = model.predict(test_data_scaled)
#
# # 将预测结果转换为类别标签(取概率最高的类别)
# predicted_classes = np.argmax(predictions, axis=1)
#
# # 打印预测结果
# for i, predicted_class in enumerate(predicted_classes):
# if i < len(test_data_c1):
# print(f"Sample {i + 1} from Class 1 (No Fire) predicted as: {predicted_class}")
# elif i < len(test_data_c1) + len(test_data_c2):
# print(f"Sample {i + 1} from Class 2 (Smoldering Fire) predicted as: {predicted_class}")
# else:
# print(f"Sample {i + 1} from Class 3 (Flaming Fire) predicted as: {predicted_class}")
# 对测试数据进行预测
predictions = model.predict(test_data_scaled)
# 将预测结果转换为类别标签(取概率最高的类别)
predicted_classes = np.argmax(predictions, axis=1)
# 打印预测结果和概率
for i, (predicted_probabilities, predicted_class) in enumerate(zip(predictions, predicted_classes)):
class_label = "Class 1 (No Fire)" if i < len(test_data_c1) else \
"Class 2 (Smoldering Fire)" if i < len(test_data_c1) + len(test_data_c2) else \
"Class 3 (Flaming Fire)"
print(
f"Sample {i + 1} from {class_label} predicted as class: {predicted_class}, probabilities: {predicted_probabilities}")
4.随机森林
训练模型:
python
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from joblib import dump
#三特征三分类问题
#三特征为(温度、CO、烟雾),类别为无火,阴燃火,明火
#无火
Class1 = [(0.0, 0.0, 20.0), (11.1, 5.27, 20.5), (12.4, 6.04, 20.6), (14.2, 7.68, 20.7), (15.2, 8.73, 20.9),
(16.1, 12.2, 21.4), (16.0, 12.8, 21.6), (16.4, 13.4, 21.9), (16.6, 13.3, 22.0), (16.7, 12.9, 22.2),
(16.6, 12.1, 22.4), (16.9, 11.5, 22.6), (17.1, 10.9, 22.7), (19.6, 10.0, 22.8), (18.6, 15.4, 22.9),
(17.6, 15.0, 23.0), (15.5, 13.8, 23.0), (17.9, 13.1, 23.1), (18.2, 13.2, 23.1), (18.3, 13.3, 23.1),
(18.5, 13.7, 23.2), (18.2, 14.0, 23.2), (18.3, 14.2, 23.2), (17.7, 17.0, 23.3), (18.3, 15.3, 23.3),
(19.1, 15.7, 23.4), (20.8, 16.6, 23.5), (21.9, 17.2, 23.6), (23.2, 17.9, 23.7), (26.5, 13.6, 24.0),
(28.8, 14.8, 24.1), (32.1, 16.5, 24.3), (32.4, 21.9, 24.6), (34.0, 26.1, 26.0), (34.4, 28.0, 26.4),
(33.9, 29.1, 26.8), (34.2, 29.7, 27.7), (34.4, 29.0, 28.1), (34.8, 27.9, 28.4), (34.9, 25.9, 28.9), (35.3, 24.9, 29.9),
(35.6, 25.5, 30.1), (35.8, 26.5, 30.3), (35.9, 28.8, 30.8), (36.1, 29.4, 31.0), (36.3, 29.4, 31.2),
(37.2, 27.0, 31.8), (39.7, 20.5, 31.6), (37.9, 19.5, 31.5), (38.5, 19.8, 31.4)]
#阴燃火
Class2=[(40.9, 21.1, 31.2), (43.1, 22.3, 31.2), (45.5, 23.6, 31.2), (49.0, 23.9, 31.2), (50.0, 24.3, 31.2),
(50.5, 24.8, 31.3), (50.4, 25.2, 31.3), (50.0, 25.4, 31.4), (49.3, 25.4, 31.4), (46.7, 24.1, 31.4),
(45.9, 24.2, 31.4), (45.6, 23.9, 31.4), (45.3, 23.8, 31.4), (44.8, 23.5, 31.3), (44.7, 23.7, 31.3),
(44.8, 23.9, 31.2), (46.3, 24.3, 31.2), (47.5, 24.9, 31.1), (49.1, 25.8, 31.1), (51.9, 27.3, 31.2),
(53.4, 28.1, 31.2), (57.7, 30.2, 31.5), (58.5, 30.6, 31.6), (59.8, 31.3, 31.7), (60.4, 31.6, 31.8),
(62.6, 32.8, 32.2), (63.4, 33.3, 32.3), (65.5, 34.4, 32.5), (67.0, 35.2, 32.6), (68.9, 36.2, 32.8),
(73.6, 38.8, 33.1), (76.3, 40.3, 33.2), (80.0, 42.3, 33.4), (88.5, 47.1, 33.8), (91.2, 48.6, 34.0),
(93.0, 49.6, 34.2), (93.0, 49.7, 34.7), (91.7, 49.0, 34.9), (89.9, 47.9, 35.0), (84.6, 45.0, 35.4),
(82.1, 43.6, 35.5), (79.8, 42.3, 35.6), (76.6, 40.6, 35.7), (76.1, 40.3, 35.8), (76.0, 40.2, 35.8),
(77.7, 41.1, 36.0), (79.3, 42.0, 36.0), (84.3, 44.8, 36.2), (87.7, 46.8, 36.3), (102.0, 54.7, 36.9),
(107.0, 57.5, 37.2), (110.0, 59.8, 37.4), (114.0, 62.1, 38.1), (114.0, 61.6, 38.4), (112.0, 60.4, 38.6),
(106.0, 57.4, 39.0), (104.0, 56.1, 39.2), (103.0, 55.2, 39.3), (106.0, 57.0, 39.6), (110.0, 59.7, 39.8),
(117.0, 63.6, 40.1), (131.0, 71.9, 40.7), (136.0, 75.1, 41.1), (140.0, 77.4, 41.5), (142.0, 78.9, 42.3),
(142.0, 78.4, 42.7), (140.0, 77.7, 43.2), (141.0, 77.8, 43.9), (142.0, 78.7, 44.3), (143.0, 79.3, 44.6),
(140.0, 77.5, 45.3), (138.0, 76.3, 45.6), (138.0, 76.0, 45.9), (136.0, 75.3, 46.3), (132.0, 72.8, 46.5),
(126.0, 69.4, 46.6), (118.0, 64.6, 46.7), (116.0, 63.1, 46.6), (114.0, 62.1, 46.6), (112.0, 60.3, 46.4),
(110.0, 59.4, 46.3), (108.0, 58.4, 46.2), (105.0, 56.6, 45.9), (104.0, 55.7, 45.7), (103.0, 55.3, 45.6),
(105.0, 56.3, 45.3), (108.0, 58.0, 45.2), (112.0, 60.3, 45.1), (122.0, 66.6, 45.1), (129.0, 70.4, 45.3),
(135.0, 74.2, 45.4), (151.0, 83.4, 46.0), (153.0, 84.8, 46.3), (152.0, 84.2, 46.6), (146.0, 80.3, 47.2),
(142.0, 78.1, 47.3), (138.0, 75.9, 47.4), (122.0, 66.1, 47.5), (127.0, 73.2, 47.3), (125.0, 71.9, 47.2),
(122.0, 70.4, 47.0), (128.0, 78.3, 46.7), (129.0, 78.5, 46.6), (175.0, 79.6, 46.5), (139.0, 84.6, 46.4),
(147.0, 89.4, 46.4) ]
#明火
Class3=[(158.0, 95.8, 46.6), (172.0, 100.0, 47.1), (175.0, 109.0, 47.6), (158.0, 77.0, 48.2), (172.0, 86.0, 49.6),
(179.0, 95.0, 50.4), (171.0, 89.0, 51.0), (166.0, 93.3, 51.8), (158.0, 88.5, 52.0), (153.0, 85.2, 52.2),
(152.0, 84.7, 52.4), (159.0, 89.1, 52.6), (170.0, 96.1, 52.9), (173.0, 95.0, 53.9), (169.0, 92.0, 54.3),
(171.0, 96.6, 54.6), (163.0, 95.5, 54.9), (155.0, 90.6, 54.9), (159.0, 86.0, 54.8), (161.0, 78.3, 54.6),
(152.0, 76.9, 54.4), (159.0, 78.2, 54.3), (162.0, 85.1, 54.2), (156.0, 80.7, 54.3), (157.0, 87.4, 54.5),
(177.0, 100.0, 55.0), (188.0, 107.0, 55.5), (196.0, 113.0, 56.0), (208.0, 121.0, 57.4), (205.0, 119.0, 58.0),
(199.0, 115.0, 58.6), (180.0, 103.0, 59.7), (175.0, 99.5, 60.1), (177.0, 101.0, 60.4), (195.0, 113.0, 61.2),
(205.0, 119.0, 61.8), (199.0, 125.0, 62.4), (201.0, 132.0, 63.9), (208.0, 136.0, 64.6), (209.0, 144.0, 65.4),
(201.0, 150.0, 67.3), (200.0, 136.0, 68.0), (201.0, 117.0, 68.3), (169.0, 111.0, 67.3), (184.0, 136.0, 66.3),
(172.0, 129.0, 65.3), (178.0, 141.0, 62.9), (189.0, 151.0, 61.7), (205.0, 165.0, 60.7), (200.0, 153.0, 59.2),
(193.0, 145.0, 58.6), (203.0, 155.0, 58.2), (189.0, 143.0, 57.7), (189.0, 143.0, 57.6), (192.0, 145.0, 57.5),
(194.0, 142.0, 57.5), (161.0, 107.0, 57.6), (166.0, 110.0, 57.7), (173.0, 116.0, 57.9), (177.0, 119.0, 58.0),
(181.0, 122.0, 58.1), (190.0, 129.0, 58.2), (193.0, 132.0, 58.2), (195.0, 134.0, 58.2), (197.0, 137.0, 58.1),
(201.0, 141.0, 58.0), (199.0, 135.0, 58.0), (195.0, 120.0, 58.5), (198.0, 116.0, 59.3), (200.0, 138.0, 60.5),
(202.0, 167.0, 64.0), (201.0, 176.0, 65.9), (204.0, 178.0, 67.8), (200.0, 126.0, 73.0), (202.0, 132.0, 72.9),
(189.0, 140.0, 70.3), (188.0, 140.0, 69.6), (189.0, 141.0, 68.9), (194.0, 142.0, 68.1), (202.0, 147.0, 67.8),
(210.0, 153.0, 67.6), (237.0, 173.0, 68.5), (239.0, 174.0, 68.9), (169.0, 105.0, 69.7), (169.0, 105.0, 69.8),
(171.0, 107.0, 69.8), (185.0, 117.0, 69.6), (192.0, 121.0, 69.5), (210.0, 134.0, 69.4), (204.0, 131.0, 69.3),
(189.0, 130.0, 68.6), (184.0, 131.0, 68.0), (183.0, 132.0, 67.3), (201.0, 143.0, 66.2), (182.0, 120.0, 65.8),
(182.0, 121.0, 66.1), (184.0, 124.0, 66.1), (182.0, 124.0, 65.7), (182.0, 125.0, 65.5), (180.0, 124.0, 64.9),
(183.0, 128.0, 64.5), (181.0, 127.0, 64.1), (180.0, 126.0, 63.3), (185.0, 129.0, 63.0), (186.0, 128.0, 62.7),
(208.0, 142.0, 62.6), (199.0, 99.5, 63.0), (208.0, 123.0, 64.0), (245.0, 150.0, 69.0), (234.0, 142.0, 70.7),
(246.0, 126.0, 74.5), (244.0, 130.0, 75.3), (256.0, 145.0, 76.7), (267.0, 166.0, 79.9)]
# 合并数据和标签
X = np.vstack((np.array(Class1), np.array(Class2), np.array(Class3))).astype(np.float32)
y = np.hstack((np.zeros(len(Class1)), np.ones(len(Class2)), 2 * np.ones(len(Class3)))).astype(
np.int32) # 类别标签编码为0, 1, 2
# 创建随机森林分类器
clf = RandomForestClassifier(n_estimators=100, random_state=42) # n_estimators是决策树的数量
# 训练模型
clf.fit(X, y)
# 保存模型
dump(clf, 'my_random_forest.joblib') # 保存为joblib文件
print("保存完成")
使用模型:
python
from joblib import load
import numpy as np
# 加载模型
model = load('my_random_forest.joblib')
#无火
test_data_c1 = [ (15.8, 9.79, 21.0), (16.1, 11.5, 21.3), (17.5, 13.5, 23.1), (7.7, 13.2, 23.1),(32.8, 25.8, 24.9),
(33.5, 20.6, 25.2),(35.0, 25.0, 29.2), (34.9, 24.6, 29.4), (36.8, 25.7, 31.6), (0, 0, 21)]
#阴燃
test_data_c2=[(47.9, 24.7, 31.4), (47.1, 24.3, 31.4), (54.7, 28.8, 31.3), (56.8, 30.0, 31.4), (60.9, 31.9, 31.9),
(62.0, 32.5, 32.1),(76.5, 40.5, 35.9), (91.9, 49.1, 36.5),(130.0, 70.8, 47.5), (125.0, 68.1, 47.5), ]
#明火
test_data_c3=[(201.0, 138.0, 72.4), (191.0, 120.0, 71.8),(228.0, 164.0, 67.5), (235.0, 171.0, 68.1),
(195.0, 124.0, 69.5), (205.0, 131.0, 69.4),(180.0, 123.0, 66.1), (181.0, 124.0, 65.9),
(207.0, 167.0, 71.3), (206.0, 147.0, 72.4)]
# 测试数据
# 注意:这里我们将列表的列表转换为 NumPy 数组
test_data_c1 = np.array(test_data_c1).astype(np.float32)
test_data_c2 = np.array(test_data_c2).astype(np.float32)
test_data_c3 = np.array(test_data_c3).astype(np.float32)
# 合并所有测试数据(可选,但通常更方便一次性处理)
# 如果模型是用标准化数据训练的,你需要使用保存的scaler来标准化测试数据
# 但从你的代码来看,你没有进行标准化,所以我们不需要scaler
test_data = np.vstack((test_data_c1, test_data_c2, test_data_c3))
# 使用模型进行预测
predictions = model.predict(test_data)
# 打印预测结果
print("Predictions:", predictions)
# 如果你想要将预测结果对应回原来的类别(Class1, Class2, Class3),你可以这样做:
# 创建一个空的列表来存储带有类别标签的预测结果
predicted_classes = []
for i, pred in enumerate(predictions):
if i < len(test_data_c1):
predicted_classes.append(f"Class1: {pred}")
elif i < len(test_data_c1) + len(test_data_c2):
predicted_classes.append(f"Class2: {pred}")
else:
predicted_classes.append(f"Class3: {pred}")
# 打印带有类别标签的预测结果
print("Predicted Classes:", predicted_classes)
5.总结:
本文基于温度、CO浓度和烟雾三个特征,采用三种机器学习方法(SVM、神经网络和随机森林)实现火灾检测的三分类任务(无火、阴燃火、明火)。通过采集大量样本数据,分别构建了三种分类模型,并详细介绍了模型训练和预测过程。实验结果表明,这些方法能有效区分不同火灾状态,其中SVM采用RBF核函数,神经网络使用单隐藏层结构,随机森林则通过多决策树集成实现分类。每种方法均提供模型保存和加载功能,便于实际部署应用。该研究为基于多参数传感的火灾智能检测提供了可行方案