使用 plt.subplots_adjust
或 fig.subplots_adjust
方法来控制两个子图之间的间距:
python
# 假设 best_true_target_labels, best_pred, new_target_data_input, new_target_data 是由 trainer.train() 返回的
best_true_target_labels, best_pred, new_target_data_input, new_target_data = trainer.train()
if len(best_true_target_labels) == len(best_pred):
index_length = len(best_true_target_labels)
print("打印预测标签和真实标签的长度index_length:", index_length)
i = index_length - 4
# 创建两个图,每个图包含两个子图
fig1, axs1 = plt.subplots(2, 1, figsize=(10, 10))
fig2, axs2 = plt.subplots(2, 1, figsize=(10, 10))
while i <= index_length - 1:
# 拿到最后一个batch的数据去画input图像
input_data = new_target_data_input[i] # input_data.shape:torch.Size([1, 30])
target_data = new_target_data[i]
# input_data对应的数据标签
input_data_labels_plot = best_true_target_labels[i]
best_pred_plot = best_pred[i]
input_data_flattened = input_data.flatten() # 调整形状为 torch.Size([30])
target_data_flattened = target_data.flatten() # 调整形状为 torch.Size([31])
# 创建从 1 开始的 x 轴序列
x_values_input = range(1, input_data_flattened.shape[0] + 1)
x_values_target = range(1, target_data_flattened.shape[0] + 1)
# 根据i选择当前子图和图
if i - (index_length - 4) < 2:
ax = axs1[i - (index_length - 4)]
fig = fig1
else:
ax = axs2[i - (index_length - 4) - 2]
fig = fig2
# 绘制input_data_flattened和target_data_flattened
ax.plot(x_values_input, input_data_flattened.numpy(), label='Input Data', linewidth=2)
ax.plot(x_values_target, target_data_flattened.numpy(), label='Target Data', linewidth=1)
# 绘制 input_data_labels 在 x 轴 31 位置的点
ax.scatter([31], input_data_labels_plot.numpy(), color='green', label='True Label', zorder=2)
ax.scatter([31], best_pred_plot.numpy(), color='red', label='Predicted Label', zorder=2)
ax.set_title(f'Input Data and Forecast Results for index {i}')
ax.set_xlabel('Time Step')
ax.set_ylabel('Value')
ax.legend()
# 扩展 x 轴的范围,以便后续绘图
ax.set_xlim(0, max(len(input_data_flattened), len(target_data_flattened)) + 4)
i += 1
fig1.tight_layout()
fig2.tight_layout()
# 调整两个图的子图之间的间距
fig1.subplots_adjust(hspace=0.4) # 控制 fig1 子图之间的竖直间距
fig2.subplots_adjust(hspace=0.4) # 控制 fig2 子图之间的竖直间距
fig1.suptitle('First Two Subplots')
fig2.suptitle('Second Two Subplots')
plt.show()
在这个代码中:
- 使用
fig1.tight_layout()
和fig2.tight_layout()
调整每个图的子图布局,使其不重叠。 - 使用
fig1.subplots_adjust(hspace=0.4)
和fig2.subplots_adjust(hspace=0.4)
来控制每个图中的子图之间的竖直间距。hspace
参数控制子图之间的竖直间距,值越大间距越大,你可以根据需要调整这个值。
通过这种方式,你可以更好地控制两个子图之间的间距,使图表的布局更加清晰和美观。