笔记:TREX工具-4

TREx API 示例

trex 提供了一系列 Python API,方便在代码中直接查询和分析引擎 plan 的数据。

加载引擎

python 复制代码
import trex
from trex.notebook import *

engine_name = "../tests/inputs/mobilenet.qat.onnx.engine"
plan = trex.EnginePlan(
    f"{engine_name}.graph.json",
    f"{engine_name}.profile.json",
    f"{engine_name}.metadata.json"
)

EnginePlan 加载后,核心数据结构是 plan.df --- 一个 pandas DataFrame,每行代表一个层,包含层的名称、类型、延迟、精度、tactic 等字段。

列出 k 个最慢的层

python 复制代码
top3 = plan.df.nlargest(3, 'latency.pct_time')
for i in range(len(top3)):
    layer = top3.iloc[i]
    print("%s: %s" % (layer["Name"], layer["type"]))

输出:

复制代码
features.15.conv.2.weight + QuantizeLinear_722 + Conv_726 + Add_728: Convolution
features.16.conv.2.weight + QuantizeLinear_771 + Conv_775 + Add_777: Convolution
features.13.conv.2.weight + QuantizeLinear_625 + Conv_629 + Add_631: Convolution

计算这 3 个最慢层的延迟总和:

python 复制代码
top3_latency = top3['latency.avg_time'].sum()
top3_percent = top3['latency.pct_time'].sum()
print(f"top3 latency: {top3_latency:.6f} ms ({top3_percent:.2f}%)")
# top3 latency: 0.045236 ms (9.62%)

按类型查询层

用 pandas query 按层类型过滤:

python 复制代码
ltype = "Convolution"
convs = plan.df.query(f"type == \"{ltype}\"")
print(f"There are {len(convs)} convolutions")
print(convs['latency.avg_time'].median())
# There are 53 convolutions
# 0.00586459

也可以用 trex 内置方法,效果相同:

python 复制代码
convs2 = plan.get_layers_by_type('Convolution')
print(f"There are {len(convs2)} convolutions")

访问层激活

层的输入和输出信息有多种访问方式。

直接从 DataFrame 读取原始数据:

python 复制代码
print(convs.iloc[0]['Inputs'])
# [{'Name': '317', 'Location': 'Device', 'Dimensions': [1, 3, 224, 224],
#   'Format/Datatype': 'Four wide channel vectorized row major Int8 format'}]

clean_df 简化显示:

python 复制代码
clean_convs = trex.clean_df(convs2.copy(), inplace=True)
clean_convs.iloc[0]['Inputs']
# 'Int8 NC/4HW4'

作为 Activation 实例,获取结构化字段:

python 复制代码
inputs, outputs = trex.create_activations(convs.iloc[0])
print(inputs[0].name)       # 317
print(inputs[0].shape)      # [1, 3, 224, 224]
print(inputs[0].precision)  # INT8
print(inputs[0].format)     # Int8 NC/4HW4
print(inputs[0].size_bytes) # 150528
属性 说明
name 激活 tensor 的名称
shape tensor 的维度
precision 数据类型(FP32/FP16/INT8)
format 内存布局格式(NCHW/NC/4HW4 等)
size_bytes tensor 大小(字节)

查询和分组

trex 底层是 pandas,可以直接用 pandas 的 groupby 做聚合分析:

python 复制代码
# 按类型分组,汇总延迟
plan.df.groupby(["type"])[["latency.avg_time", "latency.pct_time"]].sum()
type latency.avg_time latency.pct_time
Convolution 0.383324 81.49%
Pooling 0.004525 0.96%
Reformat 0.082539 17.55%

trex 也提供了更简洁的封装:

python 复制代码
# 按类型汇总延迟(与上面等价)
trex.group_sum_attr(plan.df, "type", "latency.avg_time")

# 按类型统计层数
trex.group_count(plan.df, "type")
type count
Convolution 53
Pooling 1
Reformat 18

常用 API 速查

API 作用
plan.df 获取所有层的 DataFrame
plan.df.nlargest(k, 'latency.pct_time') 列出 k 个最慢的层
plan.get_layers_by_type(name) 按类型过滤层
trex.create_activations(layer) 获取层的输入/输出 Activation 对象
trex.clean_df(df) 简化 DataFrame 显示(格式化精度/格式字段)
trex.group_sum_attr(df, group_col, val_col) 按列分组并求和
trex.group_count(df, col) 按列分组并计数

这些 API 就是在前面几篇中 report_cardcompare_engines 各种图表的底层数据来源。直接用 API 可以按需做自定义分析,比如只导出某类层的数据、写脚本批量对比多个引擎的某几项指标等。

参考:

deepseek

Tensorrt git 仓库

相关推荐
半导体守望者2 小时前
MKS Profibus-DP 接口等离子发生器Plasma Generators EIite
经验分享·笔记·机器人·自动化·制造
玄米乌龙茶1232 小时前
思维导图笔记:模型微调技术
笔记
叶~小兮2 小时前
Jenkins构建生产CICD环境学习笔记
笔记·学习·jenkins
暴躁小师兄数据学院2 小时前
【AI大模型应用开发工程师特训笔记】第04讲(第4章):运算符
人工智能·笔记·机器学习
问心无愧05132 小时前
ctf show web 入门258
android·前端·笔记
xxl大卡3 小时前
Redis完整详细学习笔记
redis·笔记·学习
魔都大虾3 小时前
六月北京小吃是什么 有那些特色
笔记
一口吃俩胖子4 小时前
【脉宽调制DCDC功率变换学习笔记022】DCDC变换器的稳定性、奈奎斯特准则、增益裕度和相位裕度
笔记·学习
GLDbalala4 小时前
GPU PRO 5 - 2.3 Volumetric Light Effects in Killzone: Shadow Fall 笔记
笔记