PyTorch檔案生成機制中的FileManager.write_with_template
前言
PyTorch中有些檔案是在編譯過程中跑腳本生成的,如.pyi
檔是由.pyi.in
檔生成,torch/csrc/autograd/generated
目錄下的.cpp
檔則是由tools/autograd/templates
下的template .cpp
檔生成的。
它們底層都是調用FileManager.write_with_template
函數,其功能是對原檔案中的特定字串依照callback function所指示的方式做替換,進而生成對應的.pyi
或.cpp
檔。
本文會先查看FileManager.write_with_template
函數是如何被調用的,再細看它的實現。
FileManager.write_with_template調用
gen_pyi
tools/pyi/gen_pyi.py
cpp
fm.write_with_template(
"torch/_C/__init__.pyi",
"torch/_C/__init__.pyi.in",
lambda: {
"generated_comment": "@" + "generated from torch/_C/__init__.pyi.in",
**env,
},
)
fm.write_with_template(
"torch/_C/_VariableFunctions.pyi",
"torch/_C/_VariableFunctions.pyi.in",
lambda: {
"generated_comment": "@"
+ "generated from torch/_C/_VariableFunctions.pyi.in",
**env,
},
)
fm.write_with_template(
"torch/_VF.pyi",
"torch/_C/_VariableFunctions.pyi.in",
lambda: {
"generated_comment": "@"
+ "generated from torch/_C/_VariableFunctions.pyi.in",
**env,
},
)
fm.write_with_template(
"torch/return_types.pyi",
"torch/_C/return_types.pyi.in",
lambda: {
"generated_comment": "@" + "generated from torch/_C/return_types.pyi",
**env,
},
)
gen_nn_functional(fm)
此處的四個fm.write_with_template
會由torch/_C
資料夾下的四個.pyi.in
檔生成torch/_C
資料夾下的__init__.pyi
, _VariableFunctions.pyi
和torch
資料夾下的_VF.pyi
, return_types.pyi
。
gen_nn_functional
tools/pyi/gen_pyi.py
python
def gen_nn_functional(fm: FileManager) -> None:
# ...
fm.write_with_template(
"torch/nn/functional.pyi",
"torch/nn/functional.pyi.in",
lambda: {
"imported_hints": import_code,
"dispatched_hints": dispatch_code,
},
)
# ...
fm.write_with_template(
"torch/_C/_nn.pyi",
"torch/_C/_nn.pyi.in",
lambda: {
"imported_hints": import_code,
"dispatched_hints": dispatch_code,
},
)
此處的兩個fm.write_with_template
會由torch/nn/functional.pyi.in
及torch/_C/_nn.pyi.in
生成torch/nn/functional.pyi
和torch/_C/_nn.pyi.in
。
write_sharded
torchgen/utils.py
python
def write_sharded(
self,
filename: str,
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str],
) -> None:
#...
for shard in all_shards:
shard_id = shard["shard_id"]
self.write_with_template(
f"{base_filename}{shard_id}{extension}", filename, lambda: shard
)
#...
其中的all_shards
為:
python
[{'shard_id': 'Everything'}, {'shard_id': '_0'}, {'shard_id': '_1'}, {'shard_id': '_2'}]
所以這裡的write_with_template
會由filename
即python_torch_functions.cpp
生成python_torch_functionsEverything.cpp
, python_torch_functions_0.cpp
, python_torch_functions_1.cpp
和python_torch_functions_2.cpp
四個檔案。
注意到上面三個例子中,write_with_template
的第三個參數(env_callable
)都是一個呼叫後會返回dict
的lambda函數。
FileManager.write_with_template實現
torchgen/utils.py
FileManager.write_with_template
write_with_template
除了self
以外有三個參數:
filename
:生成的.pyi
的檔名或.cpp
的檔名template_fn
:作為輸入的.pyi.in
的檔名或template.cpp
的檔名env_callable
:在做替換時會用到的callback function
python
def write_with_template(
self,
filename: str,
template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]],
) -> None:
filename = "{}/{}".format(self.install_dir, filename)
assert filename not in self.filenames, "duplicate file write {filename}"
self.filenames.add(filename)
if not self.dry_run:
substitute_out = self.substitute_with_template(
template_fn=template_fn,
env_callable=env_callable,
)
self._write_if_changed(filename=filename, contents=substitute_out)
可以看到這段代碼最核心的內容就是調用substitute_with_template
生成substitute_out
。
之後再將替換後的結果,也就是substitute_out
寫入filename
(.pyi
檔)這個檔案中。
注:在做類型檢查時,callback function是由typing.Callable表示的,詳見Python typing函式庫和torch.types。
FileManager.substitute_with_template
torchgen/utils.py
除self
外有兩個參數:
template_fn
:作為輸入的.pyi.in
的檔名或template.cpp
的檔名env_callable
:在做替換時會用到的callback function
python
# Read from template file and replace pattern with callable (type could be dict or str).
def substitute_with_template(
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
) -> str:
template_path = os.path.join(self.template_dir, template_fn)
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if "generated_comment" not in env:
comment = "@" + "generated by torchgen/gen.py"
comment += " from {}".format(os.path.basename(template_path))
env["generated_comment"] = comment
template = _read_template(template_path)
return template.substitute(env)
elif isinstance(env, str):
return env
else:
assert_never(env)
env_callable
是一個呼叫後會返回dict
的lambda函數,所以會進入isinstance(env, dict)
這個分支,先由_read_template
讀入template檔案(.pyi.in
檔或template .cpp
檔)後調用template.substitute
。
_read_template
torchgen/utils.py
參數template_fn
為pyi
或template cpp
的檔名。
python
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn)
讀入template_fn
,生成CodeTemplate
物件並回傳。
torchgen/code_template.py
CodeTemplate
torchgen/code_template.py
先來看看CodeTemplate
類別的作用。
python
# match $identifier or ${identifier} and replace with value in env
# If this identifier is at the beginning of whitespace on a line
# and its value is a list then it is treated as
# block substitution by indenting to that depth and putting each element
# of the list on its own line
# if the identifier is on a line starting with non-whitespace and a list
# then it is comma separated ${,foo} will insert a comma before the list
# if this list is not empty and ${foo,} will insert one after.
class CodeTemplate:
substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
substitution = re.compile(substitution_str, re.MULTILINE)
pattern: str
filename: str
# ...
注釋裡說明了CodeTemplate
的功用是把模板中${identifier}
字樣替換成env
中對應的value。
在torch/_C/_VariableFunctions.pyi.in
中就有以下字樣:
python
# ${generated_comment}
# ...
${function_hints}
${all_directive}
在python_torch_functions.cpp
中則有以下字樣:
cpp
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
// ...
// generated forward declarations start here
${py_forwards}
// ...
static PyMethodDef torch_functions_shard[] = {
${py_method_defs}
};
// ...
// generated methods start here
${py_methods}
CodeTemplate.from_file
torchgen/code_template.py
python
class CodeTemplate:
# ...
@staticmethod
def from_file(filename: str) -> "CodeTemplate":
with open(filename, "r") as f:
return CodeTemplate(f.read(), filename)
# ...
調用CodeTemplate
的建構子,傳入filename
的內容及名稱。
CodeTemplate.init
filename
:作為輸入的.pyi.in
的檔名或template.cpp
的檔名pattern
:在CodeTemplate.from_file
中是以CodeTemplate(f.read(), filename)
調用CodeTemplate
建構子,所以pattern
成員變數會被設為從filename
檔案裡讀出來的東西
python
class CodeTemplate:
# ...
def __init__(self, pattern: str, filename: str = "") -> None:
self.pattern = pattern
self.filename = filename
# ...
substitute
torchgen/code_template.py
回顧torchgen/utils.py
的substitute_with_template
中的:
python
template = _read_template(template_path)
生成了CodeTemplate
物件template
後繼續調用:
python
return template.substitute(env)
其功能是做一些正則替換:
python
class CodeTemplate:
# ...
def substitute(
self, env: Optional[Mapping[str, object]] = None, **kwargs: object
) -> str:
if env is None:
env = {}
def lookup(v: str) -> object:
assert env is not None
return kwargs[v] if v in kwargs else env[v]
def indent_lines(indent: str, v: Sequence[object]) -> str:
return "".join(
[indent + l + "\n" for e in v for l in str(e).splitlines()]
).rstrip()
def replace(match: Match[str]) -> str:
indent = match.group(1)
key = match.group(2)
comma_before = ""
comma_after = ""
if key[0] == "{":
key = key[1:-1]
if key[0] == ",":
comma_before = ", "
key = key[1:]
if key[-1] == ",":
comma_after = ", "
key = key[:-1]
v = lookup(key)
if indent is not None:
if not isinstance(v, list):
v = [v]
return indent_lines(indent, v)
elif isinstance(v, list):
middle = ", ".join([str(x) for x in v])
if len(v) == 0:
return middle
return comma_before + middle + comma_after
else:
return str(v)
return self.substitution.sub(replace, self.pattern)
函數最後的self.substitution.sub(replace, self.pattern)
中的self.substitution
是CodeTemplate
的成員:
python
substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
substitution = re.compile(substitution_str, re.MULTILINE)
re.compile
後得到的substitution
是一個re.Pattern
物件。
先來看看re.Pattern.sub
是什麼,參考Passing a function to re.sub in Python及Python: re.compile and re.sub中給出的例子:
python
import re
substitution = re.compile(r'\d')
number_mapping = {'1': 'one', '2': 'two', '3': 'three'}
s = "1 testing 2 3"
substitution.sub(lambda x: number_mapping[x.group()], s) # 'one testing two three'
re.Pattern.sub
的第一個參數是做替換的函數,第二個參數則是欲處理的字串,它會尋找特定樣式的字串(此處是r'\d'
),對它們做替換後回傳。
所以self.substitution.sub(replace, self.pattern)
這句是在self.pattern
(也就是pyi.in
或template cpp
檔中的內容)中尋找substitution_str
樣式的字串,並用replace
這個函數所指定的方式做替換。
得到替換後的結果後,回到substitute_with_template
函數:
python
return template.substitute(env)
那裡繼續將結果回傳,來到write_with_template
函數:
python
substitute_out = self.substitute_with_template(
template_fn=template_fn,
env_callable=env_callable,
)
self._write_if_changed(filename=filename, contents=substitute_out)
在那裡會把替換結果substitute_out
寫入filename
,也就是生成的.pyi
的檔名或.cpp
的檔名。
來看看torch/_C/_VariableFunctions.pyi
中的${generated_comment}
。
回顧gen_pyi
函數中呼叫write_with_template
時,與env
一同傳入了generated_comment
的key value pair:
python
fm.write_with_template(
"torch/_C/_VariableFunctions.pyi",
"torch/_C/_VariableFunctions.pyi.in",
lambda: {
"generated_comment": "@"
+ "generated from torch/_C/_VariableFunctions.pyi.in",
**env,
},
)
所以到了substitute
函數,env
參數便是一個包含generated_comment
的key value pair的字典。
# ${generated_comment}
在做替換後,會變成生成的torch/_C/_VariableFunctions.pyi
檔案中的第一行:
python
# @generated from torch/_C/_VariableFunctions.pyi.in