From 7486fe1cc63b184e808c09482e7cd0c99b696776 Mon Sep 17 00:00:00 2001 From: zhangwenwen Date: Thu, 6 May 2021 23:03:23 +0800 Subject: [PATCH] convert pytorch model to paddlepaddle model --- .gitignore | 1 + tools/pytorch2paddlepaddle.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tools/pytorch2paddlepaddle.py diff --git a/.gitignore b/.gitignore index b6e4761..e4a1406 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/tools/pytorch2paddlepaddle.py b/tools/pytorch2paddlepaddle.py new file mode 100644 index 0000000..56d9aa1 --- /dev/null +++ b/tools/pytorch2paddlepaddle.py @@ -0,0 +1,28 @@ +import torch +import numpy as np +from models.swin_transformer import SwinTransformer +# 构建输入 +input_data = np.random.rand(1, 3, 224, 224).astype("float32") + + +swin_model_cfg_map = { + "swin_tiny_patch4_window7_224": { + "EMBED_DIM": 96, + "DEPTHS": [ 2, 2, 6, 2 ], + "NUM_HEADS": [ 3, 6, 12, 24 ], + "WINDOW_SIZE": 7, + } +} + +model_name = "swin_tiny_patch4_window7_224" +torch_module = SwinTransformer(**swin_model_cfg_map[model_name]) +torch_state_dict = torch.load("/home/andy/data/pretrained_models/{}.pth".format(model_name))["model"] +torch_module.load_state_dict(torch_state_dict) +# 设置为eval模式 +torch_module.eval() +# 进行转换 +from x2paddle.convert import pytorch2paddle +pytorch2paddle(torch_module, + save_dir="pd_{}".format(model_name), + jit_type="trace", + input_examples=[torch.tensor(input_data)]) \ No newline at end of file