data_path = Path("data/mnist")
data_path.mkdir(parents=True, exist_ok=True)
#是否从网络下载数据文件
source_data_file_from_net = False
#下载数据
#此网址打开较慢,可以使用 https://tianchi.aliyun.com/dataset/165658 阿里云下载
mnist_url = "http://deeplearning.net/data/mnist/"
mnist_zip_name = "mnist.pkl.gz"
if source_data_file_from_net:
if not (data_path/mnist_zip_name).exists:
content = requests.get(mnist_url + mnist_zip_name).content
(data_path / mnist_zip_name).open('wb').write(content)
((x_train, y_train),(x_valid, y_valid), _) = pick.load(f, encoding='latin-1')
else:
with gzip.open((data_path / mnist_zip_name).as_posix(), "rb") as data_file:
((x_train, y_train),(x_valid, y_valid), _) = pick.load(data_file, encoding='latin-1')
else:
data_file = open('data/mnist/mnist.pkl', 'rb+')
((x_train, y_train),(x_valid, y_valid), _) = pickle.load(data_file, encoding='latin-1')
print("x_train shape ", x_train.shape)
print(x_train[:5])
print("y_train shape ", y_train.shape)
print(y_train[:5])
print("第一个数字5展示:")
plt.imshow(x_train[0].reshape((28,28)), cmap='gray')
# 修改损失函数为CategoricalCrossentropy, 报错,还未解决
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
data_file = open('data/mnist/mnist.pkl', 'rb+')
((x_train, y_train),(x_valid, y_valid), _) = pickle.load(data_file, encoding='latin-1')
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_valid = tf.keras.utils.to_categorical(y_valid, num_classes=10)
print("x_train.shape ", x_train.shape)
print("y_train.shape ", y_train.shape)
print("x_valid.shape ", x_valid.shape)
print("y_valid.shape ", y_valid.shape)
model.fit(x_train, y_train, epochs=5, batch_size=64,
validation_data=(x_valid, y_valid))
data_file.close()
x_train.shape (50000, 784)
y_train.shape (50000, 10)
x_valid.shape (10000, 784)
y_valid.shape (10000, 10)
Epoch 1/5
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
Cell In[63], line 18
16 print("x_valid.shape ", x_valid.shape)
17 print("y_valid.shape ", y_valid.shape)
---> 18 model.fit(x_train, y_train, epochs=5, batch_size=64,
19 validation_data=(x_valid, y_valid))
20 data_file.close()
File D:\python\Lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File D:\python\Lib\site-packages\tensorflow\python\eager\execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
51 try:
52 ctx.ensure_initialized()
---> 53 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
54 inputs, attrs, num_outputs)
55 except core._NotOkStatusException as e:
56 if name is not None:
InvalidArgumentError: Graph execution error:
Detected at node Equal defined at (most recent call last):
File "D:\python\Lib\runpy.py", line 198, in _run_module_as_main
File "D:\python\Lib\runpy.py", line 88, in _run_code
File "D:\python\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
File "D:\python\Lib\site-packages\traitlets\config\application.py", line 1043, in launch_instance
File "D:\python\Lib\site-packages\ipykernel\kernelapp.py", line 739, in start
File "D:\python\Lib\site-packages\tornado\platform\asyncio.py", line 205, in start
File "D:\python\Lib\asyncio\base_events.py", line 607, in run_forever
File "D:\python\Lib\asyncio\base_events.py", line 1919, in _run_once
File "D:\python\Lib\asyncio\events.py", line 80, in _run
File "D:\python\Lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
File "D:\python\Lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
File "D:\python\Lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
File "D:\python\Lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
File "D:\python\Lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
File "D:\python\Lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
File "D:\python\Lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
File "D:\python\Lib\site-packages\IPython\core\interactiveshell.py", line 2945, in run_cell
File "D:\python\Lib\site-packages\IPython\core\interactiveshell.py", line 3000, in _run_cell
File "D:\python\Lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
File "D:\python\Lib\site-packages\IPython\core\interactiveshell.py", line 3203, in run_cell_async
File "D:\python\Lib\site-packages\IPython\core\interactiveshell.py", line 3382, in run_ast_nodes
File "D:\python\Lib\site-packages\IPython\core\interactiveshell.py", line 3442, in run_code
File "C:\Users\AXZQ\AppData\Local\Temp\ipykernel_8960\1530222387.py", line 18, in <module>
File "D:\python\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler
File "D:\python\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 320, in fit
File "D:\python\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 121, in one_step_on_iterator
File "D:\python\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 108, in one_step_on_data
File "D:\python\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 77, in train_step
File "D:\python\Lib\site-packages\keras\src\trainers\trainer.py", line 452, in compute_metrics
File "D:\python\Lib\site-packages\keras\src\trainers\compile_utils.py", line 330, in update_state
File "D:\python\Lib\site-packages\keras\src\trainers\compile_utils.py", line 17, in update_state
File "D:\python\Lib\site-packages\keras\src\metrics\reduction_metrics.py", line 204, in update_state
File "D:\python\Lib\site-packages\keras\src\metrics\accuracy_metrics.py", line 240, in sparse_categorical_accuracy
File "D:\python\Lib\site-packages\keras\src\ops\numpy.py", line 2355, in equal
File "D:\python\Lib\site-packages\keras\src\backend\tensorflow\numpy.py", line 1144, in equal
Incompatible shapes: [64,10] vs. [64]
[[{{node Equal}}]] [Op:__inference_one_step_on_iterator_69389]