Initial commit

This commit is contained in:
v-zeliu1 2021-04-13 00:34:56 +08:00
parent ce5bae042d
commit 3dc2a55301
24 changed files with 2482 additions and 4 deletions

129
.gitignore vendored Normal file
View File

@ -0,0 +1,129 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

View File

@ -1,13 +1,89 @@
# Swin Transformer
By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://sites.google.com/site/hanhushomepage/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/instance-segmentation-on-coco-minival)](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/semantic-segmentation-on-ade20k-val)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k-val?p=swin-transformer-hierarchical-vision)
This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/abs/2103.14030). The code will be coming soon.
By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf). It currently includes code and models for the following tasks:
> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.
> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
## Updates
***04/12/2021***
Initial commits:
1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
## Introduction
**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text. To address these differences, we propose a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at various scales and has linear computational complexity with respect to image size.
These qualities of Swin Transformer make it compatible with a broad range of vision tasks, including image classification (86.4 top-1 accuracy on ImageNet-1K) and dense prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation (53.5 mIoU on ADE20K val).
**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
computation to non-overlapping local windows while also allowing for cross-window connection.
Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
ADE20K semantic segmentatiion (`53.5 mIoU` on val), surpassing previous models by a large margin.
![teaser](figures/teaser.png)
## Main Results on ImageNet with Pretrained Models
**ImageNet-1K and ImageNet-22K Pretrained Models**
| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
Note: access code for `baidu` is `swin`.
## Main Results on Downstream Tasks
**COCO Object Detection (2017 val)**
| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |
| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |
| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |
| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G |
| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G |
| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |
| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |
| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |
| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |
| Swin-L | HTC++<sup>*</sup> | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |
Note: <sup>*</sup> indicates multi-scale testing.
**ADE20K Semantic Segmentation (val)**
| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |
| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |
| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |
| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |
| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |
## Citing Swin Transformer
@ -20,6 +96,12 @@ These qualities of Swin Transformer make it compatible with a broad range of vis
}
```
## Getting Started
- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.
- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a

236
config.py Normal file
View File

@ -0,0 +1,236 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------'
import os
import yaml
from yacs.config import CfgNode as CN
_C = CN()
# Base config files
_C.BASE = ['']
# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
# Batch size for a single GPU, could be overwritten by command line argument
_C.DATA.BATCH_SIZE = 128
# Path to dataset, could be overwritten by command line argument
_C.DATA.DATA_PATH = ''
# Dataset name
_C.DATA.DATASET = 'imagenet'
# Input image size
_C.DATA.IMG_SIZE = 224
# Interpolation to resize image (random, bilinear, bicubic)
_C.DATA.INTERPOLATION = 'bicubic'
# Use zipped dataset instead of folder dataset
# could be overwritten by command line argument
_C.DATA.ZIP_MODE = False
# Cache Data in Memory, could be overwritten by command line argument
_C.DATA.CACHE_MODE = 'part'
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
_C.DATA.PIN_MEMORY = True
# Number of data loading threads
_C.DATA.NUM_WORKERS = 8
# -----------------------------------------------------------------------------
# Model settings
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = 'swin'
# Model name
_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
# Checkpoint to resume, could be overwritten by command line argument
_C.MODEL.RESUME = ''
# Number of classes, overwritten in data preparation
_C.MODEL.NUM_CLASSES = 1000
# Dropout rate
_C.MODEL.DROP_RATE = 0.0
# Drop path rate
_C.MODEL.DROP_PATH_RATE = 0.1
# Label Smoothing
_C.MODEL.LABEL_SMOOTHING = 0.1
# Swin Transformer parameters
_C.MODEL.SWIN = CN()
_C.MODEL.SWIN.PATCH_SIZE = 4
_C.MODEL.SWIN.IN_CHANS = 3
_C.MODEL.SWIN.EMBED_DIM = 96
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN.WINDOW_SIZE = 7
_C.MODEL.SWIN.MLP_RATIO = 4.
_C.MODEL.SWIN.QKV_BIAS = True
_C.MODEL.SWIN.QK_SCALE = None
_C.MODEL.SWIN.APE = False
_C.MODEL.SWIN.PATCH_NORM = True
# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.EPOCHS = 300
_C.TRAIN.WARMUP_EPOCHS = 20
_C.TRAIN.WEIGHT_DECAY = 0.05
_C.TRAIN.BASE_LR = 5e-4
_C.TRAIN.WARMUP_LR = 5e-7
_C.TRAIN.MIN_LR = 5e-6
# Clip gradient norm
_C.TRAIN.CLIP_GRAD = 5.0
# Auto resume from latest checkpoint
_C.TRAIN.AUTO_RESUME = True
# Gradient accumulation steps
# could be overwritten by command line argument
_C.TRAIN.ACCUMULATION_STEPS = 0
# Whether to use gradient checkpointing to save memory
# could be overwritten by command line argument
_C.TRAIN.USE_CHECKPOINT = False
# LR scheduler
_C.TRAIN.LR_SCHEDULER = CN()
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
# Epoch interval to decay LR, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
# LR decay rate, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
# Optimizer
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
# Optimizer Epsilon
_C.TRAIN.OPTIMIZER.EPS = 1e-8
# Optimizer Betas
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
# SGD momentum
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Color jitter factor
_C.AUG.COLOR_JITTER = 0.4
# Use AutoAugment policy. "v0" or "original"
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
# Random erase prob
_C.AUG.REPROB = 0.25
# Random erase mode
_C.AUG.REMODE = 'pixel'
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
_C.AUG.MIXUP = 0.8
# Cutmix alpha, cutmix enabled if > 0
_C.AUG.CUTMIX = 1.0
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
_C.AUG.CUTMIX_MINMAX = None
# Probability of performing mixup or cutmix when either/both is enabled
_C.AUG.MIXUP_PROB = 1.0
# Probability of switching to cutmix when both mixup and cutmix enabled
_C.AUG.MIXUP_SWITCH_PROB = 0.5
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
_C.AUG.MIXUP_MODE = 'batch'
# -----------------------------------------------------------------------------
# Testing settings
# -----------------------------------------------------------------------------
_C.TEST = CN()
# Whether to use center crop when testing
_C.TEST.CROP = True
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
# overwritten by command line argument
_C.AMP_OPT_LEVEL = ''
# Path to output folder, overwritten by command line argument
_C.OUTPUT = ''
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
_C.SEED = 0
# Perform evaluation only, overwritten by command line argument
_C.EVAL_MODE = False
# Test throughput only, overwritten by command line argument
_C.THROUGHPUT_MODE = False
# local rank for DistributedDataParallel, given by command line argument
_C.LOCAL_RANK = 0
def _update_config_from_file(config, cfg_file):
config.defrost()
with open(cfg_file, 'r') as f:
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
for cfg in yaml_cfg.setdefault('BASE', ['']):
if cfg:
_update_config_from_file(
config, os.path.join(os.path.dirname(cfg_file), cfg)
)
print('=> merge config from {}'.format(cfg_file))
config.merge_from_file(cfg_file)
config.freeze()
def update_config(config, args):
_update_config_from_file(config, args.cfg)
config.defrost()
if args.opts:
config.merge_from_list(args.opts)
# merge from specific arguments
if args.batch_size:
config.DATA.BATCH_SIZE = args.batch_size
if args.data_path:
config.DATA.DATA_PATH = args.data_path
if args.zip:
config.DATA.ZIP_MODE = True
if args.cache_mode:
config.DATA.CACHE_MODE = args.cache_mode
if args.resume:
config.MODEL.RESUME = args.resume
if args.accumulation_steps:
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
if args.use_checkpoint:
config.TRAIN.USE_CHECKPOINT = True
if args.amp_opt_level:
config.AMP_OPT_LEVEL = args.amp_opt_level
if args.output:
config.OUTPUT = args.output
if args.tag:
config.TAG = args.tag
if args.eval:
config.EVAL_MODE = True
if args.throughput:
config.THROUGHPUT_MODE = True
# set local rank for distributed training
config.LOCAL_RANK = args.local_rank
# output folder
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
config.freeze()
def get_config(args):
"""Get a yacs CfgNode object with default values."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
config = _C.clone()
update_config(config, args)
return config

View File

@ -0,0 +1,13 @@
# only for evaluation
DATA:
IMG_SIZE: 384
MODEL:
TYPE: swin
NAME: swin_base_patch4_window12_384
SWIN:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 12
TEST:
CROP: False

View File

@ -0,0 +1,9 @@
MODEL:
TYPE: swin
NAME: swin_base_patch4_window7_224
DROP_PATH_RATE: 0.5
SWIN:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 7

View File

@ -0,0 +1,13 @@
# only for evaluation
DATA:
IMG_SIZE: 384
MODEL:
TYPE: swin
NAME: swin_large_patch4_window12_384
SWIN:
EMBED_DIM: 192
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 6, 12, 24, 48 ]
WINDOW_SIZE: 12
TEST:
CROP: False

View File

@ -0,0 +1,9 @@
# only for evaluation
MODEL:
TYPE: swin
NAME: swin_large_patch4_window7_224
SWIN:
EMBED_DIM: 192
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 6, 12, 24, 48 ]
WINDOW_SIZE: 7

View File

@ -0,0 +1,9 @@
MODEL:
TYPE: swin
NAME: swin_small_patch4_window7_224
DROP_PATH_RATE: 0.3
SWIN:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 7

View File

@ -0,0 +1,9 @@
MODEL:
TYPE: swin
NAME: swin_tiny_patch4_window7_224
DROP_PATH_RATE: 0.2
SWIN:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 6, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 7

1
data/__init__.py Normal file
View File

@ -0,0 +1 @@
from .build import build_loader

128
data/build.py Normal file
View File

@ -0,0 +1,128 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import torch
import numpy as np
import torch.distributed as dist
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
from timm.data import create_transform
from timm.data.transforms import _pil_interp
from .cached_image_folder import CachedImageFolder
from .samplers import SubsetRandomSampler
def build_loader(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
config.freeze()
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
dataset_val, _ = build_dataset(is_train=False, config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
sampler_val = SubsetRandomSampler(indices)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False
)
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
prefix = 'train' if is_train else 'val'
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
else:
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
else:
raise NotImplementedError("We only support ImageNet Now.")
return dataset, nb_classes
def build_transform(is_train, config):
resize_im = config.DATA.IMG_SIZE > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
interpolation=config.DATA.INTERPOLATION,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
return transform
t = []
if resize_im:
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.IMG_SIZE)
t.append(
transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
else:
t.append(
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=_pil_interp(config.DATA.INTERPOLATION))
)
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)

251
data/cached_image_folder.py Normal file
View File

@ -0,0 +1,251 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import io
import os
import time
import torch.distributed as dist
import torch.utils.data as data
from PIL import Image
from .zipreader import is_zip_path, ZipReader
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images
def make_dataset_with_ann(ann_file, img_prefix, extensions):
images = []
with open(ann_file, "r") as f:
contents = f.readlines()
for line_str in contents:
path_contents = [c for c in line_str.split('\t')]
im_file_name = path_contents[0]
class_index = int(path_contents[1])
assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
item = (os.path.join(img_prefix, im_file_name), class_index)
images.append(item)
return images
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
cache_mode="no"):
# image folder mode
if ann_file == '':
_, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
# zip mode
else:
samples = make_dataset_with_ann(os.path.join(root, ann_file),
os.path.join(root, img_prefix),
extensions)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.samples = samples
self.labels = [y_1k for _, y_1k in samples]
self.classes = list(set(self.labels))
self.transform = transform
self.target_transform = target_transform
self.cache_mode = cache_mode
if self.cache_mode != "no":
self.init_cache()
def init_cache(self):
assert self.cache_mode in ["part", "full"]
n_sample = len(self.samples)
global_rank = dist.get_rank()
world_size = dist.get_world_size()
samples_bytes = [None for _ in range(n_sample)]
start_time = time.time()
for index in range(n_sample):
if index % (n_sample // 10) == 0:
t = time.time() - start_time
print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
start_time = time.time()
path, target = self.samples[index]
if self.cache_mode == "full":
samples_bytes[index] = (ZipReader.read(path), target)
elif self.cache_mode == "part" and index % world_size == global_rank:
samples_bytes[index] = (ZipReader.read(path), target)
else:
samples_bytes[index] = (path, target)
self.samples = samples_bytes
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
if isinstance(path, bytes):
img = Image.open(io.BytesIO(path))
elif is_zip_path(path):
data = ZipReader.read(path)
img = Image.open(io.BytesIO(data))
else:
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_img_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class CachedImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
loader=default_img_loader, cache_mode="no"):
super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
ann_file=ann_file, img_prefix=img_prefix,
transform=transform, target_transform=target_transform,
cache_mode=cache_mode)
self.imgs = self.samples
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
image = self.loader(path)
if self.transform is not None:
img = self.transform(image)
else:
img = image
if self.target_transform is not None:
target = self.target_transform(target)
return img, target

29
data/samplers.py Normal file
View File

@ -0,0 +1,29 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import torch
class SubsetRandomSampler(torch.utils.data.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.epoch = 0
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
def set_epoch(self, epoch):
self.epoch = epoch

103
data/zipreader.py Normal file
View File

@ -0,0 +1,103 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import zipfile
import io
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def is_zip_path(img_or_path):
"""judge if this is a zip path"""
return '.zip@' in img_or_path
class ZipReader(object):
"""A class to read zipped files"""
zip_bank = dict()
def __init__(self):
super(ZipReader, self).__init__()
@staticmethod
def get_zipfile(path):
zip_bank = ZipReader.zip_bank
if path not in zip_bank:
zfile = zipfile.ZipFile(path, 'r')
zip_bank[path] = zfile
return zip_bank[path]
@staticmethod
def split_zip_style_path(path):
pos_at = path.index('@')
assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
zip_path = path[0: pos_at]
folder_path = path[pos_at + 1:]
folder_path = str.strip(folder_path, '/')
return zip_path, folder_path
@staticmethod
def list_folder(path):
zip_path, folder_path = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
folder_list = []
for file_foler_name in zfile.namelist():
file_foler_name = str.strip(file_foler_name, '/')
if file_foler_name.startswith(folder_path) and \
len(os.path.splitext(file_foler_name)[-1]) == 0 and \
file_foler_name != folder_path:
if len(folder_path) == 0:
folder_list.append(file_foler_name)
else:
folder_list.append(file_foler_name[len(folder_path) + 1:])
return folder_list
@staticmethod
def list_files(path, extension=None):
if extension is None:
extension = ['.*']
zip_path, folder_path = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
file_lists = []
for file_foler_name in zfile.namelist():
file_foler_name = str.strip(file_foler_name, '/')
if file_foler_name.startswith(folder_path) and \
str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
if len(folder_path) == 0:
file_lists.append(file_foler_name)
else:
file_lists.append(file_foler_name[len(folder_path) + 1:])
return file_lists
@staticmethod
def read(path):
zip_path, path_img = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
data = zfile.read(path_img)
return data
@staticmethod
def imread(path):
zip_path, path_img = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
data = zfile.read(path_img)
try:
im = Image.open(io.BytesIO(data))
except:
print("ERROR IMG LOADED: ", path_img)
random_img = np.random.rand(224, 224, 3) * 255
im = Image.fromarray(np.uint8(random_img))
return im

BIN
figures/teaser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 909 KiB

203
get_started.md Normal file
View File

@ -0,0 +1,203 @@
# Swin Transformer for Image Classification
This folder contains the implementation of the Swin Transformer for image classification.
## Model Zoo
### Regular ImageNet-1K trained models
| name | resolution |acc@1 | acc@5 | #params | FLOPs | model |
|:---:|:---:|:---:|:---:| :---:| :---:|:---:|
| Swin-T | 224x224 | 81.2 | 95.5 | 28M | 4.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
| Swin-S | 224x224 | 83.2 | 96.2 | 50M | 8.7G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
| Swin-B | 224x224 | 83.5 | 96.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
| Swin-B | 384x384 | 84.5 | 97.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
### ImageNet-22K pre-trained models
| name | resolution |acc@1 | acc@5 | #params | FLOPs | 22K model | 1K model |
|:---: |:---: |:---:|:---:|:---:|:---:|:---:|:---:|
| Swin-B | 224x224 | 85.2 | 97.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
| Swin-B | 384x384 | 86.4 | 98.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
| Swin-L | 224x224 | 86.3 | 97.9 | 197M | 34.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
| Swin-L | 384x384 | 87.3 | 98.2 | 197M | 103.9G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
Note: access code for `baidu` is `swin`.
## Usage
### Install
- Clone this repo:
```bash
git clone https://github.com/microsoft/Swin-Transformer.git
cd Swin-Transformer
```
- Create a conda virtual environment and activate it:
```bash
conda create -n swin python=3.7 -y
conda activate swin
```
- Install `CUDA==10.1` with `cudnn7` following
the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
- Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`:
```bash
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
```
- Install `timm==0.3.2`:
```bash
pip install timm==0.3.2
```
- Install `Apex`:
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
- Install other requirements:
```bash
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
```
### Data preparation
We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
load data:
- For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
```bash
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
```
- To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
four files:
- `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
- `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth
label. Make sure the data folder looks like this:
```bash
$ tree data
data
└── ImageNet-Zip
├── train_map.txt
├── train.zip
├── val_map.txt
└── val.zip
$ head -n 5 data/ImageNet-Zip/val_map.txt
ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516
$ head -n 5 data/ImageNet-Zip/train_map.txt
n01440764/n01440764_10026.JPEG 0
n01440764/n01440764_10027.JPEG 0
n01440764/n01440764_10029.JPEG 0
n01440764/n01440764_10040.JPEG 0
n01440764/n01440764_10042.JPEG 0
```
### Evaluation
To evaluate a pre-trained `Swin Transformer` on ImageNet val, run:
```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \
--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>
```
For example, to evaluate the `Swin-B` with a single GPU:
```bash
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
--cfg configs/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>
```
### Training from scratch
To train a `Swin Transformer` on ImageNet from scratch, run:
```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
```
**Notes**:
- To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters.
- To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will
shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.
- When GPU memory is not enough, you can try the following suggestions:
- Use gradient accumulation by adding `--accumulation-steps <steps>`, set appropriate `<steps>` according to your need.
- Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`.
Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details.
- We recommend using multi-node with more GPUs for training very large models, a tutorial can be found
in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
- To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g.,
`--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5.
- For additional options, see [config](config.py) and run `python main.py --help` to get detailed message.
For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run:
`Swin-T`:
```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
```
`Swin-S`:
```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
```
`Swin-B`:
```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \
--accumulation-steps 2 [--use-checkpoint]
```
### Throughput
To measure the throughput, run:
```bash
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \
--cfg <config-file> --data-path <imagenet-path> --batch-size 64 --throughput --amp-opt-level O0
```

41
logger.py Normal file
View File

@ -0,0 +1,41 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import sys
import logging
import functools
from termcolor import colored
@functools.lru_cache()
def create_logger(output_dir, dist_rank=0, name=''):
# create logger
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
# create formatter
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
# create console handlers for master process
if dist_rank == 0:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(
logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
logger.addHandler(console_handler)
# create file handlers
file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
logger.addHandler(file_handler)
return logger

102
lr_scheduler.py Normal file
View File

@ -0,0 +1,102 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import torch
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.scheduler.step_lr import StepLRScheduler
from timm.scheduler.scheduler import Scheduler
def build_scheduler(config, optimizer, n_iter_per_epoch):
num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
lr_scheduler = None
if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_steps,
t_mul=1.,
lr_min=config.TRAIN.MIN_LR,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
cycle_limit=1,
t_in_epochs=False,
)
elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
lr_scheduler = LinearLRScheduler(
optimizer,
t_initial=num_steps,
lr_min_rate=0.01,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=decay_steps,
decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
return lr_scheduler
class LinearLRScheduler(Scheduler):
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min_rate: float,
warmup_t=0,
warmup_lr_init=0.,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
self.t_initial = t_initial
self.lr_min_rate = lr_min_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
t = t - self.warmup_t
total_t = self.t_initial - self.warmup_t
lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None

345
main.py Normal file
View File

@ -0,0 +1,345 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import time
import argparse
import datetime
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter
from config import get_config
from models import build_model
from data import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def parse_option():
parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
# easy config modification
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--output', default='output', type=str, metavar='PATH',
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
# distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
args, unparsed = parser.parse_known_args()
config = get_config(args)
return args, config
def main(config):
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
model = build_model(config)
model.cuda()
logger.info(str(model))
optimizer = build_optimizer(config, model)
if config.AMP_OPT_LEVEL != "O0":
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"number of params: {n_parameters}")
if hasattr(model_without_ddp, 'flops'):
flops = model_without_ddp.flops()
logger.info(f"number of GFLOPs: {flops / 1e9}")
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
if config.AUG.MIXUP > 0.:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif config.MODEL.LABEL_SMOOTHING > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
else:
criterion = torch.nn.CrossEntropyLoss()
max_accuracy = 0.0
if config.TRAIN.AUTO_RESUME:
resume_file = auto_resume_helper(config.OUTPUT)
if resume_file:
if config.MODEL.RESUME:
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
config.MODEL.RESUME = resume_file
logger.info(f'auto resuming from {resume_file}')
else:
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
if config.MODEL.RESUME:
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
acc1, acc5, loss = validate(config, data_loader_val, model)
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
if config.EVAL_MODE:
return
if config.THROUGHPUT_MODE:
throughput(data_loader_val, model, logger)
return
logger.info("Start training")
start_time = time.time()
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
data_loader_train.sampler.set_epoch(epoch)
train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
acc1, acc5, loss = validate(config, data_loader_val, model)
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
max_accuracy = max(max_accuracy, acc1)
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler):
model.train()
optimizer.zero_grad()
num_steps = len(data_loader)
batch_time = AverageMeter()
loss_meter = AverageMeter()
norm_meter = AverageMeter()
start = time.time()
end = time.time()
for idx, (samples, targets) in enumerate(data_loader):
samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
outputs = model(samples)
if config.TRAIN.ACCUMULATION_STEPS > 1:
loss = criterion(outputs, targets)
loss = loss / config.TRAIN.ACCUMULATION_STEPS
if config.AMP_OPT_LEVEL != "O0":
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(amp.master_params(optimizer))
else:
loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step_update(epoch * num_steps + idx)
else:
loss = criterion(outputs, targets)
optimizer.zero_grad()
if config.AMP_OPT_LEVEL != "O0":
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(amp.master_params(optimizer))
else:
loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
optimizer.step()
lr_scheduler.step_update(epoch * num_steps + idx)
torch.cuda.synchronize()
loss_meter.update(loss.item(), targets.size(0))
norm_meter.update(grad_norm)
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
etas = batch_time.avg * (num_steps - idx)
logger.info(
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
f'mem {memory_used:.0f}MB')
epoch_time = time.time() - start
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
@torch.no_grad()
def validate(config, data_loader, model):
criterion = torch.nn.CrossEntropyLoss()
model.eval()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
end = time.time()
for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
# measure accuracy and record loss
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1 = reduce_tensor(acc1)
acc5 = reduce_tensor(acc5)
loss = reduce_tensor(loss)
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logger.info(
f'Test: [{idx}/{len(data_loader)}]\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
f'Mem {memory_used:.0f}MB')
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
@torch.no_grad()
def throughput(data_loader, model, logger):
model.eval()
for idx, (images, _) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
batch_size = images.shape[0]
for i in range(50):
model(images)
torch.cuda.synchronize()
logger.info(f"throughput averaged with 30 times")
tic1 = time.time()
for i in range(30):
model(images)
torch.cuda.synchronize()
tic2 = time.time()
logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
return
if __name__ == '__main__':
_, config = parse_option()
if config.AMP_OPT_LEVEL != "O0":
assert amp is not None, "amp not installed!"
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
else:
rank = -1
world_size = -1
torch.cuda.set_device(config.LOCAL_RANK)
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
torch.distributed.barrier()
seed = config.SEED + dist.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# linear scale the learning rate according to total batch size, may not be optimal
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
# gradient accumulation also need to scale the learning rate
if config.TRAIN.ACCUMULATION_STEPS > 1:
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
config.defrost()
config.TRAIN.BASE_LR = linear_scaled_lr
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
config.TRAIN.MIN_LR = linear_scaled_min_lr
config.freeze()
os.makedirs(config.OUTPUT, exist_ok=True)
logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
if dist.get_rank() == 0:
path = os.path.join(config.OUTPUT, "config.json")
with open(path, "w") as f:
f.write(config.dump())
logger.info(f"Full config saved to {path}")
# print config
logger.info(config.dump())
main(config)

1
models/__init__.py Normal file
View File

@ -0,0 +1 @@
from .build import build_model

33
models/build.py Normal file
View File

@ -0,0 +1,33 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
from .swin_transformer import SwinTransformer
def build_model(config):
model_type = config.MODEL.TYPE
if model_type == 'swin':
model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWIN.PATCH_SIZE,
in_chans=config.MODEL.SWIN.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.SWIN.EMBED_DIM,
depths=config.MODEL.SWIN.DEPTHS,
num_heads=config.MODEL.SWIN.NUM_HEADS,
window_size=config.MODEL.SWIN.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
qkv_bias=config.MODEL.SWIN.QKV_BIAS,
qk_scale=config.MODEL.SWIN.QK_SCALE,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWIN.APE,
patch_norm=config.MODEL.SWIN.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT)
else:
raise NotImplementedError(f"Unkown model: {model_type}")
return model

585
models/swin_transformer.py Normal file
View File

@ -0,0 +1,585 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops

57
optimizer.py Normal file
View File

@ -0,0 +1,57 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
from torch import optim as optim
def build_optimizer(config, model):
"""
Build optimizer, set weight decay of normalization to 0 by default.
"""
skip = {}
skip_keywords = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
if hasattr(model, 'no_weight_decay_keywords'):
skip_keywords = model.no_weight_decay_keywords()
parameters = set_weight_decay(model, skip, skip_keywords)
opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
optimizer = None
if opt_lower == 'sgd':
optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
return optimizer
def set_weight_decay(model, skip_list=(), skip_keywords=()):
has_decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
check_keywords_in_name(name, skip_keywords):
no_decay.append(param)
# print(f"{name} has no weight decay")
else:
has_decay.append(param)
return [{'params': has_decay},
{'params': no_decay, 'weight_decay': 0.}]
def check_keywords_in_name(name, keywords=()):
isin = False
for keyword in keywords:
if keyword in name:
isin = True
return isin

90
utils.py Normal file
View File

@ -0,0 +1,90 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import torch
import torch.distributed as dist
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
amp.load_state_dict(checkpoint['amp'])
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
if config.AMP_OPT_LEVEL != "O0":
save_state['amp'] = amp.state_dict()
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt