transformer注意力权重系数绘图

参考绘制tsne图,首先将模型中的注意力权重导出,因为我的模型中L=2,所以导出两层

python 复制代码
# plot_weight
weight_model_layer0 = Model(inputs=model.inputs, outputs=model.get_layer('transformer_0').output)
weight_output_layer0 = weight_model_layer0.predict(X_test)
np.save('weight_output_layer0', weight_output_layer0[1])

weight_model_layer1 = Model(inputs=model.inputs, outputs=model.get_layer('transformer_1').output)
weight_output_layer1 = weight_model_layer1.predict(X_test)
np.save('weight_output_layer1', weight_output_layer1[1])

然后,搜到一些使用seaborn绘制热力图的代码,其中我查的比较多的问题是

1、如何修改colorbar字体的大小

2、如何修改xy轴labelsize

3、如何给子图添加标题

就是cbar=False,再重新绘制一个colorbar

参考python使用seaborn画热力图中设置colorbar图例刻度字体大小_seaborn 设置colorbar刻度-CSDN博客

python 复制代码
# 绘制热力图
hm1 = sns.heatmap(attention_per_head_0[0:40, 0:40], cbar=False, cbar_kws={'shrink': 0.8}, square=True, xticklabels='auto', yticklabels='auto')
# 修改xy轴labelsize
hm1.tick_params(labelsize=8)
# 设置标题
hm1.set_title('layer1_head_{}'.format(4), size=12)
# 显示colorbar
cb = hm1.figure.colorbar(hm1.collections[0])  
# 修改colorbar的labelsize
cb.ax.tick_params(labelsize=8)
相关推荐
火车叼位5 分钟前
也许你不需要创建.venv, 此规范使python脚本自备依赖
python
火车叼位12 分钟前
脚本伪装:让 Python 与 Node.js 像原生 Shell 命令一样运行
运维·javascript·python
孤狼warrior22 分钟前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
机器学习之心32 分钟前
TCN-Transformer-BiGRU组合模型回归+SHAP分析+新数据预测+多输出!深度学习可解释分析
深度学习·回归·transformer·shap分析
Katecat9966332 分钟前
YOLO11分割算法实现甲状腺超声病灶自动检测与定位_DWR方法应用
python
玩大数据的龙威1 小时前
农经权二轮延包—各种地块示意图
python·arcgis
ZH15455891311 小时前
Flutter for OpenHarmony Python学习助手实战:数据库操作与管理的实现
python·学习·flutter
belldeep1 小时前
python:用 Flask 3 , mistune 2 和 mermaid.min.js 10.9 来实现 Markdown 中 mermaid 图表的渲染
javascript·python·flask
喵手1 小时前
Python爬虫实战:电商价格监控系统 - 从定时任务到历史趋势分析的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·电商价格监控系统·从定时任务到历史趋势分析·采集结果sqlite存储
喵手1 小时前
Python爬虫实战:京东/淘宝搜索多页爬虫实战 - 从反爬对抗到数据入库的完整工程化方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·京东淘宝页面数据采集·反爬对抗到数据入库·采集结果csv导出