Python-JupyterGPU机器学习代码

使用 Jupyter Notebook 进行 GPU 加速的机器学习代码开发,通常涉及到利用 GPU 运行深度学习模型,特别是基于 TensorFlow 或 PyTorch 这样的深度学习框架。GPU 的并行计算能力可以显著加快模型训练的速度,尤其对于大规模数据集和复杂模型来说效果更为明显。

根据需要训练的数据进行机器学习建模

python 复制代码
import time
 
start = time.time()

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Activation, Dropout, Flatten, Dense
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
import matplotlib.pyplot as plt
from PIL import Image 
from glob import glob
import os

import os 
print(os.system('ls /dataset/stress_test_data_2'))

train_path = '/dataset/stress_test_data_2/Training/'
test_path = '/dataset/stress_test_data_2/Test/'

img = load_img(train_path + "Apple Braeburn/0_100.jpg", target_size=(100,100))
plt.imshow(img)
plt.axis("off")
plt.show()

images = ['Orange', 'Banana', 'Cauliflower', 'Cactus fruit', 'Eggplant', 'Avocado', 'Blueberry','Lemon', 'Kiwi']

import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize =(15,5))
for i in range(9):
    ax = fig.add_subplot(3,3,i+1,xticks=[],yticks=[])
    #fig.patch.set_facecolor('#E53090')
    #Above code adds a background color for subplots you can change the hex color code as you wish
    plt.title(images[i])
    plt.axis("off")
    ax.imshow(load_img(train_path + images[i] +"/0_100.jpg", target_size=(100,100)))

x = img_to_array(img)
print(x.shape)

className = glob(train_path + '/*')
number_of_class = len(className)
print(number_of_class)

model = Sequential()
model.add(Conv2D(32, (3,3), input_shape= x.shape))
model.add(Activation("relu"))
model.add(MaxPooling2D())

model.add(Conv2D(32, (3,3),))
model.add(Activation("relu"))
model.add(MaxPooling2D())

model.add(Conv2D(64, (3,3),))
model.add(Activation("relu"))
model.add(MaxPooling2D())

model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))
model.add(Dropout(0.5))
model.add(Dense(number_of_class))#output
model.add(Activation("softmax"))

model.compile(loss = "categorical_crossentropy",
             optimizer = "rmsprop",
             metrics = ["accuracy"])

model.summary()

batch_size = 32


train_datagen = ImageDataGenerator(rescale = 1./255,
                  shear_range = 0.3,
                  horizontal_flip=True,
                  vertical_flip=False,
                  zoom_range = 0.3
                  )
test_datagen  = ImageDataGenerator(rescale = 1./255)

train_generator = train_datagen.flow_from_directory(train_path,
                                                    target_size=x.shape[:2],
                                                    batch_size = batch_size,
                                                    color_mode= "rgb",
                                                    class_mode = "categorical")
test_generator = test_datagen.flow_from_directory(test_path,
                                                    target_size=x.shape[:2],
                                                    batch_size = batch_size,
                                                    color_mode= "rgb",
                                                    class_mode = "categorical")

hist = model.fit_generator(generator = train_generator, 
                   steps_per_epoch = 1600 // batch_size,
                   epochs = 50,
                   validation_data = test_generator,
                   validation_steps = 800 // batch_size)


print(hist.history.keys())

plt.plot(hist.history["loss"], label = "Train Loss")
plt.plot(hist.history["val_loss"], label = "Validation Loss")
plt.legend()
plt.show()

plt.plot(hist.history["accuracy"], label = "Train Accuracy")
plt.plot(hist.history["val_accuracy"], label = "Validation Accuracy")
plt.legend()
plt.show()

end = time.time()
 
print("time cost",end - start)

在 Jupyter Notebook 中进行 GPU 加速的机器学习代码开发,可以带来训练速度的显著提升,尤其适用于大规模数据和复杂模型的场景。同时,结合 Jupyter Notebook 的交互式编程和展示优势,可以更方便地进行实验、调试和结果展示。如果你需要进一步的帮助或有其他问题,请随时告诉我。

相关推荐
红色石榴30 分钟前
Qt中文乱码解决
开发语言·qt
Htht11131 分钟前
【Qt】实现模拟触摸屏 上下滑动表格 的两种方式
开发语言·qt
A 八方31 分钟前
Python MongoDB
开发语言·python·mongodb
sz66cm2 小时前
Python基础 -- 使用Python实现ssh终端并实现数据处理与统计功能
开发语言·python·ssh
liangbm34 小时前
MATLAB系列02:MATLAB基础
开发语言·数据结构·笔记·matlab·教程·工程基础·高级绘图
ac-er88885 小时前
如何在Flask中实现国际化和本地化
后端·python·flask
Adolf_19935 小时前
Flask-WTF的使用
后端·python·flask
空城皆是旧梦5 小时前
python爬虫初体验(一)
爬虫·python
藓类少女5 小时前
正则表达式
数据库·python·mysql·正则表达式
change95135 小时前
PHP纯离线搭建(php 8.1.7)
开发语言·php