XLA中生成Causal Mask上三角-inf矩阵

transformers生成CausalAttentionMask的上三角-inf矩阵:

参考transformers源码

python 复制代码
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import os

os.environ['PJRT_DEVICE']='IPU'
# os.environ['PJRT_DEVICE']='GPU'
# os.environ['XLA_FLAGS']='--xla_dump_to=gen_AttnFwd-XLA_GPU'

tgt_len = 10
dtype=torch.float32
device = xm.xla_device()

# src/transformers/modeling_attn_mask_utils.py#AttentionMaskConverter::_make_causal_mask
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
print(mask)
# print(mask.size())
# print(mask[3][3])

"""
2024-11-07 07:16:18.824506: F tensorflow/compiler/xla/service/hlo_computation.cc:70] Check failed: nullptr != root (nullptr vs. 0)
Aborted (core dumped)
"""

'''
module @SyncTensorsGraph.25 {
  func.func @main() -> tuple<tensor<10x10xf32>> {
    %0 = mhlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi64>
    %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xi64>) -> tensor<10x10xi64>
    %2 = mhlo.constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xi64>
    %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<10xi64>) -> tensor<10x10xi64>
    %4 = mhlo.compare  LT, %1, %3 : (tensor<10x10xi64>, tensor<10x10xi64>) -> tensor<10x10xi1>
    %5 = mhlo.constant dense<false> : tensor<i1>
    %6 = "mhlo.broadcast_in_dim"(%5) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i1>) -> tensor<10x10xi1>
    %7 = mhlo.compare  NE, %4, %6 : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<10x10xi1>
    %8 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %9 = "mhlo.broadcast_in_dim"(%8) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10x10xf32>
    %10 = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
    %11 = "mhlo.broadcast_in_dim"(%10) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10x10xf32>
    %12 = "mhlo.select"(%7, %9, %11) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
    %13 = "mhlo.tuple"(%12) {xla_shape = "(f32[10,10]{1,0})"} : (tensor<10x10xf32>) -> tuple<tensor<10x10xf32>>
    return %13 : tuple<tensor<10x10xf32>>
  }
}
'''

'''
XLA_GPU甚至给出了完整的mhlo实现:
gen_AttnFwd-XLA_GPU/module_0000.SyncTensorsGraph.25.sm_8.0_gpu_after_optimizations.txt

HloModule SyncTensorsGraph.25, entry_computation_layout={(f32[])->(f32[10,10]{1,0})}

fused_computation {
  iota.3 = s64[10,10]{1,0} iota(), iota_dimension=1
  iota.2 = s64[10]{0} iota(), iota_dimension=0
  constant_5 = s64[] constant(1)
  broadcast.7 = s64[10]{0} broadcast(constant_5), dimensions={}
  add.0 = s64[10]{0} add(iota.2, broadcast.7)
  broadcast.6 = s64[10,10]{1,0} broadcast(add.0), dimensions={0}
  compare.1 = pred[10,10]{1,0} compare(iota.3, broadcast.6), direction=LT
  constant_3 = pred[] constant(false)
  broadcast.4 = pred[10,10]{1,0} broadcast(constant_3), dimensions={}
  compare.0 = pred[10,10]{1,0} compare(compare.1, broadcast.4), direction=NE
  constant_0 = f32[] constant(0)
  broadcast.3 = f32[10,10]{1,0} broadcast(constant_0), dimensions={}
  param_0.1 = f32[] parameter(0)
  broadcast.2 = f32[10,10]{1,0} broadcast(param_0.1), dimensions={}
  ROOT select.0 = f32[10,10]{1,0} select(compare.0, broadcast.3, broadcast.2)
}

ENTRY SyncTensorsGraph.25 {
  p0.13 = f32[] parameter(0)
  fusion = f32[10,10]{1,0} fusion(p0.13), kind=kLoop, calls=fused_computation
  ROOT tuple.24 = (f32[10,10]{1,0}) tuple(fusion)
}

-----
INFO:torch_xla:Letting libtpu.so load fail during _XLAC import. libtpu.so will be loaded from `libtpu` Python package when the ComputationClient is created.
2024-11-07 11:50:41.174644: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x905c190 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-11-07 11:50:41.174714: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0
2024-11-07 11:50:41.175641: I tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.cc:194] Using BFC allocator.
2024-11-07 11:50:41.175713: I tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 75175958937 bytes on device 0 for BFCAllocator.
2024-11-07 11:50:42.013482: I tensorflow/compiler/xla/service/dump.cc:485] HloModule dump enabled with path prefix: , suffix: before_optimizations
2024-11-07 11:50:42.037845: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
tensor([[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
         -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
         -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
         -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
         -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
       device='xla:0')

'''
相关推荐
小毕超12 分钟前
基于 PyTorch 从零手搓一个GPT Transformer 对话大模型
pytorch·gpt·transformer
铖铖的花嫁12 小时前
基于RNNs(LSTM, GRU)的红点位置检测(pytorch)
pytorch·gru·lstm
YRr YRr12 小时前
ubuntu20.04 解决Pytorch默认安装CPU版本的问题
人工智能·pytorch·python
代码猪猪傻瓜coding13 小时前
pytorch torch.tile用法
人工智能·pytorch·python
sduerfh15 小时前
pytorch3d导入maya相机位姿踩坑
pytorch·3d·maya
跟德姆(dom)一起学AI17 小时前
0基础跟德姆(dom)一起学AI 深度学习02-Pytorch基本使用
开发语言·人工智能·pytorch·python·深度学习·机器学习
人工智障调包侠19 小时前
Pytorch从0复现worc2vec skipgram模型及fasttext训练维基百科语料词向量演示
人工智能·pytorch·自然语言处理·nlp·word2vec·词向量·skipgram
love_and_hope19 小时前
Pytorch学习--神经网络--利用GPU训练
人工智能·pytorch·python·神经网络·学习·数据挖掘
Deepcong1 天前
多个摄像机画面融合:找到同一个目标在多个画面中的伪三维坐标,找出这几个摄像头间的转换矩阵
人工智能·线性代数·矩阵
LifeBackwards1 天前
Pytorch如何将嵌套的dict类型数据加载到GPU
pytorch·深度学习