PyTorch檔案生成機制中的FileManager.write_with_template

PyTorch檔案生成機制中的FileManager.write_with_template

前言

PyTorch中有些檔案是在編譯過程中跑腳本生成的,如.pyi檔是由.pyi.in檔生成,torch/csrc/autograd/generated目錄下的.cpp檔則是由tools/autograd/templates下的template .cpp檔生成的。

它們底層都是調用FileManager.write_with_template函數,其功能是對原檔案中的特定字串依照callback function所指示的方式做替換,進而生成對應的.pyi.cpp檔。

本文會先查看FileManager.write_with_template函數是如何被調用的,再細看它的實現。

FileManager.write_with_template調用

gen_pyi

tools/pyi/gen_pyi.py

cpp 复制代码
    fm.write_with_template(
        "torch/_C/__init__.pyi",
        "torch/_C/__init__.pyi.in",
        lambda: {
            "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_VF.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/return_types.pyi",
        "torch/_C/return_types.pyi.in",
        lambda: {
            "generated_comment": "@" + "generated from torch/_C/return_types.pyi",
            **env,
        },
    )
    gen_nn_functional(fm)

此處的四個fm.write_with_template會由torch/_C資料夾下的四個.pyi.in檔生成torch/_C資料夾下的__init__.pyi, _VariableFunctions.pyitorch資料夾下的_VF.pyi, return_types.pyi

gen_nn_functional

tools/pyi/gen_pyi.py

python 复制代码
def gen_nn_functional(fm: FileManager) -> None:
    # ...
    fm.write_with_template(
        "torch/nn/functional.pyi",
        "torch/nn/functional.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )
    # ...
    fm.write_with_template(
        "torch/_C/_nn.pyi",
        "torch/_C/_nn.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )

此處的兩個fm.write_with_template會由torch/nn/functional.pyi.intorch/_C/_nn.pyi.in生成torch/nn/functional.pyitorch/_C/_nn.pyi.in

write_sharded

torchgen/utils.py

python 复制代码
    def write_sharded(
        self,
        filename: str,
        items: Iterable[T],
        *,
        key_fn: Callable[[T], str],
        env_callable: Callable[[T], Dict[str, List[str]]],
        num_shards: int,
        base_env: Optional[Dict[str, Any]] = None,
        sharded_keys: Set[str],
    ) -> None:
        #...
        for shard in all_shards:
            shard_id = shard["shard_id"]
            self.write_with_template(
                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
            )
        #...

其中的all_shards為:

python 复制代码
[{'shard_id': 'Everything'}, {'shard_id': '_0'}, {'shard_id': '_1'}, {'shard_id': '_2'}]

所以這裡的write_with_template會由filenamepython_torch_functions.cpp生成python_torch_functionsEverything.cpp, python_torch_functions_0.cpp, python_torch_functions_1.cpppython_torch_functions_2.cpp四個檔案。

注意到上面三個例子中,write_with_template的第三個參數(env_callable)都是一個呼叫後會返回dict的lambda函數。

FileManager.write_with_template實現

torchgen/utils.py

FileManager.write_with_template

write_with_template除了self以外有三個參數:

  • filename:生成的.pyi的檔名或.cpp的檔名
  • template_fn:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • env_callable:在做替換時會用到的callback function
python 复制代码
    def write_with_template(
        self,
        filename: str,
        template_fn: str,
        env_callable: Callable[[], Union[str, Dict[str, Any]]],
    ) -> None:
        filename = "{}/{}".format(self.install_dir, filename)
        assert filename not in self.filenames, "duplicate file write {filename}"
        self.filenames.add(filename)
        if not self.dry_run:
            substitute_out = self.substitute_with_template(
                template_fn=template_fn,
                env_callable=env_callable,
            )
            self._write_if_changed(filename=filename, contents=substitute_out)

可以看到這段代碼最核心的內容就是調用substitute_with_template生成substitute_out

之後再將替換後的結果,也就是substitute_out寫入filename.pyi檔)這個檔案中。

注:在做類型檢查時,callback function是由typing.Callable表示的,詳見Python typing函式庫和torch.types

FileManager.substitute_with_template

torchgen/utils.py

self外有兩個參數:

  • template_fn:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • env_callable:在做替換時會用到的callback function
python 复制代码
    # Read from template file and replace pattern with callable (type could be dict or str).
    def substitute_with_template(
        self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
    ) -> str:
        template_path = os.path.join(self.template_dir, template_fn)
        env = env_callable()
        if isinstance(env, dict):
            # TODO: Update the comment reference to the correct location
            if "generated_comment" not in env:
                comment = "@" + "generated by torchgen/gen.py"
                comment += " from {}".format(os.path.basename(template_path))
                env["generated_comment"] = comment
            template = _read_template(template_path)
            return template.substitute(env)
        elif isinstance(env, str):
            return env
        else:
            assert_never(env)

env_callable是一個呼叫後會返回dict的lambda函數,所以會進入isinstance(env, dict)這個分支,先由_read_template讀入template檔案(.pyi.in檔或template .cpp檔)後調用template.substitute

_read_template

torchgen/utils.py

參數template_fnpyi或template cpp的檔名。

python 复制代码
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
    return CodeTemplate.from_file(template_fn)

讀入template_fn,生成CodeTemplate物件並回傳。

torchgen/code_template.py

CodeTemplate

torchgen/code_template.py

先來看看CodeTemplate類別的作用。

python 复制代码
# match $identifier or ${identifier} and replace with value in env
# If this identifier is at the beginning of whitespace on a line
# and its value is a list then it is treated as
# block substitution by indenting to that depth and putting each element
# of the list on its own line
# if the identifier is on a line starting with non-whitespace and a list
# then it is comma separated ${,foo} will insert a comma before the list
# if this list is not empty and ${foo,} will insert one after.


class CodeTemplate:
    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
    substitution = re.compile(substitution_str, re.MULTILINE)

    pattern: str
    filename: str
    
    # ...

注釋裡說明了CodeTemplate的功用是把模板中${identifier}字樣替換成env中對應的value。

torch/_C/_VariableFunctions.pyi.in中就有以下字樣:

python 复制代码
# ${generated_comment}
# ...
${function_hints}

${all_directive}

python_torch_functions.cpp中則有以下字樣:

cpp 复制代码
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
    
// ...
// generated forward declarations start here

${py_forwards}

// ...
static PyMethodDef torch_functions_shard[] = {
  ${py_method_defs}
};

// ...
// generated methods start here

${py_methods}

CodeTemplate.from_file

torchgen/code_template.py

python 复制代码
class CodeTemplate:
    # ...

    @staticmethod
    def from_file(filename: str) -> "CodeTemplate":
        with open(filename, "r") as f:
            return CodeTemplate(f.read(), filename)
        
    # ...

調用CodeTemplate的建構子,傳入filename的內容及名稱。

CodeTemplate.init

  • filename:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • pattern:在CodeTemplate.from_file中是以CodeTemplate(f.read(), filename)調用CodeTemplate建構子,所以pattern成員變數會被設為從filename檔案裡讀出來的東西
python 复制代码
class CodeTemplate:
    # ...
    
    def __init__(self, pattern: str, filename: str = "") -> None:
        self.pattern = pattern
        self.filename = filename
        
    # ...

substitute

torchgen/code_template.py

回顧torchgen/utils.pysubstitute_with_template中的:

python 复制代码
            template = _read_template(template_path)

生成了CodeTemplate物件template後繼續調用:

python 复制代码
            return template.substitute(env)

其功能是做一些正則替換:

python 复制代码
class CodeTemplate:
    # ...
    def substitute(
        self, env: Optional[Mapping[str, object]] = None, **kwargs: object
    ) -> str:
        if env is None:
            env = {}

        def lookup(v: str) -> object:
            assert env is not None
            return kwargs[v] if v in kwargs else env[v]

        def indent_lines(indent: str, v: Sequence[object]) -> str:
            return "".join(
                [indent + l + "\n" for e in v for l in str(e).splitlines()]
            ).rstrip()

        def replace(match: Match[str]) -> str:
            indent = match.group(1)
            key = match.group(2)
            comma_before = ""
            comma_after = ""
            if key[0] == "{":
                key = key[1:-1]
                if key[0] == ",":
                    comma_before = ", "
                    key = key[1:]
                if key[-1] == ",":
                    comma_after = ", "
                    key = key[:-1]
            v = lookup(key)
            if indent is not None:
                if not isinstance(v, list):
                    v = [v]
                return indent_lines(indent, v)
            elif isinstance(v, list):
                middle = ", ".join([str(x) for x in v])
                if len(v) == 0:
                    return middle
                return comma_before + middle + comma_after
            else:
                return str(v)

        return self.substitution.sub(replace, self.pattern)

函數最後的self.substitution.sub(replace, self.pattern)中的self.substitutionCodeTemplate的成員:

python 复制代码
    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
    substitution = re.compile(substitution_str, re.MULTILINE)

re.compile後得到的substitution是一個re.Pattern物件。

先來看看re.Pattern.sub是什麼,參考Passing a function to re.sub in PythonPython: re.compile and re.sub中給出的例子:

python 复制代码
import re
substitution = re.compile(r'\d')
number_mapping = {'1': 'one', '2': 'two', '3': 'three'}
s = "1 testing 2 3"
substitution.sub(lambda x: number_mapping[x.group()], s) # 'one testing two three'

re.Pattern.sub的第一個參數是做替換的函數,第二個參數則是欲處理的字串,它會尋找特定樣式的字串(此處是r'\d'),對它們做替換後回傳。

所以self.substitution.sub(replace, self.pattern)這句是在self.pattern(也就是pyi.in或template cpp檔中的內容)中尋找substitution_str樣式的字串,並用replace這個函數所指定的方式做替換。

得到替換後的結果後,回到substitute_with_template函數:

python 复制代码
            return template.substitute(env)

那裡繼續將結果回傳,來到write_with_template函數:

python 复制代码
            substitute_out = self.substitute_with_template(
                template_fn=template_fn,
                env_callable=env_callable,
            )
            self._write_if_changed(filename=filename, contents=substitute_out)

在那裡會把替換結果substitute_out寫入filename,也就是生成的.pyi的檔名或.cpp的檔名。

來看看torch/_C/_VariableFunctions.pyi中的${generated_comment}

回顧gen_pyi函數中呼叫write_with_template時,與env一同傳入了generated_comment的key value pair:

python 复制代码
    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )

所以到了substitute函數,env參數便是一個包含generated_comment的key value pair的字典。

# ${generated_comment}在做替換後,會變成生成的torch/_C/_VariableFunctions.pyi檔案中的第一行:

python 复制代码
# @generated from torch/_C/_VariableFunctions.pyi.in
相关推荐
数据智能老司机3 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机4 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机4 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机4 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i5 小时前
drf初步梳理
python·django
每日AI新事件5 小时前
python的异步函数
python
感哥5 小时前
C++ STL 常用算法
c++
这里有鱼汤6 小时前
miniQMT下载历史行情数据太慢怎么办?一招提速10倍!
前端·python
databook15 小时前
Manim实现脉冲闪烁特效
后端·python·动效
saltymilk15 小时前
C++ 模板参数推导问题小记(模板类的模板构造函数)
c++·模板元编程