Compare commits

...

12 Commits
LR-Net ... main

Author SHA1 Message Date
7486fe1cc6 convert pytorch model to paddlepaddle model 2021-05-06 23:03:23 +08:00
Han Hu
eec0e76fdb
Merge pull request #41 from kamalkraj/patch-1
Update README.md
2021-04-21 19:33:08 +08:00
Han Hu
008d5afc38
More explanation on adding links to 3rd party repo 2021-04-21 13:06:40 +08:00
Kamal Raj
0fccdb0b6f
Update README.md 2021-04-20 21:33:14 +05:30
Ze Liu
081385b99d
fix link to Swin-L-IN22K 2021-04-16 12:12:28 +08:00
Han Hu
110fb0e4e9
Merge pull request #27 from caoyue10/patch-2
Add a cross link to third party usage
2021-04-16 11:01:52 +08:00
Yue Cao
fa35a5e435
Add a cross link to third party usage 2021-04-16 10:04:48 +08:00
Han Hu
bd122d3b42
Add cross link to third party usage 2021-04-14 22:06:53 +08:00
Ze Liu
6cc8ebd200
minor bug fixes (#25) 2021-04-14 20:50:10 +08:00
Han Hu
c7e6d6efbd
Add explanation of the name "Swin" 2021-04-13 14:06:04 +08:00
Ze Liu
657aeeeefa
Merge pull request #13 from caoyue10/patch-1
Fix a typo.
2021-04-13 11:42:51 +08:00
Yue Cao
5e785157ed
Fix a typo. 2021-04-13 11:39:23 +08:00
5 changed files with 48 additions and 3 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@ -23,19 +23,19 @@ This repo is the official implementation of ["Swin Transformer: Hierarchical Vis
Initial commits:
1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
## Introduction
**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
**Swin Transformer** (the name `Swin` stands for **S**hifted **win**dow) is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
computation to non-overlapping local windows while also allowing for cross-window connection.
Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
ADE20K semantic segmentatiion (`53.5 mIoU` on val), surpassing previous models by a large margin.
ADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by a large margin.
![teaser](figures/teaser.png)
@ -102,6 +102,18 @@ Note: <sup>*</sup> indicates multi-scale testing.
- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
## Third-party Usage and Experiments
***In this pargraph, we cross link third-party repositories which use Swin and report results. You can let us know by raising an issue***
(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)
[04/14/2021] Swin for RetinaNet in Detectron: https://github.com/xiaohu2015/SwinT_detectron2.
[04/16/2021] Included in a famous model zoo: https://github.com/rwightman/pytorch-image-models.
[04/20/2021] Swin-Transformer classifier inference using TorchServe: https://github.com/kamalkraj/Swin-Transformer-Serve
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a

View File

@ -110,7 +110,9 @@ def main(config):
if resume_file:
if config.MODEL.RESUME:
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
config.defrost()
config.MODEL.RESUME = resume_file
config.freeze()
logger.info(f'auto resuming from {resume_file}')
else:
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')

View File

@ -0,0 +1,28 @@
import torch
import numpy as np
from models.swin_transformer import SwinTransformer
# 构建输入
input_data = np.random.rand(1, 3, 224, 224).astype("float32")
swin_model_cfg_map = {
"swin_tiny_patch4_window7_224": {
"EMBED_DIM": 96,
"DEPTHS": [ 2, 2, 6, 2 ],
"NUM_HEADS": [ 3, 6, 12, 24 ],
"WINDOW_SIZE": 7,
}
}
model_name = "swin_tiny_patch4_window7_224"
torch_module = SwinTransformer(**swin_model_cfg_map[model_name])
torch_state_dict = torch.load("/home/andy/data/pretrained_models/{}.pth".format(model_name))["model"]
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_module,
save_dir="pd_{}".format(model_name),
jit_type="trace",
input_examples=[torch.tensor(input_data)])

View File

@ -29,7 +29,9 @@ def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
amp.load_state_dict(checkpoint['amp'])
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")