Python classmethod()
该方法返回给定函数的类方法 。
classmethod()
什么是类方法?
类方法是绑定到类而不是其对象的方法。它不需要创建类实例,就像静态方法一样。
静态方法和类方法的区别在于:
- 静态方法对类一无所知,只处理参数;
- 类方法适用于类,因为它的参数始终是类本身;
- 类方法既可以由类调用,也可以由其对象调用。
下面是一个例子:
python
class damn():
total_damage = 0
def compute_damage(self, gears: int, base_damage: int):
self.total_damage = base_damage**gears
print(f"Gears{gears}! 造成了{self.total_damage}点伤害!")
damn.damage = classmethod(damn.compute_damage)
damn.damage(3, 64)
# Outputs: Gears3! 造成了262144点伤害!
@classmethod
python
# 用法
@classmethod
def func(cls, args...)
# cls接受类作为参数,而不是类的对象或者实例
例子:
python
from datetime import date
# random Person
class Person:
def __init__(self, name, age):
self.name = name
self.age = age
@classmethod
def fromBirthYear(cls, name, birthYear):
return cls(name, date.today().year - birthYear)
# 返回带有初始化参数的Person实例
def display(self):
print(self.name + "'s age is: " + str(self.age))
person = Person('luffy', 25)
person.display()
person1 = Person.fromBirthYear('Kaido', 1900)
person1.display()
# OUTPUTS:
luffy's age is: 25
Kaido's age is: 124
可以看出,它类似与C++中的函数重载。
在深度学习模型代码中用于超参数载入
继续举例子:
python
class luffy():
def __init__(
self,
enemy_name: str,
gears: int,
base_damage: int,
enemy_health: int,
) -> None:
self.enemy_name = enemy_name
self.gears = gears
self.base_damage = base_damage
self.enemy_health = enemy_health
def print_info(self):
print(f"enemy_name: {self.enemy_name}, gears: {self.gears}, base_damage: {self.base_damage}, enemy_health: {self.enemy_health}")
def battle_info(self):
num = self.base_damage**self.gears
if num >= self.enemy_health:
print(f"路飞使用了{self.gears}档攻击,造成了{num}点伤害,{self.enemy_name}被击溃了!")
else:
print(f"路飞使用了{self.gears}档攻击,造成了{num}点伤害,{self.enemy_name}还剩下{self.enemy_health - num}点生命值!")
@classmethod
def init_from_class_method(cls, luffyname, conf):
args = {}
args['enemy_name'] = conf[luffyname]['enemy_name']
args['gears'] = conf[luffyname]['gears']
args['base_damage'] = conf[luffyname]['base_damage']
args['enemy_health'] = conf[luffyname]['enemy_health']
return cls(**args)
以上代码定义了一个类,用来描述一场对决信息,实例中的数值则存放在config.yaml
配置文件中:
yaml
luffy1:
enemy_name: "BigMom"
gears: 3
base_damage: 12
enemy_health: 1500
luffy2:
enemy_name: "Kaido"
gears: 4
base_damage: 12
enemy_health: 25000
luffy3:
enemy_name: "Kaido"
gears: 5
base_damage: 12
enemy_health: 25000
我们把超参加载的部分封装在类方法中,接下来是如何实例加载演示:
python
config_path='practice/config.yaml'
conf = yaml.safe_load(open(config_path))
luffy1 = luffy.init_from_class_method('luffy1', conf)
luffy1.battle_info()
luffy2 = luffy.init_from_class_method('luffy2', conf)
luffy2.battle_info()
luffy3 = luffy.init_from_class_method('luffy3', conf)
luffy3.battle_info()
终端输出结果如下:
powershell
路飞使用了3档攻击,造成了1728点伤害,BigMom被击溃了!
路飞使用了4档攻击,造成了20736点伤害,Kaido还剩下4264点生命值!
路飞使用了5档攻击,造成了248832点伤害,Kaido被击溃了!
代码范例
python
class M2T2(nn.Module):
def __init__(
self,
backbone: nn.Module,
transformer: nn.Module,
object_encoder: nn.Module = None,
grasp_mlp: nn.Module = None,
set_criterion: nn.Module = None,
grasp_criterion: nn.Module = None,
place_criterion: nn.Module = None
):
super(M2T2, self).__init__()
self.backbone = backbone
self.object_encoder = object_encoder
self.transformer = transformer
self.grasp_mlp = grasp_mlp
self.set_criterion = set_criterion
self.grasp_criterion = grasp_criterion
self.place_criterion = place_criterion
@classmethod
def from_config(cls, cfg):
args = {}
args['backbone'] = PointNet2MSG.from_config(cfg.scene_encoder)
channels = args['backbone'].out_channels
obj_channels = None
if cfg.contact_decoder.num_place_queries > 0:
args['object_encoder'] = PointNet2MSGCls.from_config(
cfg.object_encoder
)
obj_channels = args['object_encoder'].out_channels
args['place_criterion'] = PlaceCriterion.from_config(
cfg.place_loss
)
args['transformer'] = ContactDecoder.from_config(
cfg.contact_decoder, channels, obj_channels
)
if cfg.contact_decoder.num_grasp_queries > 0:
args['grasp_mlp'] = ActionDecoder.from_config(
cfg.action_decoder, args['transformer']
)
matcher = HungarianMatcher.from_config(cfg.matcher)
args['set_criterion'] = SetCriterion.from_config(
cfg.grasp_loss, matcher
)
args['grasp_criterion'] = GraspCriterion.from_config(
cfg.grasp_loss
)
return cls(**args)
...
模型创建与加载部分代码:
python
model = M2T2.from_config(cfg.m2t2)
ckpt = torch.load(cfg.eval.checkpoint)
model.load_state_dict(ckpt['model'])
model = model.cuda().eval()