convert pytorch model to paddlepaddle model

This commit is contained in:
zhangwenwen 2021-05-06 23:03:23 +08:00
parent eec0e76fdb
commit 7486fe1cc6
2 changed files with 29 additions and 0 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@ -0,0 +1,28 @@
import torch
import numpy as np
from models.swin_transformer import SwinTransformer
# 构建输入
input_data = np.random.rand(1, 3, 224, 224).astype("float32")
swin_model_cfg_map = {
"swin_tiny_patch4_window7_224": {
"EMBED_DIM": 96,
"DEPTHS": [ 2, 2, 6, 2 ],
"NUM_HEADS": [ 3, 6, 12, 24 ],
"WINDOW_SIZE": 7,
}
}
model_name = "swin_tiny_patch4_window7_224"
torch_module = SwinTransformer(**swin_model_cfg_map[model_name])
torch_state_dict = torch.load("/home/andy/data/pretrained_models/{}.pth".format(model_name))["model"]
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_module,
save_dir="pd_{}".format(model_name),
jit_type="trace",
input_examples=[torch.tensor(input_data)])