方法一:
使用这种方式FPN(encoder=maxvit_small_tf_224(pretrained_cfg_overlay=dict(file=cached_safe_file),features_only=True,pretrained=False),读取bin文件
python
cached_safe_file = "/home/xx/myProject/data/pytorch_weights/maxvit_small_tf_224.bin"
# args.net.encoder.model = timm.create_model(
# "maxvit_base_tf_224",pretrained=True,
# pretrained_cfg_overlay=dict(file=cached_safe_file),features_only=True)
from Netwroks.segmentation.segmentation.decoders.fpn.transformer_decoder import FPN
from Netwroks.segmentation.encoder.model2 import maxvit_base_tf_224,maxvit_tiny_tf_224,maxvit_small_tf_224
args.net = FPN(encoder=maxvit_small_tf_224(pretrained_cfg_overlay=dict(file=cached_safe_file),features_only=True,pretrained=False),
in_channels=3,
classes=1,
activation='sigmoid' )
方法二
下面这种方式也可以,读取model.safetensors
python
args.net = smp.FPN(
encoder_name="tu-maxvit_base_tf_224",#'tu-maxvit_base_tf_224', # 选择解码器, 例如 mobilenet_v2 或 efficientnet-b7
encoder_weights=None, # 使用预先训练的权重imagenet进行解码器初始化
in_channels=3, # 模型输入通道(1个用于灰度图像,3个用于RGB等)
classes=1,
activation='sigmoid' # 模型输出通道(数据集所分的类别总数)
)
import timm
cached_safe_file = "/home/xx/.cache/huggingface/hub/models--timm--maxvit_base_tf_224.in1k/snapshots/ede2304fc169e23779075c092a477302f23660f8/model.safetensors"
args.net.encoder.model = timm.create_model(
"maxvit_base_tf_224",pretrained=True,
pretrained_cfg_overlay=dict(file=cached_safe_file),features_only=True)