diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b6e4761
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,129 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/README.md b/README.md
index 4ae38dd..76299ef 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,89 @@
# Swin Transformer
-By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://sites.google.com/site/hanhushomepage/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
+[](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-hierarchical-vision)
+[](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-hierarchical-vision)
+[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=swin-transformer-hierarchical-vision)
+[](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=swin-transformer-hierarchical-vision)
+[](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-hierarchical-vision)
+[](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k-val?p=swin-transformer-hierarchical-vision)
-This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/abs/2103.14030). The code will be coming soon.
+By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
+
+This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf). It currently includes code and models for the following tasks:
+
+> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.
+
+> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
+
+> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
+
+## Updates
+
+***04/12/2021***
+
+Initial commits:
+
+1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
+2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
+3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
## Introduction
-**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text. To address these differences, we propose a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at various scales and has linear computational complexity with respect to image size.
-These qualities of Swin Transformer make it compatible with a broad range of vision tasks, including image classification (86.4 top-1 accuracy on ImageNet-1K) and dense prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation (53.5 mIoU on ADE20K val).
+**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
+general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
+computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
+computation to non-overlapping local windows while also allowing for cross-window connection.
+
+Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
+ADE20K semantic segmentatiion (`53.5 mIoU` on val), surpassing previous models by a large margin.
+
+
+
+## Main Results on ImageNet with Pretrained Models
+
+**ImageNet-1K and ImageNet-22K Pretrained Models**
+
+| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
+| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
+| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
+| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
+| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
+| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
+| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
+| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
+| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
+
+Note: access code for `baidu` is `swin`.
+
+## Main Results on Downstream Tasks
+
+**COCO Object Detection (2017 val)**
+
+| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |
+| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |
+| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |
+| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G |
+| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G |
+| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |
+| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |
+| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |
+| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |
+| Swin-L | HTC++* | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |
+
+Note: * indicates multi-scale testing.
+
+**ADE20K Semantic Segmentation (val)**
+
+| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |
+| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |
+| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |
+| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |
+| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |
## Citing Swin Transformer
@@ -20,6 +96,12 @@ These qualities of Swin Transformer make it compatible with a broad range of vis
}
```
+## Getting Started
+
+- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.
+- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
+- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
+
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..5f150f8
--- /dev/null
+++ b/config.py
@@ -0,0 +1,236 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------'
+
+import os
+import yaml
+from yacs.config import CfgNode as CN
+
+_C = CN()
+
+# Base config files
+_C.BASE = ['']
+
+# -----------------------------------------------------------------------------
+# Data settings
+# -----------------------------------------------------------------------------
+_C.DATA = CN()
+# Batch size for a single GPU, could be overwritten by command line argument
+_C.DATA.BATCH_SIZE = 128
+# Path to dataset, could be overwritten by command line argument
+_C.DATA.DATA_PATH = ''
+# Dataset name
+_C.DATA.DATASET = 'imagenet'
+# Input image size
+_C.DATA.IMG_SIZE = 224
+# Interpolation to resize image (random, bilinear, bicubic)
+_C.DATA.INTERPOLATION = 'bicubic'
+# Use zipped dataset instead of folder dataset
+# could be overwritten by command line argument
+_C.DATA.ZIP_MODE = False
+# Cache Data in Memory, could be overwritten by command line argument
+_C.DATA.CACHE_MODE = 'part'
+# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
+_C.DATA.PIN_MEMORY = True
+# Number of data loading threads
+_C.DATA.NUM_WORKERS = 8
+
+# -----------------------------------------------------------------------------
+# Model settings
+# -----------------------------------------------------------------------------
+_C.MODEL = CN()
+# Model type
+_C.MODEL.TYPE = 'swin'
+# Model name
+_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
+# Checkpoint to resume, could be overwritten by command line argument
+_C.MODEL.RESUME = ''
+# Number of classes, overwritten in data preparation
+_C.MODEL.NUM_CLASSES = 1000
+# Dropout rate
+_C.MODEL.DROP_RATE = 0.0
+# Drop path rate
+_C.MODEL.DROP_PATH_RATE = 0.1
+# Label Smoothing
+_C.MODEL.LABEL_SMOOTHING = 0.1
+
+# Swin Transformer parameters
+_C.MODEL.SWIN = CN()
+_C.MODEL.SWIN.PATCH_SIZE = 4
+_C.MODEL.SWIN.IN_CHANS = 3
+_C.MODEL.SWIN.EMBED_DIM = 96
+_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+_C.MODEL.SWIN.WINDOW_SIZE = 7
+_C.MODEL.SWIN.MLP_RATIO = 4.
+_C.MODEL.SWIN.QKV_BIAS = True
+_C.MODEL.SWIN.QK_SCALE = None
+_C.MODEL.SWIN.APE = False
+_C.MODEL.SWIN.PATCH_NORM = True
+
+# -----------------------------------------------------------------------------
+# Training settings
+# -----------------------------------------------------------------------------
+_C.TRAIN = CN()
+_C.TRAIN.START_EPOCH = 0
+_C.TRAIN.EPOCHS = 300
+_C.TRAIN.WARMUP_EPOCHS = 20
+_C.TRAIN.WEIGHT_DECAY = 0.05
+_C.TRAIN.BASE_LR = 5e-4
+_C.TRAIN.WARMUP_LR = 5e-7
+_C.TRAIN.MIN_LR = 5e-6
+# Clip gradient norm
+_C.TRAIN.CLIP_GRAD = 5.0
+# Auto resume from latest checkpoint
+_C.TRAIN.AUTO_RESUME = True
+# Gradient accumulation steps
+# could be overwritten by command line argument
+_C.TRAIN.ACCUMULATION_STEPS = 0
+# Whether to use gradient checkpointing to save memory
+# could be overwritten by command line argument
+_C.TRAIN.USE_CHECKPOINT = False
+
+# LR scheduler
+_C.TRAIN.LR_SCHEDULER = CN()
+_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
+# Epoch interval to decay LR, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
+# LR decay rate, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
+
+# Optimizer
+_C.TRAIN.OPTIMIZER = CN()
+_C.TRAIN.OPTIMIZER.NAME = 'adamw'
+# Optimizer Epsilon
+_C.TRAIN.OPTIMIZER.EPS = 1e-8
+# Optimizer Betas
+_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
+# SGD momentum
+_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
+
+# -----------------------------------------------------------------------------
+# Augmentation settings
+# -----------------------------------------------------------------------------
+_C.AUG = CN()
+# Color jitter factor
+_C.AUG.COLOR_JITTER = 0.4
+# Use AutoAugment policy. "v0" or "original"
+_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
+# Random erase prob
+_C.AUG.REPROB = 0.25
+# Random erase mode
+_C.AUG.REMODE = 'pixel'
+# Random erase count
+_C.AUG.RECOUNT = 1
+# Mixup alpha, mixup enabled if > 0
+_C.AUG.MIXUP = 0.8
+# Cutmix alpha, cutmix enabled if > 0
+_C.AUG.CUTMIX = 1.0
+# Cutmix min/max ratio, overrides alpha and enables cutmix if set
+_C.AUG.CUTMIX_MINMAX = None
+# Probability of performing mixup or cutmix when either/both is enabled
+_C.AUG.MIXUP_PROB = 1.0
+# Probability of switching to cutmix when both mixup and cutmix enabled
+_C.AUG.MIXUP_SWITCH_PROB = 0.5
+# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
+_C.AUG.MIXUP_MODE = 'batch'
+
+# -----------------------------------------------------------------------------
+# Testing settings
+# -----------------------------------------------------------------------------
+_C.TEST = CN()
+# Whether to use center crop when testing
+_C.TEST.CROP = True
+
+# -----------------------------------------------------------------------------
+# Misc
+# -----------------------------------------------------------------------------
+# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
+# overwritten by command line argument
+_C.AMP_OPT_LEVEL = ''
+# Path to output folder, overwritten by command line argument
+_C.OUTPUT = ''
+# Tag of experiment, overwritten by command line argument
+_C.TAG = 'default'
+# Frequency to save checkpoint
+_C.SAVE_FREQ = 1
+# Frequency to logging info
+_C.PRINT_FREQ = 10
+# Fixed random seed
+_C.SEED = 0
+# Perform evaluation only, overwritten by command line argument
+_C.EVAL_MODE = False
+# Test throughput only, overwritten by command line argument
+_C.THROUGHPUT_MODE = False
+# local rank for DistributedDataParallel, given by command line argument
+_C.LOCAL_RANK = 0
+
+
+def _update_config_from_file(config, cfg_file):
+ config.defrost()
+ with open(cfg_file, 'r') as f:
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
+
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
+ if cfg:
+ _update_config_from_file(
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
+ )
+ print('=> merge config from {}'.format(cfg_file))
+ config.merge_from_file(cfg_file)
+ config.freeze()
+
+
+def update_config(config, args):
+ _update_config_from_file(config, args.cfg)
+
+ config.defrost()
+ if args.opts:
+ config.merge_from_list(args.opts)
+
+ # merge from specific arguments
+ if args.batch_size:
+ config.DATA.BATCH_SIZE = args.batch_size
+ if args.data_path:
+ config.DATA.DATA_PATH = args.data_path
+ if args.zip:
+ config.DATA.ZIP_MODE = True
+ if args.cache_mode:
+ config.DATA.CACHE_MODE = args.cache_mode
+ if args.resume:
+ config.MODEL.RESUME = args.resume
+ if args.accumulation_steps:
+ config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
+ if args.use_checkpoint:
+ config.TRAIN.USE_CHECKPOINT = True
+ if args.amp_opt_level:
+ config.AMP_OPT_LEVEL = args.amp_opt_level
+ if args.output:
+ config.OUTPUT = args.output
+ if args.tag:
+ config.TAG = args.tag
+ if args.eval:
+ config.EVAL_MODE = True
+ if args.throughput:
+ config.THROUGHPUT_MODE = True
+
+ # set local rank for distributed training
+ config.LOCAL_RANK = args.local_rank
+
+ # output folder
+ config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
+
+ config.freeze()
+
+
+def get_config(args):
+ """Get a yacs CfgNode object with default values."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ config = _C.clone()
+ update_config(config, args)
+
+ return config
diff --git a/configs/swin_base_patch4_window12_384.yaml b/configs/swin_base_patch4_window12_384.yaml
new file mode 100644
index 0000000..b54deb7
--- /dev/null
+++ b/configs/swin_base_patch4_window12_384.yaml
@@ -0,0 +1,13 @@
+# only for evaluation
+DATA:
+ IMG_SIZE: 384
+MODEL:
+ TYPE: swin
+ NAME: swin_base_patch4_window12_384
+ SWIN:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TEST:
+ CROP: False
\ No newline at end of file
diff --git a/configs/swin_base_patch4_window7_224.yaml b/configs/swin_base_patch4_window7_224.yaml
new file mode 100644
index 0000000..b296128
--- /dev/null
+++ b/configs/swin_base_patch4_window7_224.yaml
@@ -0,0 +1,9 @@
+MODEL:
+ TYPE: swin
+ NAME: swin_base_patch4_window7_224
+ DROP_PATH_RATE: 0.5
+ SWIN:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 7
\ No newline at end of file
diff --git a/configs/swin_large_patch4_window12_384.yaml b/configs/swin_large_patch4_window12_384.yaml
new file mode 100644
index 0000000..bacf5f6
--- /dev/null
+++ b/configs/swin_large_patch4_window12_384.yaml
@@ -0,0 +1,13 @@
+# only for evaluation
+DATA:
+ IMG_SIZE: 384
+MODEL:
+ TYPE: swin
+ NAME: swin_large_patch4_window12_384
+ SWIN:
+ EMBED_DIM: 192
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 6, 12, 24, 48 ]
+ WINDOW_SIZE: 12
+TEST:
+ CROP: False
\ No newline at end of file
diff --git a/configs/swin_large_patch4_window7_224.yaml b/configs/swin_large_patch4_window7_224.yaml
new file mode 100644
index 0000000..df8af4c
--- /dev/null
+++ b/configs/swin_large_patch4_window7_224.yaml
@@ -0,0 +1,9 @@
+# only for evaluation
+MODEL:
+ TYPE: swin
+ NAME: swin_large_patch4_window7_224
+ SWIN:
+ EMBED_DIM: 192
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 6, 12, 24, 48 ]
+ WINDOW_SIZE: 7
\ No newline at end of file
diff --git a/configs/swin_small_patch4_window7_224.yaml b/configs/swin_small_patch4_window7_224.yaml
new file mode 100644
index 0000000..8f5c40f
--- /dev/null
+++ b/configs/swin_small_patch4_window7_224.yaml
@@ -0,0 +1,9 @@
+MODEL:
+ TYPE: swin
+ NAME: swin_small_patch4_window7_224
+ DROP_PATH_RATE: 0.3
+ SWIN:
+ EMBED_DIM: 96
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 3, 6, 12, 24 ]
+ WINDOW_SIZE: 7
\ No newline at end of file
diff --git a/configs/swin_tiny_patch4_window7_224.yaml b/configs/swin_tiny_patch4_window7_224.yaml
new file mode 100644
index 0000000..851c745
--- /dev/null
+++ b/configs/swin_tiny_patch4_window7_224.yaml
@@ -0,0 +1,9 @@
+MODEL:
+ TYPE: swin
+ NAME: swin_tiny_patch4_window7_224
+ DROP_PATH_RATE: 0.2
+ SWIN:
+ EMBED_DIM: 96
+ DEPTHS: [ 2, 2, 6, 2 ]
+ NUM_HEADS: [ 3, 6, 12, 24 ]
+ WINDOW_SIZE: 7
\ No newline at end of file
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..70c633c
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1 @@
+from .build import build_loader
\ No newline at end of file
diff --git a/data/build.py b/data/build.py
new file mode 100644
index 0000000..840d925
--- /dev/null
+++ b/data/build.py
@@ -0,0 +1,128 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import torch
+import numpy as np
+import torch.distributed as dist
+from torchvision import datasets, transforms
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data import Mixup
+from timm.data import create_transform
+from timm.data.transforms import _pil_interp
+
+from .cached_image_folder import CachedImageFolder
+from .samplers import SubsetRandomSampler
+
+
+def build_loader(config):
+ config.defrost()
+ dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
+ config.freeze()
+ print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
+ dataset_val, _ = build_dataset(is_train=False, config=config)
+ print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
+
+ num_tasks = dist.get_world_size()
+ global_rank = dist.get_rank()
+ if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
+ indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
+ sampler_train = SubsetRandomSampler(indices)
+ else:
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+
+ indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
+ sampler_val = SubsetRandomSampler(indices)
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=config.DATA.BATCH_SIZE,
+ num_workers=config.DATA.NUM_WORKERS,
+ pin_memory=config.DATA.PIN_MEMORY,
+ drop_last=True,
+ )
+
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=config.DATA.BATCH_SIZE,
+ shuffle=False,
+ num_workers=config.DATA.NUM_WORKERS,
+ pin_memory=config.DATA.PIN_MEMORY,
+ drop_last=False
+ )
+
+ # setup mixup / cutmix
+ mixup_fn = None
+ mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
+ if mixup_active:
+ mixup_fn = Mixup(
+ mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
+ prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
+ label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
+
+ return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
+
+
+def build_dataset(is_train, config):
+ transform = build_transform(is_train, config)
+ if config.DATA.DATASET == 'imagenet':
+ prefix = 'train' if is_train else 'val'
+ if config.DATA.ZIP_MODE:
+ ann_file = prefix + "_map.txt"
+ prefix = prefix + ".zip@/"
+ dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
+ cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
+ else:
+ root = os.path.join(config.DATA.DATA_PATH, prefix)
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 1000
+ else:
+ raise NotImplementedError("We only support ImageNet Now.")
+
+ return dataset, nb_classes
+
+
+def build_transform(is_train, config):
+ resize_im = config.DATA.IMG_SIZE > 32
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=config.DATA.IMG_SIZE,
+ is_training=True,
+ color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
+ auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
+ re_prob=config.AUG.REPROB,
+ re_mode=config.AUG.REMODE,
+ re_count=config.AUG.RECOUNT,
+ interpolation=config.DATA.INTERPOLATION,
+ )
+ if not resize_im:
+ # replace RandomResizedCropAndInterpolation with
+ # RandomCrop
+ transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
+ return transform
+
+ t = []
+ if resize_im:
+ if config.TEST.CROP:
+ size = int((256 / 224) * config.DATA.IMG_SIZE)
+ t.append(
+ transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
+ # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
+ else:
+ t.append(
+ transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
+ interpolation=_pil_interp(config.DATA.INTERPOLATION))
+ )
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
+ return transforms.Compose(t)
diff --git a/data/cached_image_folder.py b/data/cached_image_folder.py
new file mode 100644
index 0000000..79db42d
--- /dev/null
+++ b/data/cached_image_folder.py
@@ -0,0 +1,251 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import io
+import os
+import time
+import torch.distributed as dist
+import torch.utils.data as data
+from PIL import Image
+
+from .zipreader import is_zip_path, ZipReader
+
+
+def has_file_allowed_extension(filename, extensions):
+ """Checks if a file is an allowed extension.
+ Args:
+ filename (string): path to a file
+ Returns:
+ bool: True if the filename ends with a known image extension
+ """
+ filename_lower = filename.lower()
+ return any(filename_lower.endswith(ext) for ext in extensions)
+
+
+def find_classes(dir):
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
+ classes.sort()
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ return classes, class_to_idx
+
+
+def make_dataset(dir, class_to_idx, extensions):
+ images = []
+ dir = os.path.expanduser(dir)
+ for target in sorted(os.listdir(dir)):
+ d = os.path.join(dir, target)
+ if not os.path.isdir(d):
+ continue
+
+ for root, _, fnames in sorted(os.walk(d)):
+ for fname in sorted(fnames):
+ if has_file_allowed_extension(fname, extensions):
+ path = os.path.join(root, fname)
+ item = (path, class_to_idx[target])
+ images.append(item)
+
+ return images
+
+
+def make_dataset_with_ann(ann_file, img_prefix, extensions):
+ images = []
+ with open(ann_file, "r") as f:
+ contents = f.readlines()
+ for line_str in contents:
+ path_contents = [c for c in line_str.split('\t')]
+ im_file_name = path_contents[0]
+ class_index = int(path_contents[1])
+
+ assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
+ item = (os.path.join(img_prefix, im_file_name), class_index)
+
+ images.append(item)
+
+ return images
+
+
+class DatasetFolder(data.Dataset):
+ """A generic data loader where the samples are arranged in this way: ::
+ root/class_x/xxx.ext
+ root/class_x/xxy.ext
+ root/class_x/xxz.ext
+ root/class_y/123.ext
+ root/class_y/nsdf3.ext
+ root/class_y/asd932_.ext
+ Args:
+ root (string): Root directory path.
+ loader (callable): A function to load a sample given its path.
+ extensions (list[string]): A list of allowed extensions.
+ transform (callable, optional): A function/transform that takes in
+ a sample and returns a transformed version.
+ E.g, ``transforms.RandomCrop`` for images.
+ target_transform (callable, optional): A function/transform that takes
+ in the target and transforms it.
+ Attributes:
+ samples (list): List of (sample path, class_index) tuples
+ """
+
+ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
+ cache_mode="no"):
+ # image folder mode
+ if ann_file == '':
+ _, class_to_idx = find_classes(root)
+ samples = make_dataset(root, class_to_idx, extensions)
+ # zip mode
+ else:
+ samples = make_dataset_with_ann(os.path.join(root, ann_file),
+ os.path.join(root, img_prefix),
+ extensions)
+
+ if len(samples) == 0:
+ raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
+ "Supported extensions are: " + ",".join(extensions)))
+
+ self.root = root
+ self.loader = loader
+ self.extensions = extensions
+
+ self.samples = samples
+ self.labels = [y_1k for _, y_1k in samples]
+ self.classes = list(set(self.labels))
+
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.cache_mode = cache_mode
+ if self.cache_mode != "no":
+ self.init_cache()
+
+ def init_cache(self):
+ assert self.cache_mode in ["part", "full"]
+ n_sample = len(self.samples)
+ global_rank = dist.get_rank()
+ world_size = dist.get_world_size()
+
+ samples_bytes = [None for _ in range(n_sample)]
+ start_time = time.time()
+ for index in range(n_sample):
+ if index % (n_sample // 10) == 0:
+ t = time.time() - start_time
+ print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
+ start_time = time.time()
+ path, target = self.samples[index]
+ if self.cache_mode == "full":
+ samples_bytes[index] = (ZipReader.read(path), target)
+ elif self.cache_mode == "part" and index % world_size == global_rank:
+ samples_bytes[index] = (ZipReader.read(path), target)
+ else:
+ samples_bytes[index] = (path, target)
+ self.samples = samples_bytes
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __repr__(self):
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
+ fmt_str += ' Root Location: {}\n'.format(self.root)
+ tmp = ' Transforms (if any): '
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ tmp = ' Target Transforms (if any): '
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ return fmt_str
+
+
+IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
+
+
+def pil_loader(path):
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+ if isinstance(path, bytes):
+ img = Image.open(io.BytesIO(path))
+ elif is_zip_path(path):
+ data = ZipReader.read(path)
+ img = Image.open(io.BytesIO(data))
+ else:
+ with open(path, 'rb') as f:
+ img = Image.open(f)
+ return img.convert('RGB')
+
+
+def accimage_loader(path):
+ import accimage
+ try:
+ return accimage.Image(path)
+ except IOError:
+ # Potentially a decoding problem, fall back to PIL.Image
+ return pil_loader(path)
+
+
+def default_img_loader(path):
+ from torchvision import get_image_backend
+ if get_image_backend() == 'accimage':
+ return accimage_loader(path)
+ else:
+ return pil_loader(path)
+
+
+class CachedImageFolder(DatasetFolder):
+ """A generic data loader where the images are arranged in this way: ::
+ root/dog/xxx.png
+ root/dog/xxy.png
+ root/dog/xxz.png
+ root/cat/123.png
+ root/cat/nsdf3.png
+ root/cat/asd932_.png
+ Args:
+ root (string): Root directory path.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ Attributes:
+ imgs (list): List of (image path, class_index) tuples
+ """
+
+ def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
+ loader=default_img_loader, cache_mode="no"):
+ super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
+ ann_file=ann_file, img_prefix=img_prefix,
+ transform=transform, target_transform=target_transform,
+ cache_mode=cache_mode)
+ self.imgs = self.samples
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ image = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(image)
+ else:
+ img = image
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
diff --git a/data/samplers.py b/data/samplers.py
new file mode 100644
index 0000000..596e220
--- /dev/null
+++ b/data/samplers.py
@@ -0,0 +1,29 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import torch
+
+
+class SubsetRandomSampler(torch.utils.data.Sampler):
+ r"""Samples elements randomly from a given list of indices, without replacement.
+
+ Arguments:
+ indices (sequence): a sequence of indices
+ """
+
+ def __init__(self, indices):
+ self.epoch = 0
+ self.indices = indices
+
+ def __iter__(self):
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
+
+ def __len__(self):
+ return len(self.indices)
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/data/zipreader.py b/data/zipreader.py
new file mode 100644
index 0000000..060bc46
--- /dev/null
+++ b/data/zipreader.py
@@ -0,0 +1,103 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import zipfile
+import io
+import numpy as np
+from PIL import Image
+from PIL import ImageFile
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+def is_zip_path(img_or_path):
+ """judge if this is a zip path"""
+ return '.zip@' in img_or_path
+
+
+class ZipReader(object):
+ """A class to read zipped files"""
+ zip_bank = dict()
+
+ def __init__(self):
+ super(ZipReader, self).__init__()
+
+ @staticmethod
+ def get_zipfile(path):
+ zip_bank = ZipReader.zip_bank
+ if path not in zip_bank:
+ zfile = zipfile.ZipFile(path, 'r')
+ zip_bank[path] = zfile
+ return zip_bank[path]
+
+ @staticmethod
+ def split_zip_style_path(path):
+ pos_at = path.index('@')
+ assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
+
+ zip_path = path[0: pos_at]
+ folder_path = path[pos_at + 1:]
+ folder_path = str.strip(folder_path, '/')
+ return zip_path, folder_path
+
+ @staticmethod
+ def list_folder(path):
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+ zfile = ZipReader.get_zipfile(zip_path)
+ folder_list = []
+ for file_foler_name in zfile.namelist():
+ file_foler_name = str.strip(file_foler_name, '/')
+ if file_foler_name.startswith(folder_path) and \
+ len(os.path.splitext(file_foler_name)[-1]) == 0 and \
+ file_foler_name != folder_path:
+ if len(folder_path) == 0:
+ folder_list.append(file_foler_name)
+ else:
+ folder_list.append(file_foler_name[len(folder_path) + 1:])
+
+ return folder_list
+
+ @staticmethod
+ def list_files(path, extension=None):
+ if extension is None:
+ extension = ['.*']
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+ zfile = ZipReader.get_zipfile(zip_path)
+ file_lists = []
+ for file_foler_name in zfile.namelist():
+ file_foler_name = str.strip(file_foler_name, '/')
+ if file_foler_name.startswith(folder_path) and \
+ str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
+ if len(folder_path) == 0:
+ file_lists.append(file_foler_name)
+ else:
+ file_lists.append(file_foler_name[len(folder_path) + 1:])
+
+ return file_lists
+
+ @staticmethod
+ def read(path):
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
+ zfile = ZipReader.get_zipfile(zip_path)
+ data = zfile.read(path_img)
+ return data
+
+ @staticmethod
+ def imread(path):
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
+ zfile = ZipReader.get_zipfile(zip_path)
+ data = zfile.read(path_img)
+ try:
+ im = Image.open(io.BytesIO(data))
+ except:
+ print("ERROR IMG LOADED: ", path_img)
+ random_img = np.random.rand(224, 224, 3) * 255
+ im = Image.fromarray(np.uint8(random_img))
+ return im
diff --git a/figures/teaser.png b/figures/teaser.png
new file mode 100644
index 0000000..bcd2f74
Binary files /dev/null and b/figures/teaser.png differ
diff --git a/get_started.md b/get_started.md
new file mode 100644
index 0000000..ead05f2
--- /dev/null
+++ b/get_started.md
@@ -0,0 +1,203 @@
+# Swin Transformer for Image Classification
+
+This folder contains the implementation of the Swin Transformer for image classification.
+
+## Model Zoo
+
+### Regular ImageNet-1K trained models
+
+| name | resolution |acc@1 | acc@5 | #params | FLOPs | model |
+|:---:|:---:|:---:|:---:| :---:| :---:|:---:|
+| Swin-T | 224x224 | 81.2 | 95.5 | 28M | 4.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
+| Swin-S | 224x224 | 83.2 | 96.2 | 50M | 8.7G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
+| Swin-B | 224x224 | 83.5 | 96.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
+| Swin-B | 384x384 | 84.5 | 97.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
+
+### ImageNet-22K pre-trained models
+
+| name | resolution |acc@1 | acc@5 | #params | FLOPs | 22K model | 1K model |
+|:---: |:---: |:---:|:---:|:---:|:---:|:---:|:---:|
+| Swin-B | 224x224 | 85.2 | 97.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
+| Swin-B | 384x384 | 86.4 | 98.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
+| Swin-L | 224x224 | 86.3 | 97.9 | 197M | 34.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
+| Swin-L | 384x384 | 87.3 | 98.2 | 197M | 103.9G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
+
+Note: access code for `baidu` is `swin`.
+
+## Usage
+
+### Install
+
+- Clone this repo:
+
+```bash
+git clone https://github.com/microsoft/Swin-Transformer.git
+cd Swin-Transformer
+```
+
+- Create a conda virtual environment and activate it:
+
+```bash
+conda create -n swin python=3.7 -y
+conda activate swin
+```
+
+- Install `CUDA==10.1` with `cudnn7` following
+ the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
+- Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`:
+
+```bash
+conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
+```
+
+- Install `timm==0.3.2`:
+
+```bash
+pip install timm==0.3.2
+```
+
+- Install `Apex`:
+
+```bash
+git clone https://github.com/NVIDIA/apex
+cd apex
+pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
+```
+
+- Install other requirements:
+
+```bash
+pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
+```
+
+### Data preparation
+
+We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
+load data:
+
+- For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
+ ```bash
+ $ tree data
+ imagenet
+ ├── train
+ │ ├── class1
+ │ │ ├── img1.jpeg
+ │ │ ├── img2.jpeg
+ │ │ └── ...
+ │ ├── class2
+ │ │ ├── img3.jpeg
+ │ │ └── ...
+ │ └── ...
+ └── val
+ ├── class1
+ │ ├── img4.jpeg
+ │ ├── img5.jpeg
+ │ └── ...
+ ├── class2
+ │ ├── img6.jpeg
+ │ └── ...
+ └── ...
+
+ ```
+- To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
+ four files:
+ - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
+ - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth
+ label. Make sure the data folder looks like this:
+
+ ```bash
+ $ tree data
+ data
+ └── ImageNet-Zip
+ ├── train_map.txt
+ ├── train.zip
+ ├── val_map.txt
+ └── val.zip
+
+ $ head -n 5 data/ImageNet-Zip/val_map.txt
+ ILSVRC2012_val_00000001.JPEG 65
+ ILSVRC2012_val_00000002.JPEG 970
+ ILSVRC2012_val_00000003.JPEG 230
+ ILSVRC2012_val_00000004.JPEG 809
+ ILSVRC2012_val_00000005.JPEG 516
+
+ $ head -n 5 data/ImageNet-Zip/train_map.txt
+ n01440764/n01440764_10026.JPEG 0
+ n01440764/n01440764_10027.JPEG 0
+ n01440764/n01440764_10029.JPEG 0
+ n01440764/n01440764_10040.JPEG 0
+ n01440764/n01440764_10042.JPEG 0
+ ```
+
+### Evaluation
+
+To evaluate a pre-trained `Swin Transformer` on ImageNet val, run:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \
+--cfg --resume --data-path
+```
+
+For example, to evaluate the `Swin-B` with a single GPU:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
+--cfg configs/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path
+```
+
+### Training from scratch
+
+To train a `Swin Transformer` on ImageNet from scratch, run:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py \
+--cfg --data-path [--batch-size --output --tag ]
+```
+
+**Notes**:
+
+- To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters.
+ - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will
+ shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.
+- When GPU memory is not enough, you can try the following suggestions:
+ - Use gradient accumulation by adding `--accumulation-steps `, set appropriate `` according to your need.
+ - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`.
+ Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details.
+ - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found
+ in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
+- To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g.,
+ `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5.
+- For additional options, see [config](config.py) and run `python main.py --help` to get detailed message.
+
+For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run:
+
+`Swin-T`:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
+--cfg configs/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 128
+```
+
+`Swin-S`:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
+--cfg configs/swin_small_patch4_window7_224.yaml --data-path --batch-size 128
+```
+
+`Swin-B`:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
+--cfg configs/swin_base_patch4_window7_224.yaml --data-path --batch-size 64 \
+--accumulation-steps 2 [--use-checkpoint]
+```
+
+### Throughput
+
+To measure the throughput, run:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \
+--cfg --data-path --batch-size 64 --throughput --amp-opt-level O0
+```
diff --git a/logger.py b/logger.py
new file mode 100644
index 0000000..a066e55
--- /dev/null
+++ b/logger.py
@@ -0,0 +1,41 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import sys
+import logging
+import functools
+from termcolor import colored
+
+
+@functools.lru_cache()
+def create_logger(output_dir, dist_rank=0, name=''):
+ # create logger
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ # create formatter
+ fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
+ color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
+ colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
+
+ # create console handlers for master process
+ if dist_rank == 0:
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setLevel(logging.DEBUG)
+ console_handler.setFormatter(
+ logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+ logger.addHandler(console_handler)
+
+ # create file handlers
+ file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+ logger.addHandler(file_handler)
+
+ return logger
diff --git a/lr_scheduler.py b/lr_scheduler.py
new file mode 100644
index 0000000..4d27289
--- /dev/null
+++ b/lr_scheduler.py
@@ -0,0 +1,102 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import torch
+from timm.scheduler.cosine_lr import CosineLRScheduler
+from timm.scheduler.step_lr import StepLRScheduler
+from timm.scheduler.scheduler import Scheduler
+
+
+def build_scheduler(config, optimizer, n_iter_per_epoch):
+ num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
+ warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
+ decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
+
+ lr_scheduler = None
+ if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
+ lr_scheduler = CosineLRScheduler(
+ optimizer,
+ t_initial=num_steps,
+ t_mul=1.,
+ lr_min=config.TRAIN.MIN_LR,
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
+ warmup_t=warmup_steps,
+ cycle_limit=1,
+ t_in_epochs=False,
+ )
+ elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
+ lr_scheduler = LinearLRScheduler(
+ optimizer,
+ t_initial=num_steps,
+ lr_min_rate=0.01,
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
+ warmup_t=warmup_steps,
+ t_in_epochs=False,
+ )
+ elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
+ lr_scheduler = StepLRScheduler(
+ optimizer,
+ decay_t=decay_steps,
+ decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
+ warmup_t=warmup_steps,
+ t_in_epochs=False,
+ )
+
+ return lr_scheduler
+
+
+class LinearLRScheduler(Scheduler):
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ t_initial: int,
+ lr_min_rate: float,
+ warmup_t=0,
+ warmup_lr_init=0.,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ initialize=True,
+ ) -> None:
+ super().__init__(
+ optimizer, param_group_field="lr",
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+ initialize=initialize)
+
+ self.t_initial = t_initial
+ self.lr_min_rate = lr_min_rate
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ self.t_in_epochs = t_in_epochs
+ if self.warmup_t:
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ super().update_groups(self.warmup_lr_init)
+ else:
+ self.warmup_steps = [1 for _ in self.base_values]
+
+ def _get_lr(self, t):
+ if t < self.warmup_t:
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+ else:
+ t = t - self.warmup_t
+ total_t = self.t_initial - self.warmup_t
+ lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
+ return lrs
+
+ def get_epoch_values(self, epoch: int):
+ if self.t_in_epochs:
+ return self._get_lr(epoch)
+ else:
+ return None
+
+ def get_update_values(self, num_updates: int):
+ if not self.t_in_epochs:
+ return self._get_lr(num_updates)
+ else:
+ return None
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..eba7be3
--- /dev/null
+++ b/main.py
@@ -0,0 +1,345 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import time
+import argparse
+import datetime
+import numpy as np
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.utils import accuracy, AverageMeter
+
+from config import get_config
+from models import build_model
+from data import build_loader
+from lr_scheduler import build_scheduler
+from optimizer import build_optimizer
+from logger import create_logger
+from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
+
+try:
+ # noinspection PyUnresolvedReferences
+ from apex import amp
+except ImportError:
+ amp = None
+
+
+def parse_option():
+ parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
+ parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
+ parser.add_argument(
+ "--opts",
+ help="Modify config options by adding 'KEY VALUE' pairs. ",
+ default=None,
+ nargs='+',
+ )
+
+ # easy config modification
+ parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
+ parser.add_argument('--data-path', type=str, help='path to dataset')
+ parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
+ parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
+ help='no: no cache, '
+ 'full: cache all data, '
+ 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
+ parser.add_argument('--resume', help='resume from checkpoint')
+ parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
+ parser.add_argument('--use-checkpoint', action='store_true',
+ help="whether to use gradient checkpointing to save memory")
+ parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
+ help='mixed precision opt level, if O0, no amp is used')
+ parser.add_argument('--output', default='output', type=str, metavar='PATH',
+ help='root of output folder, the full path is