【一起学NLP】Chapter3-使用神经网络解决问题

目录

使用神经网络解决问题

python 复制代码
import sys
sys.path.append('..') # 为了引入父目录的文件而进行的设定
from dataset import spiral
import matplotlib.pyplot as plt

x,t = spiral.load_data()
print('x',x.shape)# (300,2)
print('y',t.shape)# (300,3)
复制代码
x (300, 2)
y (300, 3)

在上面的例子中,要从ch01目录的dataset目录引入spiral.py。因此,上面的代码通过sys.path.append('...')将父目录添加到了import的检索路径中。

python 复制代码
print(x)
print(t)
复制代码
[[-0.00000000e+00  0.00000000e+00]
 [-9.76986432e-04  9.95216044e-03]
 [ 5.12668241e-03  1.93317647e-02]
 [-3.86043324e-04  2.99975161e-02]
 [ 1.42509650e-02  3.73752591e-02]
 [ 9.41914082e-04  4.99911272e-02]
 [ 2.25361319e-02  5.56068589e-02]
 [ 6.52848904e-03  6.96948982e-02]
 [ 2.50649535e-02  7.59720219e-02]
 [ 2.03287580e-02  8.76740646e-02]
 [ 5.98440862e-02  8.01166983e-02]
 [ 6.19050693e-02  9.09272368e-02]
 [ 3.22809763e-02  1.15576549e-01]
 [ 8.28423530e-02  1.00185551e-01]
 [ 1.09856959e-01  8.67839183e-02]
 [ 9.33208222e-02  1.17436043e-01]
 [ 7.82976217e-02  1.39533087e-01]
 [ 1.23994559e-01  1.16298535e-01]
 [ 8.06199110e-02  1.60936105e-01]
 [ 1.39235917e-01  1.29280158e-01]
 [ 1.53599653e-01  1.28090384e-01]
 [ 1.38981638e-01  1.57429680e-01]
 [ 1.89873275e-01  1.11122183e-01]
 [ 1.41160729e-01  1.81586477e-01]
 [ 1.50631465e-01  1.86842613e-01]
 [ 1.71714639e-01  1.81697778e-01]
 [ 2.10050449e-01  1.53227963e-01]
 [ 2.32019716e-01  1.38082770e-01]
 [ 2.31219125e-01  1.57916802e-01]
 [ 2.51798226e-01  1.43866791e-01]
 [ 2.59928236e-01  1.49790894e-01]
 [ 3.03168799e-01  6.47200053e-02]
 [ 2.99776174e-01  1.11956445e-01]
 [ 2.71902926e-01  1.86999462e-01]
 [ 3.39649101e-01  1.54430706e-02]
 [ 3.41955960e-01  7.46064458e-02]
 [ 3.56153446e-01  5.24854558e-02]
 [ 3.69893826e-01 -8.86325970e-03]
 [ 3.78749506e-01  3.08027885e-02]
 [ 3.89323347e-01 -2.29636922e-02]
 [ 3.97756477e-01 -4.23058548e-02]
 [ 3.97790750e-01 -9.93102174e-02]
 [ 4.19082447e-01 -2.77471188e-02]
 [ 4.26755751e-01 -5.27212418e-02]
 [ 4.27342016e-01 -1.04779774e-01]
 [ 4.37333586e-01 -1.06015726e-01]
 [ 4.54778403e-01 -6.91129800e-02]
 [ 4.33582415e-01 -1.81400909e-01]
 [ 4.75290298e-01 -6.70755731e-02]
 [ 4.36148818e-01 -2.23325341e-01]
 [ 2.88929307e-01 -4.08068444e-01]
 [ 4.69656878e-01 -1.98802456e-01]
 [ 3.86773022e-01 -3.47572480e-01]
 [ 5.00808621e-01 -1.73466783e-01]
 [ 4.24431168e-01 -3.33853536e-01]
 [ 4.24040623e-01 -3.50270681e-01]
 [ 2.80512809e-01 -4.84677794e-01]
 [ 5.30623098e-01 -2.08180518e-01]
 [ 4.11038907e-01 -4.09202904e-01]
 [ 4.62137267e-01 -3.66782151e-01]
 [ 3.95281189e-01 -4.51389834e-01]
 [ 3.67613445e-01 -4.86785738e-01]
 [ 4.94556678e-01 -3.73916691e-01]
 [ 4.35611632e-01 -4.55129109e-01]
 [ 3.28234290e-01 -5.49419922e-01]
 [ 2.78014129e-01 -5.87544164e-01]
 [ 3.04441838e-01 -5.85589590e-01]
 [-3.84353415e-02 -6.68896647e-01]
 [ 4.32663510e-01 -5.24597262e-01]
 [ 3.67995385e-01 -5.83677477e-01]
 [ 1.62513204e-01 -6.80874040e-01]
 [-7.01870180e-02 -7.06522316e-01]
 [ 1.20922889e-01 -7.09772960e-01]
 [ 1.81179805e-01 -7.07159019e-01]
 [ 4.37091210e-02 -7.38708002e-01]
 [ 1.03491018e-01 -7.42825423e-01]
 [-6.40535178e-03 -7.59973007e-01]
 [-1.70362786e-01 -7.50917120e-01]
 [ 1.37916439e-01 -7.67710268e-01]
 [-1.54966573e-02 -7.89847994e-01]
 [-4.01784946e-02 -7.98990418e-01]
 [-2.98916375e-01 -7.52827338e-01]
 [-1.88135777e-01 -7.98125886e-01]
 [-1.16656887e-03 -8.29999180e-01]
 [-2.86409662e-01 -7.89664173e-01]
 [-1.45414274e-01 -8.37469217e-01]
 [-3.08244630e-01 -8.02860666e-01]
 [-1.13882388e-01 -8.62514233e-01]
 [-1.97807480e-01 -8.57480146e-01]
 [-7.28153677e-01 -5.11754065e-01]
 [-5.28735647e-01 -7.28312169e-01]
 [-1.95859389e-01 -8.88672662e-01]
 [-6.55599644e-01 -6.45437144e-01]
 [-4.36321749e-01 -8.21293694e-01]
 [-3.97220716e-01 -8.51948181e-01]
 [-7.40823025e-01 -5.94711060e-01]
 [-7.83135284e-01 -5.55246906e-01]
 [-6.83380750e-01 -6.88397233e-01]
 [-6.79385602e-01 -7.06282665e-01]
 [-5.87009159e-01 -7.97195238e-01]
 [-0.00000000e+00 -0.00000000e+00]
 [-8.37020532e-03 -5.47171481e-03]
 [-1.83996723e-02 -7.83913638e-03]
 [-2.38615572e-02 -1.81831264e-02]
 [-2.71639022e-02 -2.93619212e-02]
 [-4.67712876e-02 -1.76761608e-02]
 [-5.13932286e-02 -3.09634633e-02]
 [-6.72785631e-02 -1.93286044e-02]
 [-7.58815643e-02 -2.53374861e-02]
 [-8.87183570e-02 -1.51345014e-02]
 [-8.95920555e-02 -4.44214317e-02]
 [-1.05681839e-01 -3.05180093e-02]
 [-1.12859458e-01 -4.07767431e-02]
 [-1.29959580e-01 -3.24155301e-03]
 [-1.39813880e-01 -7.21657356e-03]
 [-1.45884653e-01 -3.48951009e-02]
 [-1.58860475e-01 -1.90617278e-02]
 [-1.69789215e-01  8.46300817e-03]
 [-1.79892182e-01  6.22920756e-03]
 [-1.82531530e-01  5.27469495e-02]
 [-1.96787914e-01  3.57003745e-02]
 [-1.78610713e-01  1.10445521e-01]
 [-2.14327911e-01  4.96341281e-02]
 [-2.23948847e-01  5.24110098e-02]
 [-2.39863331e-01  8.09829103e-03]
 [-2.45011435e-01  4.96930264e-02]
 [-2.33354905e-01  1.14653775e-01]
 [-2.50877484e-01  9.98022449e-02]
 [-2.77281919e-01  3.89196302e-02]
 [-2.70172382e-01  1.05389203e-01]
 [-2.92387849e-01  6.71516607e-02]
 [-2.73972209e-01  1.45049056e-01]
 [-2.85671818e-01  1.44192969e-01]
 [-3.21248452e-01  7.54945832e-02]
 [-3.07272511e-01  1.45545883e-01]
 [-2.39745837e-01  2.54993988e-01]
 [-2.48112212e-01  2.60845415e-01]
 [-1.51453586e-01  3.37582303e-01]
 [-2.67068820e-01  2.70322484e-01]
 [-2.89146403e-01  2.61714267e-01]
 [-1.38404573e-01  3.75292118e-01]
 [-2.47585805e-01  3.26804635e-01]
 [-2.18612860e-01  3.58620158e-01]
 [-1.25241675e-01  4.11356929e-01]
 [-1.95264902e-01  3.94298895e-01]
 [-1.46778711e-01  4.25389245e-01]
 [-1.04988763e-01  4.47858638e-01]
 [-2.68359413e-01  3.85853892e-01]
 [-1.93548031e-01  4.39248403e-01]
 [-8.77361039e-02  4.82081296e-01]
 [-1.39120457e-01  4.80255659e-01]
 [-8.31852186e-02  5.03170169e-01]
 [-7.48872073e-02  5.14579349e-01]
 [-5.70803544e-03  5.29969262e-01]
 [-1.58687180e-02  5.39766786e-01]
 [-1.49002564e-01  5.29431994e-01]
 [-1.98435901e-03  5.59996484e-01]
 [-2.08488575e-01  5.30502134e-01]
 [ 5.65227669e-02  5.77239272e-01]
 [-2.89358604e-02  5.89290010e-01]
 [ 1.25355588e-01  5.86758874e-01]
 [ 6.26495621e-02  6.06774285e-01]
 [ 3.01335207e-01  5.41846005e-01]
 [ 2.75775684e-01  5.66434261e-01]
 [ 1.32352421e-01  6.26165183e-01]
 [ 3.47342639e-01  5.49411586e-01]
 [ 3.13540745e-01  5.80768630e-01]
 [-1.44067966e-01  6.54327457e-01]
 [ 5.02643212e-01  4.57984499e-01]
 [ 3.70493915e-01  5.82094716e-01]
 [ 3.49928740e-01  6.06258919e-01]
 [ 4.93501165e-01  5.10447451e-01]
 [ 3.63212538e-01  6.21672464e-01]
 [ 6.62295009e-01  3.07026581e-01]
 [ 3.18207549e-01  6.68089781e-01]
 [ 4.46081933e-01  6.02918659e-01]
 [ 6.18103628e-01  4.42207988e-01]
 [ 6.07553924e-01  4.73052037e-01]
 [ 5.05555512e-01  5.93981165e-01]
 [ 2.72157190e-01  7.41640387e-01]
 [ 5.70347440e-01  5.60984668e-01]
 [ 6.05299121e-01  5.38249918e-01]
 [ 6.46809095e-01  5.04021819e-01]
 [ 7.10706203e-01  4.28715165e-01]
 [ 7.35105615e-01  4.06472305e-01]
 [ 6.74187768e-01  5.17659012e-01]
 [ 8.14670518e-01  2.75521227e-01]
 [ 8.45117378e-01  2.06583196e-01]
 [ 6.40506934e-01  6.03449142e-01]
 [ 8.66120175e-01  2.04782425e-01]
 [ 8.96181358e-01  8.28189190e-02]
 [ 8.33525542e-01  3.65150888e-01]
 [ 9.18575330e-01  5.11797183e-02]
 [ 9.22965423e-01  1.14170172e-01]
 [ 9.36555210e-01  8.04011093e-02]
 [ 9.08727618e-01  2.76973131e-01]
 [ 9.50896208e-01 -1.31895416e-01]
 [ 9.17784255e-01  3.13961879e-01]
 [ 9.33629259e-01 -2.97886567e-01]
 [ 9.88995443e-01 -4.45871499e-02]
 [ 0.00000000e+00 -0.00000000e+00]
 [ 9.79880762e-03 -1.99583797e-03]
 [ 1.99940729e-02  4.86876255e-04]
 [ 2.61662525e-02 -1.46740326e-02]
 [ 3.84511087e-02 -1.10232589e-02]
 [ 4.95890467e-02 -6.39737823e-03]
 [ 5.99771637e-02 -1.65524386e-03]
 [ 6.21697472e-02 -3.21702119e-02]
 [ 7.43436411e-02 -2.95469632e-02]
 [ 8.06974118e-02 -3.98488109e-02]
 [ 8.14343751e-02 -5.80382853e-02]
 [ 9.22001492e-02 -5.99927703e-02]
 [ 9.37395802e-02 -7.49192306e-02]
 [ 9.60790524e-02 -8.75717745e-02]
 [ 1.03261762e-01 -9.45357529e-02]
 [ 1.06370805e-01 -1.05760351e-01]
 [ 1.37466576e-01 -8.18714884e-02]
 [ 1.01082273e-01 -1.36683481e-01]
 [ 1.10151505e-01 -1.42360971e-01]
 [ 9.16540054e-02 -1.66431798e-01]
 [ 1.29709392e-01 -1.52234929e-01]
 [ 3.26863586e-02 -2.07440599e-01]
 [ 1.37426305e-01 -1.71796422e-01]
 [ 1.50106931e-01 -1.74263907e-01]
 [ 1.32737490e-01 -1.99951891e-01]
 [ 1.34246622e-01 -2.10897711e-01]
 [ 9.69068794e-02 -2.41265532e-01]
 [ 2.21194549e-02 -2.69092419e-01]
 [ 1.03753054e-01 -2.60067883e-01]
 [ 3.87351563e-02 -2.87401440e-01]
 [ 1.55776942e-01 -2.56385539e-01]
 [ 5.05915167e-02 -3.05843912e-01]
 [ 4.49940024e-02 -3.16820990e-01]
 [-1.74914339e-02 -3.29536113e-01]
 [ 4.91708033e-03 -3.39964443e-01]
 [-4.84363632e-02 -3.46632253e-01]
 [ 1.94660931e-02 -3.59473325e-01]
 [-1.00885395e-01 -3.55980529e-01]
 [ 2.46037929e-03 -3.79992035e-01]
 [-4.68117997e-02 -3.87180391e-01]
 [-9.43358132e-02 -3.88716805e-01]
 [-3.36669158e-02 -4.08615392e-01]
 [-1.88529982e-01 -3.75308468e-01]
 [-5.06039869e-02 -4.27011986e-01]
 [-1.62425981e-01 -4.08922732e-01]
 [-2.12648211e-01 -3.96586357e-01]
 [-8.03418780e-02 -4.52929556e-01]
 [-1.52934025e-01 -4.44422304e-01]
 [-2.65614079e-01 -3.99811407e-01]
 [-2.74970348e-01 -4.05575280e-01]
 [-2.34573442e-01 -4.41560076e-01]
 [-3.78612687e-01 -3.41690551e-01]
 [-3.34460262e-01 -3.98166213e-01]
 [-2.65222370e-01 -4.58865007e-01]
 [-3.32119042e-01 -4.25789786e-01]
 [-4.12557377e-01 -3.63725735e-01]
 [-4.02002778e-01 -3.89863779e-01]
 [-3.04770028e-01 -4.81679593e-01]
 [-4.50761110e-01 -3.64985508e-01]
 [-5.66743279e-01 -1.64018461e-01]
 [-4.15392351e-01 -4.32954033e-01]
 [-4.91811803e-01 -3.60861678e-01]
 [-5.48909912e-01 -2.88267079e-01]
 [-5.87812164e-01 -2.26664641e-01]
 [-5.79543657e-01 -2.71531122e-01]
 [-5.86563807e-01 -2.80076598e-01]
 [-6.00326136e-01 -2.74241738e-01]
 [-5.67589542e-01 -3.56008584e-01]
 [-6.11252234e-01 -2.97944133e-01]
 [-6.62681829e-01 -1.92231095e-01]
 [-6.99434809e-01  2.81237977e-02]
 [-6.63528162e-01 -2.52646745e-01]
 [-7.07437708e-01 -1.33910002e-01]
 [-6.95635703e-01 -2.21339037e-01]
 [-7.39522482e-01 -2.65800316e-02]
 [-7.41761456e-01  1.10860010e-01]
 [-7.31852772e-01  2.04918327e-01]
 [-7.62821244e-01  1.04898761e-01]
 [-7.78110642e-01  5.42570570e-02]
 [-7.10909578e-01  3.44539651e-01]
 [-7.84879826e-01  1.54801999e-01]
 [-7.75097682e-01  2.35209657e-01]
 [-8.09110142e-01  1.33194513e-01]
 [-7.12117393e-01  4.26367000e-01]
 [-7.84729601e-01  2.99665570e-01]
 [-8.49662414e-01 -2.39537407e-02]
 [-7.11847617e-01  4.82569136e-01]
 [-7.47488621e-01  4.45152515e-01]
 [-7.42609239e-01  4.72156244e-01]
 [-8.43868125e-01  2.82819001e-01]
 [-6.50907748e-01  6.21545737e-01]
 [-8.11713922e-01  4.11364205e-01]
 [-8.46795454e-01  3.59635174e-01]
 [-8.35929064e-01  4.07581402e-01]
 [-9.22350700e-01  1.81298612e-01]
 [-4.56405752e-01  8.33182927e-01]
 [-6.30262690e-01  7.24133235e-01]
 [-5.75705887e-01  7.80680941e-01]
 [-2.06074574e-01  9.58088341e-01]
 [-5.97431027e-01  7.89415080e-01]]
[[1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]]

此时,x是输入数据, t是监督标签。观察x和t的形状,可知它们各自有300笔样本数据,其中x是二维数据,t是三维数据。另外,t是one-hot向量,对应的正确解标签的类标记为1,其余的标记为0。下面,我们把这些数据绘制在图上,结果如图1-31所示。

python 复制代码
# 绘制数据点
N = 100
CLS_NUM = 3
markers = ['o', 'x', '^']
for i in range(CLS_NUM):
    plt.scatter(x[i*N:(i+1)*N, 0], x[i*N:(i+1)*N, 1], s=40, marker=markers[i])
plt.show()

如图1-31所示,输入是二维数据,类别数是3。观察这个数据集可知,它不能被直线分割。因此,我们需要学习非线性的分割线。那么,我们的神经网络(具有使用非线性的sigmoid激活函数的隐藏层的神经网络)能否正确学习这种非线性模式呢?让我们实验一下。

Tip:数据集划分

因为这个实验相对简单,所以我们不把数据集分成训练数据、验证数据和测试数据。不过,实际任务中会将数据集分为训练数据和测试数据(以及验证数据)来进行学习和评估。

python 复制代码
import sys
sys.path.append('..')
import numpy as np
from common.layers import Affine,Sigmoid,SoftmaxWithLoss

class TwoLayerNet:
    def __init__(self,input_size,hidden_size,output_size):
        I,H,O = input_size,hidden_size,output_size
        # 初始化权重和偏置
        W1 = 0.01*np.random.randn(I,H)# (1,I)*(I,H)
        b1 = np.zeros(H)
        W2 = 0.01*np.random.randn(H,O)# (1,H)*(H,O)
        b2 = np.zeros(O)
        # 生成层
        self.layers = [
            Affine(W1,b1),
            Sigmoid(),
            Affine(W2,b2)
        ]
        # 损失函数
        self.loss_layer = SoftmaxWithLoss()
        # 将所有的权重和梯度整理到列表中
        self.params, self.grads = [],[]
        for layer in self.layers:
            self.params += layer.params
            self.grads += layer.grads

    def predict(self,x):
        for layer in self.layers:
            x = layer.forward(x)
        return x
    
    def forward(self,x,t):
        score = self.predict(x)
        loss = self.loss_layer.forward(score,t)
        return loss
    
    def backward(self,dout=1):
        dout = self.loss_layer.backward(dout)
        for layer in reversed(self.layers):# 反向传播
            dout = layer.backward(dout)
        return dout
    

初始化程序接收3个参数。input_size是输入层的神经元数,hidden_size是隐藏层的神经元数,output_size是输出层的神经元数。在内部实现中,首先用零向量(np.zeros(​)​)初始化偏置,再用小的随机数(0.01 *np.random.randn(​)​)初始化权重。通过将权重设成小的随机数,学习可以更容易地进行。接着,生成必要的层,并将它们整理到实例变量layers列表中。最后,将这个模型使用到的参数和梯度归纳在一起。

学习使用的代码

python 复制代码
# 学习使用的代码
import sys
sys.path.append('..')
import numpy as np
from common.optimizer import SGD
from dataset import spiral
import matplotlib.pyplot as plt
# from two_layer_net import TwoLayerNet

# 设定超参数
max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0

# 读入数据,生成模型和优化器
x, t = spiral.load_data()
model = TwoLayerNet(input_size=2, hidden_size=hidden_size, output_size=3)
optimizer = SGD(lr=learning_rate)

# 学习用的变量
data_size = len(x)
max_iters = data_size // batch_size  # 每个epoch中的批次数
total_loss = 0
loss_count = 0
loss_list = []

for epoch in range(max_epoch):
    # 打乱数据
    idx = np.random.permutation(data_size)
    x = x[idx]  # data sample
    t = t[idx]  # data label

    for iters in range(max_iters):
        batch_x = x[iters * batch_size:(iters + 1) * batch_size]
        batch_t = t[iters * batch_size:(iters + 1) * batch_size]

        # 确保batch_x和batch_t都是二维数组
        if batch_x.ndim == 1:
            batch_x = batch_x.reshape(1, -1)  # 如果是1D数组,变成2D(1, 特征数)
        if batch_t.ndim == 1:
            batch_t = batch_t.reshape(1, -1)  # 如果是1D数组,变成2D(1, 类别数)

        loss = model.forward(batch_x, batch_t)
        model.backward()
        optimizer.update(model.params, model.grads)
        total_loss += loss
        loss_count += 1

    # 定期输出学习过程
    if (iters + 1) % 10 == 0:
        avg_loss = total_loss / loss_count
        print('|epoch %d | iter %d / %d | loss %.2f' % (epoch + 1, iters + 1, max_iters, avg_loss))
        loss_list.append(avg_loss)
        total_loss, loss_count = 0, 0
复制代码
|epoch 1 | iter 10 / 10 | loss 1.13
|epoch 2 | iter 10 / 10 | loss 1.13
|epoch 3 | iter 10 / 10 | loss 1.12
|epoch 4 | iter 10 / 10 | loss 1.12
|epoch 5 | iter 10 / 10 | loss 1.11
|epoch 6 | iter 10 / 10 | loss 1.14
|epoch 7 | iter 10 / 10 | loss 1.16
|epoch 8 | iter 10 / 10 | loss 1.11
|epoch 9 | iter 10 / 10 | loss 1.12
|epoch 10 | iter 10 / 10 | loss 1.13
|epoch 11 | iter 10 / 10 | loss 1.12
|epoch 12 | iter 10 / 10 | loss 1.11
|epoch 13 | iter 10 / 10 | loss 1.09
|epoch 14 | iter 10 / 10 | loss 1.08
|epoch 15 | iter 10 / 10 | loss 1.04
|epoch 16 | iter 10 / 10 | loss 1.03
|epoch 17 | iter 10 / 10 | loss 0.96
|epoch 18 | iter 10 / 10 | loss 0.92
|epoch 19 | iter 10 / 10 | loss 0.92
|epoch 20 | iter 10 / 10 | loss 0.87
|epoch 21 | iter 10 / 10 | loss 0.85
|epoch 22 | iter 10 / 10 | loss 0.82
|epoch 23 | iter 10 / 10 | loss 0.79
|epoch 24 | iter 10 / 10 | loss 0.78
|epoch 25 | iter 10 / 10 | loss 0.82
|epoch 26 | iter 10 / 10 | loss 0.78
|epoch 27 | iter 10 / 10 | loss 0.76
|epoch 28 | iter 10 / 10 | loss 0.76
|epoch 29 | iter 10 / 10 | loss 0.78
|epoch 30 | iter 10 / 10 | loss 0.75
|epoch 31 | iter 10 / 10 | loss 0.78
|epoch 32 | iter 10 / 10 | loss 0.77
|epoch 33 | iter 10 / 10 | loss 0.77
|epoch 34 | iter 10 / 10 | loss 0.78
|epoch 35 | iter 10 / 10 | loss 0.75
|epoch 36 | iter 10 / 10 | loss 0.74
|epoch 37 | iter 10 / 10 | loss 0.76
|epoch 38 | iter 10 / 10 | loss 0.76
|epoch 39 | iter 10 / 10 | loss 0.73
|epoch 40 | iter 10 / 10 | loss 0.75
|epoch 41 | iter 10 / 10 | loss 0.76
|epoch 42 | iter 10 / 10 | loss 0.76
|epoch 43 | iter 10 / 10 | loss 0.76
|epoch 44 | iter 10 / 10 | loss 0.74
|epoch 45 | iter 10 / 10 | loss 0.75
|epoch 46 | iter 10 / 10 | loss 0.73
|epoch 47 | iter 10 / 10 | loss 0.72
|epoch 48 | iter 10 / 10 | loss 0.73
|epoch 49 | iter 10 / 10 | loss 0.72
|epoch 50 | iter 10 / 10 | loss 0.72
|epoch 51 | iter 10 / 10 | loss 0.72
|epoch 52 | iter 10 / 10 | loss 0.72
|epoch 53 | iter 10 / 10 | loss 0.74
|epoch 54 | iter 10 / 10 | loss 0.74
|epoch 55 | iter 10 / 10 | loss 0.72
|epoch 56 | iter 10 / 10 | loss 0.72
|epoch 57 | iter 10 / 10 | loss 0.71
|epoch 58 | iter 10 / 10 | loss 0.70
|epoch 59 | iter 10 / 10 | loss 0.72
|epoch 60 | iter 10 / 10 | loss 0.70
|epoch 61 | iter 10 / 10 | loss 0.71
|epoch 62 | iter 10 / 10 | loss 0.72
|epoch 63 | iter 10 / 10 | loss 0.70
|epoch 64 | iter 10 / 10 | loss 0.71
|epoch 65 | iter 10 / 10 | loss 0.73
|epoch 66 | iter 10 / 10 | loss 0.70
|epoch 67 | iter 10 / 10 | loss 0.71
|epoch 68 | iter 10 / 10 | loss 0.69
|epoch 69 | iter 10 / 10 | loss 0.70
|epoch 70 | iter 10 / 10 | loss 0.71
|epoch 71 | iter 10 / 10 | loss 0.68
|epoch 72 | iter 10 / 10 | loss 0.69
|epoch 73 | iter 10 / 10 | loss 0.67
|epoch 74 | iter 10 / 10 | loss 0.68
|epoch 75 | iter 10 / 10 | loss 0.67
|epoch 76 | iter 10 / 10 | loss 0.66
|epoch 77 | iter 10 / 10 | loss 0.69
|epoch 78 | iter 10 / 10 | loss 0.64
|epoch 79 | iter 10 / 10 | loss 0.68
|epoch 80 | iter 10 / 10 | loss 0.64
|epoch 81 | iter 10 / 10 | loss 0.64
|epoch 82 | iter 10 / 10 | loss 0.66
|epoch 83 | iter 10 / 10 | loss 0.62
|epoch 84 | iter 10 / 10 | loss 0.62
|epoch 85 | iter 10 / 10 | loss 0.61
|epoch 86 | iter 10 / 10 | loss 0.60
|epoch 87 | iter 10 / 10 | loss 0.60
|epoch 88 | iter 10 / 10 | loss 0.61
|epoch 89 | iter 10 / 10 | loss 0.59
|epoch 90 | iter 10 / 10 | loss 0.58
|epoch 91 | iter 10 / 10 | loss 0.56
|epoch 92 | iter 10 / 10 | loss 0.56
|epoch 93 | iter 10 / 10 | loss 0.54
|epoch 94 | iter 10 / 10 | loss 0.53
|epoch 95 | iter 10 / 10 | loss 0.53
|epoch 96 | iter 10 / 10 | loss 0.52
|epoch 97 | iter 10 / 10 | loss 0.51
|epoch 98 | iter 10 / 10 | loss 0.50
|epoch 99 | iter 10 / 10 | loss 0.48
|epoch 100 | iter 10 / 10 | loss 0.48
|epoch 101 | iter 10 / 10 | loss 0.46
|epoch 102 | iter 10 / 10 | loss 0.45
|epoch 103 | iter 10 / 10 | loss 0.45
|epoch 104 | iter 10 / 10 | loss 0.44
|epoch 105 | iter 10 / 10 | loss 0.44
|epoch 106 | iter 10 / 10 | loss 0.41
|epoch 107 | iter 10 / 10 | loss 0.40
|epoch 108 | iter 10 / 10 | loss 0.41
|epoch 109 | iter 10 / 10 | loss 0.40
|epoch 110 | iter 10 / 10 | loss 0.40
|epoch 111 | iter 10 / 10 | loss 0.38
|epoch 112 | iter 10 / 10 | loss 0.38
|epoch 113 | iter 10 / 10 | loss 0.36
|epoch 114 | iter 10 / 10 | loss 0.37
|epoch 115 | iter 10 / 10 | loss 0.35
|epoch 116 | iter 10 / 10 | loss 0.34
|epoch 117 | iter 10 / 10 | loss 0.34
|epoch 118 | iter 10 / 10 | loss 0.34
|epoch 119 | iter 10 / 10 | loss 0.33
|epoch 120 | iter 10 / 10 | loss 0.34
|epoch 121 | iter 10 / 10 | loss 0.32
|epoch 122 | iter 10 / 10 | loss 0.32
|epoch 123 | iter 10 / 10 | loss 0.31
|epoch 124 | iter 10 / 10 | loss 0.31
|epoch 125 | iter 10 / 10 | loss 0.30
|epoch 126 | iter 10 / 10 | loss 0.30
|epoch 127 | iter 10 / 10 | loss 0.28
|epoch 128 | iter 10 / 10 | loss 0.28
|epoch 129 | iter 10 / 10 | loss 0.28
|epoch 130 | iter 10 / 10 | loss 0.28
|epoch 131 | iter 10 / 10 | loss 0.27
|epoch 132 | iter 10 / 10 | loss 0.27
|epoch 133 | iter 10 / 10 | loss 0.27
|epoch 134 | iter 10 / 10 | loss 0.27
|epoch 135 | iter 10 / 10 | loss 0.27
|epoch 136 | iter 10 / 10 | loss 0.26
|epoch 137 | iter 10 / 10 | loss 0.26
|epoch 138 | iter 10 / 10 | loss 0.26
|epoch 139 | iter 10 / 10 | loss 0.25
|epoch 140 | iter 10 / 10 | loss 0.24
|epoch 141 | iter 10 / 10 | loss 0.24
|epoch 142 | iter 10 / 10 | loss 0.25
|epoch 143 | iter 10 / 10 | loss 0.24
|epoch 144 | iter 10 / 10 | loss 0.24
|epoch 145 | iter 10 / 10 | loss 0.23
|epoch 146 | iter 10 / 10 | loss 0.24
|epoch 147 | iter 10 / 10 | loss 0.23
|epoch 148 | iter 10 / 10 | loss 0.23
|epoch 149 | iter 10 / 10 | loss 0.22
|epoch 150 | iter 10 / 10 | loss 0.22
|epoch 151 | iter 10 / 10 | loss 0.22
|epoch 152 | iter 10 / 10 | loss 0.22
|epoch 153 | iter 10 / 10 | loss 0.22
|epoch 154 | iter 10 / 10 | loss 0.22
|epoch 155 | iter 10 / 10 | loss 0.22
|epoch 156 | iter 10 / 10 | loss 0.21
|epoch 157 | iter 10 / 10 | loss 0.21
|epoch 158 | iter 10 / 10 | loss 0.20
|epoch 159 | iter 10 / 10 | loss 0.21
|epoch 160 | iter 10 / 10 | loss 0.20
|epoch 161 | iter 10 / 10 | loss 0.20
|epoch 162 | iter 10 / 10 | loss 0.20
|epoch 163 | iter 10 / 10 | loss 0.21
|epoch 164 | iter 10 / 10 | loss 0.20
|epoch 165 | iter 10 / 10 | loss 0.20
|epoch 166 | iter 10 / 10 | loss 0.19
|epoch 167 | iter 10 / 10 | loss 0.19
|epoch 168 | iter 10 / 10 | loss 0.19
|epoch 169 | iter 10 / 10 | loss 0.19
|epoch 170 | iter 10 / 10 | loss 0.19
|epoch 171 | iter 10 / 10 | loss 0.19
|epoch 172 | iter 10 / 10 | loss 0.18
|epoch 173 | iter 10 / 10 | loss 0.18
|epoch 174 | iter 10 / 10 | loss 0.18
|epoch 175 | iter 10 / 10 | loss 0.18
|epoch 176 | iter 10 / 10 | loss 0.18
|epoch 177 | iter 10 / 10 | loss 0.18
|epoch 178 | iter 10 / 10 | loss 0.18
|epoch 179 | iter 10 / 10 | loss 0.17
|epoch 180 | iter 10 / 10 | loss 0.17
|epoch 181 | iter 10 / 10 | loss 0.18
|epoch 182 | iter 10 / 10 | loss 0.17
|epoch 183 | iter 10 / 10 | loss 0.18
|epoch 184 | iter 10 / 10 | loss 0.17
|epoch 185 | iter 10 / 10 | loss 0.17
|epoch 186 | iter 10 / 10 | loss 0.18
|epoch 187 | iter 10 / 10 | loss 0.17
|epoch 188 | iter 10 / 10 | loss 0.17
|epoch 189 | iter 10 / 10 | loss 0.17
|epoch 190 | iter 10 / 10 | loss 0.17
|epoch 191 | iter 10 / 10 | loss 0.16
|epoch 192 | iter 10 / 10 | loss 0.17
|epoch 193 | iter 10 / 10 | loss 0.16
|epoch 194 | iter 10 / 10 | loss 0.16
|epoch 195 | iter 10 / 10 | loss 0.16
|epoch 196 | iter 10 / 10 | loss 0.16
|epoch 197 | iter 10 / 10 | loss 0.16
|epoch 198 | iter 10 / 10 | loss 0.15
|epoch 199 | iter 10 / 10 | loss 0.16
|epoch 200 | iter 10 / 10 | loss 0.16
|epoch 201 | iter 10 / 10 | loss 0.15
|epoch 202 | iter 10 / 10 | loss 0.16
|epoch 203 | iter 10 / 10 | loss 0.16
|epoch 204 | iter 10 / 10 | loss 0.15
|epoch 205 | iter 10 / 10 | loss 0.16
|epoch 206 | iter 10 / 10 | loss 0.15
|epoch 207 | iter 10 / 10 | loss 0.15
|epoch 208 | iter 10 / 10 | loss 0.15
|epoch 209 | iter 10 / 10 | loss 0.15
|epoch 210 | iter 10 / 10 | loss 0.15
|epoch 211 | iter 10 / 10 | loss 0.15
|epoch 212 | iter 10 / 10 | loss 0.15
|epoch 213 | iter 10 / 10 | loss 0.15
|epoch 214 | iter 10 / 10 | loss 0.15
|epoch 215 | iter 10 / 10 | loss 0.15
|epoch 216 | iter 10 / 10 | loss 0.14
|epoch 217 | iter 10 / 10 | loss 0.14
|epoch 218 | iter 10 / 10 | loss 0.15
|epoch 219 | iter 10 / 10 | loss 0.14
|epoch 220 | iter 10 / 10 | loss 0.14
|epoch 221 | iter 10 / 10 | loss 0.14
|epoch 222 | iter 10 / 10 | loss 0.14
|epoch 223 | iter 10 / 10 | loss 0.14
|epoch 224 | iter 10 / 10 | loss 0.14
|epoch 225 | iter 10 / 10 | loss 0.14
|epoch 226 | iter 10 / 10 | loss 0.14
|epoch 227 | iter 10 / 10 | loss 0.14
|epoch 228 | iter 10 / 10 | loss 0.14
|epoch 229 | iter 10 / 10 | loss 0.13
|epoch 230 | iter 10 / 10 | loss 0.14
|epoch 231 | iter 10 / 10 | loss 0.13
|epoch 232 | iter 10 / 10 | loss 0.14
|epoch 233 | iter 10 / 10 | loss 0.13
|epoch 234 | iter 10 / 10 | loss 0.13
|epoch 235 | iter 10 / 10 | loss 0.13
|epoch 236 | iter 10 / 10 | loss 0.13
|epoch 237 | iter 10 / 10 | loss 0.14
|epoch 238 | iter 10 / 10 | loss 0.13
|epoch 239 | iter 10 / 10 | loss 0.13
|epoch 240 | iter 10 / 10 | loss 0.14
|epoch 241 | iter 10 / 10 | loss 0.13
|epoch 242 | iter 10 / 10 | loss 0.13
|epoch 243 | iter 10 / 10 | loss 0.13
|epoch 244 | iter 10 / 10 | loss 0.13
|epoch 245 | iter 10 / 10 | loss 0.13
|epoch 246 | iter 10 / 10 | loss 0.13
|epoch 247 | iter 10 / 10 | loss 0.13
|epoch 248 | iter 10 / 10 | loss 0.13
|epoch 249 | iter 10 / 10 | loss 0.13
|epoch 250 | iter 10 / 10 | loss 0.13
|epoch 251 | iter 10 / 10 | loss 0.13
|epoch 252 | iter 10 / 10 | loss 0.12
|epoch 253 | iter 10 / 10 | loss 0.12
|epoch 254 | iter 10 / 10 | loss 0.12
|epoch 255 | iter 10 / 10 | loss 0.12
|epoch 256 | iter 10 / 10 | loss 0.12
|epoch 257 | iter 10 / 10 | loss 0.12
|epoch 258 | iter 10 / 10 | loss 0.12
|epoch 259 | iter 10 / 10 | loss 0.13
|epoch 260 | iter 10 / 10 | loss 0.12
|epoch 261 | iter 10 / 10 | loss 0.13
|epoch 262 | iter 10 / 10 | loss 0.12
|epoch 263 | iter 10 / 10 | loss 0.12
|epoch 264 | iter 10 / 10 | loss 0.13
|epoch 265 | iter 10 / 10 | loss 0.12
|epoch 266 | iter 10 / 10 | loss 0.12
|epoch 267 | iter 10 / 10 | loss 0.12
|epoch 268 | iter 10 / 10 | loss 0.12
|epoch 269 | iter 10 / 10 | loss 0.11
|epoch 270 | iter 10 / 10 | loss 0.12
|epoch 271 | iter 10 / 10 | loss 0.12
|epoch 272 | iter 10 / 10 | loss 0.12
|epoch 273 | iter 10 / 10 | loss 0.12
|epoch 274 | iter 10 / 10 | loss 0.12
|epoch 275 | iter 10 / 10 | loss 0.11
|epoch 276 | iter 10 / 10 | loss 0.12
|epoch 277 | iter 10 / 10 | loss 0.12
|epoch 278 | iter 10 / 10 | loss 0.11
|epoch 279 | iter 10 / 10 | loss 0.11
|epoch 280 | iter 10 / 10 | loss 0.11
|epoch 281 | iter 10 / 10 | loss 0.11
|epoch 282 | iter 10 / 10 | loss 0.12
|epoch 283 | iter 10 / 10 | loss 0.11
|epoch 284 | iter 10 / 10 | loss 0.11
|epoch 285 | iter 10 / 10 | loss 0.11
|epoch 286 | iter 10 / 10 | loss 0.11
|epoch 287 | iter 10 / 10 | loss 0.11
|epoch 288 | iter 10 / 10 | loss 0.12
|epoch 289 | iter 10 / 10 | loss 0.11
|epoch 290 | iter 10 / 10 | loss 0.11
|epoch 291 | iter 10 / 10 | loss 0.11
|epoch 292 | iter 10 / 10 | loss 0.11
|epoch 293 | iter 10 / 10 | loss 0.11
|epoch 294 | iter 10 / 10 | loss 0.11
|epoch 295 | iter 10 / 10 | loss 0.12
|epoch 296 | iter 10 / 10 | loss 0.11
|epoch 297 | iter 10 / 10 | loss 0.12
|epoch 298 | iter 10 / 10 | loss 0.11
|epoch 299 | iter 10 / 10 | loss 0.11
|epoch 300 | iter 10 / 10 | loss 0.11

Tip:epoch

Epoch表示学习的单位。1个epoch相当于模型"看过"一遍所有的学习数据(遍历数据集)​。这里我们进行300个epoch的学习。

Tip:数据打乱

在进行学习时,需要随机选择数据作为mini-batch。这里,我们以epoch为单位打乱数据,对于打乱后的数据,按顺序从头开始抽取数据。数据的打乱(准确地说,是数据索引的打乱)使用np.random.permutation(​)方法。给定参数N,该方法可以返回0到N-1的随机序列,其实际的使用示例如下所示。

python 复制代码
import numpy as np
np.random.permutation(10)
复制代码
array([5, 1, 8, 4, 9, 7, 0, 2, 6, 3])
python 复制代码
np.random.permutation(10)
复制代码
array([3, 4, 2, 7, 8, 5, 6, 0, 9, 1])

像这样,调用np.random.permutation(​)可以随机打乱数据的索引。

python 复制代码
loss_list
复制代码
[1.1256062166823237,
 1.1255202354489933,
 1.1162613752115285,
 1.1162867078413503,
 1.1123000112951948,
 1.1384639824108038,
 1.1590961883070312,
 1.1086316143023154,
 1.1173305676924539,
 1.1287957712269248,
 1.1168438089353867,
 1.108338779101816,
 1.0876149200499459,
 1.076681386581935,
 1.0442376735950387,
 1.0345782626337772,
 0.9572932039643971,
 0.918385321087945,
 0.9241491096212101,
 0.8685139076509195,
 0.849380704784154,
 0.8171629191788113,
 0.7924414711357766,
 0.7826646392986113,
 0.8235432039035636,
 0.7754573601774306,
 0.7557857636797779,
 0.7644773546985875,
 0.783489908441849,
 0.7507895610696304,
 0.7773067036165259,
 0.7650839562418821,
 0.7727897179944694,
 0.7819402998382252,
 0.7479802970891092,
 0.7449918634368045,
 0.7560347486336814,
 0.762136567723541,
 0.7308895411004578,
 0.7530268898576871,
 0.7598416342022494,
 0.7594443798911804,
 0.7609245612331736,
 0.7385235003122192,
 0.7483287079215573,
 0.732212322256757,
 0.7226947264484566,
 0.7329453633807874,
 0.7228591003722222,
 0.7225109819294906,
 0.7151355271138069,
 0.7195462325887142,
 0.7375188235491987,
 0.7361580823220941,
 0.7224648808897995,
 0.7182891638960814,
 0.7074840271194414,
 0.7004036126944774,
 0.7172821745385198,
 0.7014167269154442,
 0.7139798052248814,
 0.7158390973991744,
 0.70241458479173,
 0.7147655630827306,
 0.7258385981107364,
 0.6991952628756113,
 0.7149037812469097,
 0.6923793329208154,
 0.6950715496129248,
 0.7051743201139319,
 0.6818892201279282,
 0.6931081236194652,
 0.6678843529136961,
 0.6795690012596001,
 0.6696918781908965,
 0.6601032915888689,
 0.6948944014779308,
 0.644871400994823,
 0.6797970357896568,
 0.6389928717097755,
 0.6352100394457484,
 0.6642182679001878,
 0.6194764020308783,
 0.6229674573767083,
 0.6125107706517634,
 0.5971911133520729,
 0.5988198466130645,
 0.6083885307835829,
 0.5881132092407987,
 0.5773251645421092,
 0.5573829192639743,
 0.5604590214223,
 0.5448990532547215,
 0.5286015850570895,
 0.5299314123468276,
 0.5163065862636679,
 0.5124308579159358,
 0.5016440362124143,
 0.4844736823098785,
 0.48000592169629375,
 0.4609934880592183,
 0.45402197600669847,
 0.45357141775905935,
 0.4429929465269392,
 0.43615541714274964,
 0.41049348011889714,
 0.4031641890220195,
 0.4076112682225584,
 0.4036238903780629,
 0.39951469523134286,
 0.37921674077822315,
 0.3750471397806644,
 0.36229498494049545,
 0.3675710439808242,
 0.34616510526107874,
 0.344898000290116,
 0.33528933507001024,
 0.33807418815033274,
 0.3283438143916972,
 0.33519658210523795,
 0.31583944005785075,
 0.3174067835957746,
 0.30558094499258315,
 0.31306743220376837,
 0.3001822128388422,
 0.298741892547424,
 0.28380294707062764,
 0.2811840756517869,
 0.28183411256599333,
 0.2777145993871101,
 0.27238825698634855,
 0.26707712316069976,
 0.27094530362207153,
 0.27057245386599776,
 0.27030885972727675,
 0.26437260107230176,
 0.2637830554224593,
 0.2560501993977538,
 0.2521541449844564,
 0.2429111293445622,
 0.23945524042724556,
 0.24542262746858579,
 0.2366753717699007,
 0.2403193106192818,
 0.23361288810496253,
 0.23650963179911216,
 0.23025407656786476,
 0.22932871901830215,
 0.2240169314658727,
 0.22394676535578756,
 0.22310757814713048,
 0.22264097958236592,
 0.21697683471580703,
 0.21755951288880998,
 0.21507058662677148,
 0.20835755321509958,
 0.20913544783239285,
 0.20480456033012645,
 0.20860699412184833,
 0.19815557489605248,
 0.20435571889234488,
 0.19724020746228565,
 0.21198898405557265,
 0.19990055286420488,
 0.19670798433788012,
 0.19493732434411776,
 0.18861275602513147,
 0.1871358291328788,
 0.18592571707387812,
 0.18837376894472496,
 0.18999546718467414,
 0.18273072019428158,
 0.18441370204922142,
 0.1833058740625836,
 0.1802345548091979,
 0.17511064852464456,
 0.17767416217643664,
 0.1766872863613579,
 0.17293847154024714,
 0.17328551329855016,
 0.17653792738499427,
 0.1728261261811141,
 0.17626893090372214,
 0.17434628901735752,
 0.16949071085989292,
 0.17755351335716013,
 0.17285695698153453,
 0.16648514397547962,
 0.16594153132977602,
 0.16931139951263377,
 0.16301106121602332,
 0.16958890572080493,
 0.161906824311852,
 0.15939537909039864,
 0.16286032857961993,
 0.1612419736005516,
 0.16364914579280607,
 0.1527678767991844,
 0.15970328019708874,
 0.15747782851517947,
 0.1546769422097933,
 0.1569377065521722,
 0.1554765278672734,
 0.15394834058798243,
 0.15758367150781283,
 0.15348488267419602,
 0.15446293102249137,
 0.1476797718663996,
 0.15141489293752314,
 0.1493564627042096,
 0.1517051939074469,
 0.14755917786826186,
 0.14785145742055836,
 0.14515980626797986,
 0.14678399383604507,
 0.1438836332008838,
 0.14259768661059344,
 0.1455518479651135,
 0.13922422236572785,
 0.13730944343188073,
 0.14410598841352643,
 0.1405077822816634,
 0.14224984249809627,
 0.14254640666896867,
 0.13667202381979077,
 0.1365901190048486,
 0.14251665205070402,
 0.1360865718323748,
 0.1290996692132135,
 0.13883107120690055,
 0.13142337073882102,
 0.13594951747222708,
 0.13429273376232015,
 0.13144069789855423,
 0.13148269177878147,
 0.12860673738419168,
 0.135392048601573,
 0.13223101774593124,
 0.13134628823538033,
 0.1378159605341089,
 0.13367286871605075,
 0.12860547570667225,
 0.12775015207856294,
 0.1300224354664546,
 0.13145974901706844,
 0.12796605893873647,
 0.12643439113307112,
 0.13201041906408498,
 0.12779009853793596,
 0.1304032884799826,
 0.13348456295442618,
 0.12453963049125043,
 0.12132675691111519,
 0.12286008452064738,
 0.12267013423577518,
 0.12312488940516728,
 0.12127839505918629,
 0.12467328584597182,
 0.13195709803216454,
 0.1202665723651631,
 0.12500316552637772,
 0.12062063523861206,
 0.1181079868686632,
 0.12681552271035718,
 0.11807400896605273,
 0.12014005000495154,
 0.11993006486808415,
 0.11732517384704762,
 0.1136047759790468,
 0.11888250117808816,
 0.11633511040515612,
 0.11825391183559267,
 0.11623703569727495,
 0.12322514116418926,
 0.11326334261276325,
 0.11580125295481199,
 0.12418527047451162,
 0.11289494501453376,
 0.11077561346473126,
 0.11431217925067252,
 0.11427378474471335,
 0.11731530918730074,
 0.1117398474559641,
 0.113824373282757,
 0.11354569974260224,
 0.11082773737319436,
 0.11235645391470028,
 0.1169893924945046,
 0.11394781289859857,
 0.11472703869653125,
 0.11195179439782446,
 0.1125271442536377,
 0.11019011632052719,
 0.11147947028025722,
 0.11578316334144989,
 0.110222826957397,
 0.11966849814776921,
 0.11160268411696476,
 0.10668719692739766,
 0.10854954795724314]
python 复制代码
# 绘制损失曲线
plt.plot(loss_list)
plt.xlabel('Epoch (every 10 epochs)')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.grid()
plt.show()

Trainer类

如前所述,本书中有很多机会执行神经网络的学习。为此,就需要编写前面那样的学习用的代码。然而,每次都写相同的代码太无聊了,因此我们将进行学习的类作为Trainer类提供出来。Trainer类的内部实现和刚才的源代码几乎相同,只是添加了一些新的功能而已,我们在需要的时候再详细说明其用法。

这个类的初始化程序接收神经网络(模型)和优化器,具体如下所示。

python 复制代码
# model = TwoLayerNet(...)
# optimizer = SGD(lr=1.0)
# trainer = Trainer(model,optimizer)

然后,调用fit(​)方法开始学习。fit(​)方法的参数如表1-1所示。

另外,Trainer类有plot(​)方法,它将fit(​)方法记录的损失(准确地说,是按照eval_interval评价的平均损失)在图上画出来。使用Trainer类进行学习的代码如下所示

python 复制代码
import sys
sys.path.append('..')
from common.optimizer import SGD
from common.trainer import Trainer
from dataset import spiral
# from two_layer_net import TwoLayerNet

max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0

x,t = spiral.load_data()

model = TwoLayerNet(input_size=2,hidden_size=hidden_size,output_size=3)

optimizer = SGD(lr=learning_rate)
trainer = Trainer(model,optimizer)
trainer.fit(x,t,max_epoch,batch_size,eval_interval=10)
trainer.plot()
复制代码
| epoch 1 |  iter 1 / 10 | time 0[s] | loss 1.10
| epoch 2 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 3 |  iter 1 / 10 | time 0[s] | loss 1.13
| epoch 4 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 5 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 6 |  iter 1 / 10 | time 0[s] | loss 1.10
| epoch 7 |  iter 1 / 10 | time 0[s] | loss 1.14
| epoch 8 |  iter 1 / 10 | time 0[s] | loss 1.16
| epoch 9 |  iter 1 / 10 | time 0[s] | loss 1.11
| epoch 10 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 11 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 12 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 13 |  iter 1 / 10 | time 0[s] | loss 1.10
| epoch 14 |  iter 1 / 10 | time 0[s] | loss 1.09
| epoch 15 |  iter 1 / 10 | time 0[s] | loss 1.08
| epoch 16 |  iter 1 / 10 | time 0[s] | loss 1.04
| epoch 17 |  iter 1 / 10 | time 0[s] | loss 1.03
| epoch 18 |  iter 1 / 10 | time 0[s] | loss 0.94
| epoch 19 |  iter 1 / 10 | time 0[s] | loss 0.92
| epoch 20 |  iter 1 / 10 | time 0[s] | loss 0.92
| epoch 21 |  iter 1 / 10 | time 0[s] | loss 0.87
| epoch 22 |  iter 1 / 10 | time 0[s] | loss 0.85
| epoch 23 |  iter 1 / 10 | time 0[s] | loss 0.80
| epoch 24 |  iter 1 / 10 | time 0[s] | loss 0.79
| epoch 25 |  iter 1 / 10 | time 0[s] | loss 0.78
| epoch 26 |  iter 1 / 10 | time 0[s] | loss 0.83
| epoch 27 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 28 |  iter 1 / 10 | time 0[s] | loss 0.76
| epoch 29 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 30 |  iter 1 / 10 | time 0[s] | loss 0.76
| epoch 31 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 32 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 33 |  iter 1 / 10 | time 0[s] | loss 0.78
| epoch 34 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 35 |  iter 1 / 10 | time 0[s] | loss 0.78
| epoch 36 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 37 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 38 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 39 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 40 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 41 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 42 |  iter 1 / 10 | time 0[s] | loss 0.76
| epoch 43 |  iter 1 / 10 | time 0[s] | loss 0.79
| epoch 44 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 45 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 46 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 47 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 48 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 49 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 50 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 51 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 52 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 53 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 54 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 55 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 56 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 57 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 58 |  iter 1 / 10 | time 0[s] | loss 0.69
| epoch 59 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 60 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 61 |  iter 1 / 10 | time 0[s] | loss 0.69
| epoch 62 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 63 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 64 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 65 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 66 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 67 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 68 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 69 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 70 |  iter 1 / 10 | time 0[s] | loss 0.68
| epoch 71 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 72 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 73 |  iter 1 / 10 | time 0[s] | loss 0.69
| epoch 74 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 75 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 76 |  iter 1 / 10 | time 0[s] | loss 0.65
| epoch 77 |  iter 1 / 10 | time 0[s] | loss 0.67
| epoch 78 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 79 |  iter 1 / 10 | time 0[s] | loss 0.63
| epoch 80 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 81 |  iter 1 / 10 | time 0[s] | loss 0.65
| epoch 82 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 83 |  iter 1 / 10 | time 0[s] | loss 0.64
| epoch 84 |  iter 1 / 10 | time 0[s] | loss 0.62
| epoch 85 |  iter 1 / 10 | time 0[s] | loss 0.62
| epoch 86 |  iter 1 / 10 | time 0[s] | loss 0.63
| epoch 87 |  iter 1 / 10 | time 0[s] | loss 0.59
| epoch 88 |  iter 1 / 10 | time 0[s] | loss 0.58
| epoch 89 |  iter 1 / 10 | time 0[s] | loss 0.61
| epoch 90 |  iter 1 / 10 | time 0[s] | loss 0.59
| epoch 91 |  iter 1 / 10 | time 0[s] | loss 0.58
| epoch 92 |  iter 1 / 10 | time 0[s] | loss 0.57
| epoch 93 |  iter 1 / 10 | time 0[s] | loss 0.55
| epoch 94 |  iter 1 / 10 | time 0[s] | loss 0.54
| epoch 95 |  iter 1 / 10 | time 0[s] | loss 0.53
| epoch 96 |  iter 1 / 10 | time 0[s] | loss 0.54
| epoch 97 |  iter 1 / 10 | time 0[s] | loss 0.51
| epoch 98 |  iter 1 / 10 | time 0[s] | loss 0.51
| epoch 99 |  iter 1 / 10 | time 0[s] | loss 0.50
| epoch 100 |  iter 1 / 10 | time 0[s] | loss 0.47
| epoch 101 |  iter 1 / 10 | time 0[s] | loss 0.49
| epoch 102 |  iter 1 / 10 | time 0[s] | loss 0.46
| epoch 103 |  iter 1 / 10 | time 0[s] | loss 0.44
| epoch 104 |  iter 1 / 10 | time 0[s] | loss 0.47
| epoch 105 |  iter 1 / 10 | time 0[s] | loss 0.44
| epoch 106 |  iter 1 / 10 | time 0[s] | loss 0.43
| epoch 107 |  iter 1 / 10 | time 0[s] | loss 0.43
| epoch 108 |  iter 1 / 10 | time 0[s] | loss 0.39
| epoch 109 |  iter 1 / 10 | time 0[s] | loss 0.40
| epoch 110 |  iter 1 / 10 | time 0[s] | loss 0.41
| epoch 111 |  iter 1 / 10 | time 0[s] | loss 0.38
| epoch 112 |  iter 1 / 10 | time 0[s] | loss 0.38
| epoch 113 |  iter 1 / 10 | time 0[s] | loss 0.38
| epoch 114 |  iter 1 / 10 | time 0[s] | loss 0.37
| epoch 115 |  iter 1 / 10 | time 0[s] | loss 0.36
| epoch 116 |  iter 1 / 10 | time 0[s] | loss 0.34
| epoch 117 |  iter 1 / 10 | time 0[s] | loss 0.35
| epoch 118 |  iter 1 / 10 | time 0[s] | loss 0.33
| epoch 119 |  iter 1 / 10 | time 0[s] | loss 0.35
| epoch 120 |  iter 1 / 10 | time 0[s] | loss 0.33
| epoch 121 |  iter 1 / 10 | time 0[s] | loss 0.33
| epoch 122 |  iter 1 / 10 | time 0[s] | loss 0.32
| epoch 123 |  iter 1 / 10 | time 0[s] | loss 0.31
| epoch 124 |  iter 1 / 10 | time 0[s] | loss 0.31
| epoch 125 |  iter 1 / 10 | time 0[s] | loss 0.31
| epoch 126 |  iter 1 / 10 | time 0[s] | loss 0.30
| epoch 127 |  iter 1 / 10 | time 0[s] | loss 0.30
| epoch 128 |  iter 1 / 10 | time 0[s] | loss 0.27
| epoch 129 |  iter 1 / 10 | time 0[s] | loss 0.30
| epoch 130 |  iter 1 / 10 | time 0[s] | loss 0.28
| epoch 131 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 132 |  iter 1 / 10 | time 0[s] | loss 0.27
| epoch 133 |  iter 1 / 10 | time 0[s] | loss 0.27
| epoch 134 |  iter 1 / 10 | time 0[s] | loss 0.28
| epoch 135 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 136 |  iter 1 / 10 | time 0[s] | loss 0.28
| epoch 137 |  iter 1 / 10 | time 0[s] | loss 0.25
| epoch 138 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 139 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 140 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 141 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 142 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 143 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 144 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 145 |  iter 1 / 10 | time 0[s] | loss 0.24
| epoch 146 |  iter 1 / 10 | time 0[s] | loss 0.24
| epoch 147 |  iter 1 / 10 | time 0[s] | loss 0.25
| epoch 148 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 149 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 150 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 151 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 152 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 153 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 154 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 155 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 156 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 157 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 158 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 159 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 160 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 161 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 162 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 163 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 164 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 165 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 166 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 167 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 168 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 169 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 170 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 171 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 172 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 173 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 174 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 175 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 176 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 177 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 178 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 179 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 180 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 181 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 182 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 183 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 184 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 185 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 186 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 187 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 188 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 189 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 190 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 191 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 192 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 193 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 194 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 195 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 196 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 197 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 198 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 199 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 200 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 201 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 202 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 203 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 204 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 205 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 206 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 207 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 208 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 209 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 210 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 211 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 212 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 213 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 214 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 215 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 216 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 217 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 218 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 219 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 220 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 221 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 222 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 223 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 224 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 225 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 226 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 227 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 228 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 229 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 230 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 231 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 232 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 233 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 234 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 235 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 236 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 237 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 238 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 239 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 240 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 241 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 242 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 243 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 244 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 245 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 246 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 247 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 248 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 249 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 250 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 251 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 252 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 253 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 254 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 255 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 256 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 257 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 258 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 259 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 260 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 261 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 262 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 263 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 264 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 265 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 266 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 267 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 268 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 269 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 270 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 271 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 272 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 273 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 274 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 275 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 276 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 277 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 278 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 279 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 280 |  iter 1 / 10 | time 0[s] | loss 0.10
| epoch 281 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 282 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 283 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 284 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 285 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 286 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 287 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 288 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 289 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 290 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 291 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 292 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 293 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 294 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 295 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 296 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 297 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 298 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 299 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 300 |  iter 1 / 10 | time 0[s] | loss 0.11

执行这段代码,会进行和之前一样的神经网络的学习。通过将之前展示的学习用的代码交给Trainer类负责,代码变简洁了。本书今后都将使用Trainer类进行学习。

Tip-高速化计算

安装NVIDIA Cuda管理器,并使用Cupy调用GPU进行告高速计算。

https://baijiahao.baidu.com/s?id=1781877547348944869\&wfr=spider\&for=pc

相关推荐
思则变1 小时前
[Pytest] [Part 2]增加 log功能
开发语言·python·pytest
鱼摆摆拜拜1 小时前
第 3 章:神经网络如何学习
人工智能·神经网络·学习
漫谈网络1 小时前
WebSocket 在前后端的完整使用流程
javascript·python·websocket
try2find3 小时前
安装llama-cpp-python踩坑记
开发语言·python·llama
博观而约取4 小时前
Django ORM 1. 创建模型(Model)
数据库·python·django
精灵vector5 小时前
构建专家级SQL Agent交互
python·aigc·ai编程
Zonda要好好学习5 小时前
Python入门Day2
开发语言·python
Vertira5 小时前
pdf 合并 python实现(已解决)
前端·python·pdf
太凉6 小时前
Python之 sorted() 函数的基本语法
python
项目題供诗6 小时前
黑马python(二十四)
开发语言·python