Initial commit
This commit is contained in:
parent
ce5bae042d
commit
3dc2a55301
129
.gitignore
vendored
Normal file
129
.gitignore
vendored
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
90
README.md
90
README.md
|
@ -1,13 +1,89 @@
|
||||||
# Swin Transformer
|
# Swin Transformer
|
||||||
|
|
||||||
By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://sites.google.com/site/hanhushomepage/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
|
[](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-hierarchical-vision)
|
||||||
|
[](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-hierarchical-vision)
|
||||||
|
[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=swin-transformer-hierarchical-vision)
|
||||||
|
[](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=swin-transformer-hierarchical-vision)
|
||||||
|
[](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-hierarchical-vision)
|
||||||
|
[](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k-val?p=swin-transformer-hierarchical-vision)
|
||||||
|
|
||||||
This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/abs/2103.14030). The code will be coming soon.
|
By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
|
||||||
|
|
||||||
|
This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf). It currently includes code and models for the following tasks:
|
||||||
|
|
||||||
|
> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.
|
||||||
|
|
||||||
|
> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
|
||||||
|
|
||||||
|
> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
|
||||||
|
|
||||||
|
## Updates
|
||||||
|
|
||||||
|
***04/12/2021***
|
||||||
|
|
||||||
|
Initial commits:
|
||||||
|
|
||||||
|
1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
|
||||||
|
2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
|
||||||
|
3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text. To address these differences, we propose a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at various scales and has linear computational complexity with respect to image size.
|
**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
|
||||||
These qualities of Swin Transformer make it compatible with a broad range of vision tasks, including image classification (86.4 top-1 accuracy on ImageNet-1K) and dense prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation (53.5 mIoU on ADE20K val).
|
general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
|
||||||
|
computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
|
||||||
|
computation to non-overlapping local windows while also allowing for cross-window connection.
|
||||||
|
|
||||||
|
Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
|
||||||
|
ADE20K semantic segmentatiion (`53.5 mIoU` on val), surpassing previous models by a large margin.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Main Results on ImageNet with Pretrained Models
|
||||||
|
|
||||||
|
**ImageNet-1K and ImageNet-22K Pretrained Models**
|
||||||
|
|
||||||
|
| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
|
||||||
|
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
|
||||||
|
| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
|
||||||
|
| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
|
||||||
|
| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
|
||||||
|
| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
|
||||||
|
| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
|
||||||
|
| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
|
||||||
|
| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
|
||||||
|
| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
|
||||||
|
|
||||||
|
Note: access code for `baidu` is `swin`.
|
||||||
|
|
||||||
|
## Main Results on Downstream Tasks
|
||||||
|
|
||||||
|
**COCO Object Detection (2017 val)**
|
||||||
|
|
||||||
|
| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |
|
||||||
|
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||||
|
| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |
|
||||||
|
| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |
|
||||||
|
| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |
|
||||||
|
| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G |
|
||||||
|
| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G |
|
||||||
|
| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |
|
||||||
|
| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |
|
||||||
|
| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |
|
||||||
|
| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |
|
||||||
|
| Swin-L | HTC++<sup>*</sup> | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |
|
||||||
|
|
||||||
|
Note: <sup>*</sup> indicates multi-scale testing.
|
||||||
|
|
||||||
|
**ADE20K Semantic Segmentation (val)**
|
||||||
|
|
||||||
|
| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |
|
||||||
|
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||||
|
| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |
|
||||||
|
| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |
|
||||||
|
| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |
|
||||||
|
| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |
|
||||||
|
| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |
|
||||||
|
|
||||||
## Citing Swin Transformer
|
## Citing Swin Transformer
|
||||||
|
|
||||||
|
@ -20,6 +96,12 @@ These qualities of Swin Transformer make it compatible with a broad range of vis
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.
|
||||||
|
- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
|
||||||
|
- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||||
|
|
236
config.py
Normal file
236
config.py
Normal file
|
@ -0,0 +1,236 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------'
|
||||||
|
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
|
_C = CN()
|
||||||
|
|
||||||
|
# Base config files
|
||||||
|
_C.BASE = ['']
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Data settings
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
_C.DATA = CN()
|
||||||
|
# Batch size for a single GPU, could be overwritten by command line argument
|
||||||
|
_C.DATA.BATCH_SIZE = 128
|
||||||
|
# Path to dataset, could be overwritten by command line argument
|
||||||
|
_C.DATA.DATA_PATH = ''
|
||||||
|
# Dataset name
|
||||||
|
_C.DATA.DATASET = 'imagenet'
|
||||||
|
# Input image size
|
||||||
|
_C.DATA.IMG_SIZE = 224
|
||||||
|
# Interpolation to resize image (random, bilinear, bicubic)
|
||||||
|
_C.DATA.INTERPOLATION = 'bicubic'
|
||||||
|
# Use zipped dataset instead of folder dataset
|
||||||
|
# could be overwritten by command line argument
|
||||||
|
_C.DATA.ZIP_MODE = False
|
||||||
|
# Cache Data in Memory, could be overwritten by command line argument
|
||||||
|
_C.DATA.CACHE_MODE = 'part'
|
||||||
|
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
|
||||||
|
_C.DATA.PIN_MEMORY = True
|
||||||
|
# Number of data loading threads
|
||||||
|
_C.DATA.NUM_WORKERS = 8
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Model settings
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
_C.MODEL = CN()
|
||||||
|
# Model type
|
||||||
|
_C.MODEL.TYPE = 'swin'
|
||||||
|
# Model name
|
||||||
|
_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
|
||||||
|
# Checkpoint to resume, could be overwritten by command line argument
|
||||||
|
_C.MODEL.RESUME = ''
|
||||||
|
# Number of classes, overwritten in data preparation
|
||||||
|
_C.MODEL.NUM_CLASSES = 1000
|
||||||
|
# Dropout rate
|
||||||
|
_C.MODEL.DROP_RATE = 0.0
|
||||||
|
# Drop path rate
|
||||||
|
_C.MODEL.DROP_PATH_RATE = 0.1
|
||||||
|
# Label Smoothing
|
||||||
|
_C.MODEL.LABEL_SMOOTHING = 0.1
|
||||||
|
|
||||||
|
# Swin Transformer parameters
|
||||||
|
_C.MODEL.SWIN = CN()
|
||||||
|
_C.MODEL.SWIN.PATCH_SIZE = 4
|
||||||
|
_C.MODEL.SWIN.IN_CHANS = 3
|
||||||
|
_C.MODEL.SWIN.EMBED_DIM = 96
|
||||||
|
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
||||||
|
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
||||||
|
_C.MODEL.SWIN.WINDOW_SIZE = 7
|
||||||
|
_C.MODEL.SWIN.MLP_RATIO = 4.
|
||||||
|
_C.MODEL.SWIN.QKV_BIAS = True
|
||||||
|
_C.MODEL.SWIN.QK_SCALE = None
|
||||||
|
_C.MODEL.SWIN.APE = False
|
||||||
|
_C.MODEL.SWIN.PATCH_NORM = True
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Training settings
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
_C.TRAIN = CN()
|
||||||
|
_C.TRAIN.START_EPOCH = 0
|
||||||
|
_C.TRAIN.EPOCHS = 300
|
||||||
|
_C.TRAIN.WARMUP_EPOCHS = 20
|
||||||
|
_C.TRAIN.WEIGHT_DECAY = 0.05
|
||||||
|
_C.TRAIN.BASE_LR = 5e-4
|
||||||
|
_C.TRAIN.WARMUP_LR = 5e-7
|
||||||
|
_C.TRAIN.MIN_LR = 5e-6
|
||||||
|
# Clip gradient norm
|
||||||
|
_C.TRAIN.CLIP_GRAD = 5.0
|
||||||
|
# Auto resume from latest checkpoint
|
||||||
|
_C.TRAIN.AUTO_RESUME = True
|
||||||
|
# Gradient accumulation steps
|
||||||
|
# could be overwritten by command line argument
|
||||||
|
_C.TRAIN.ACCUMULATION_STEPS = 0
|
||||||
|
# Whether to use gradient checkpointing to save memory
|
||||||
|
# could be overwritten by command line argument
|
||||||
|
_C.TRAIN.USE_CHECKPOINT = False
|
||||||
|
|
||||||
|
# LR scheduler
|
||||||
|
_C.TRAIN.LR_SCHEDULER = CN()
|
||||||
|
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
|
||||||
|
# Epoch interval to decay LR, used in StepLRScheduler
|
||||||
|
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
|
||||||
|
# LR decay rate, used in StepLRScheduler
|
||||||
|
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
_C.TRAIN.OPTIMIZER = CN()
|
||||||
|
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
|
||||||
|
# Optimizer Epsilon
|
||||||
|
_C.TRAIN.OPTIMIZER.EPS = 1e-8
|
||||||
|
# Optimizer Betas
|
||||||
|
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
|
||||||
|
# SGD momentum
|
||||||
|
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Augmentation settings
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
_C.AUG = CN()
|
||||||
|
# Color jitter factor
|
||||||
|
_C.AUG.COLOR_JITTER = 0.4
|
||||||
|
# Use AutoAugment policy. "v0" or "original"
|
||||||
|
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
|
||||||
|
# Random erase prob
|
||||||
|
_C.AUG.REPROB = 0.25
|
||||||
|
# Random erase mode
|
||||||
|
_C.AUG.REMODE = 'pixel'
|
||||||
|
# Random erase count
|
||||||
|
_C.AUG.RECOUNT = 1
|
||||||
|
# Mixup alpha, mixup enabled if > 0
|
||||||
|
_C.AUG.MIXUP = 0.8
|
||||||
|
# Cutmix alpha, cutmix enabled if > 0
|
||||||
|
_C.AUG.CUTMIX = 1.0
|
||||||
|
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
||||||
|
_C.AUG.CUTMIX_MINMAX = None
|
||||||
|
# Probability of performing mixup or cutmix when either/both is enabled
|
||||||
|
_C.AUG.MIXUP_PROB = 1.0
|
||||||
|
# Probability of switching to cutmix when both mixup and cutmix enabled
|
||||||
|
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
||||||
|
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
||||||
|
_C.AUG.MIXUP_MODE = 'batch'
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Testing settings
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
_C.TEST = CN()
|
||||||
|
# Whether to use center crop when testing
|
||||||
|
_C.TEST.CROP = True
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Misc
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
|
||||||
|
# overwritten by command line argument
|
||||||
|
_C.AMP_OPT_LEVEL = ''
|
||||||
|
# Path to output folder, overwritten by command line argument
|
||||||
|
_C.OUTPUT = ''
|
||||||
|
# Tag of experiment, overwritten by command line argument
|
||||||
|
_C.TAG = 'default'
|
||||||
|
# Frequency to save checkpoint
|
||||||
|
_C.SAVE_FREQ = 1
|
||||||
|
# Frequency to logging info
|
||||||
|
_C.PRINT_FREQ = 10
|
||||||
|
# Fixed random seed
|
||||||
|
_C.SEED = 0
|
||||||
|
# Perform evaluation only, overwritten by command line argument
|
||||||
|
_C.EVAL_MODE = False
|
||||||
|
# Test throughput only, overwritten by command line argument
|
||||||
|
_C.THROUGHPUT_MODE = False
|
||||||
|
# local rank for DistributedDataParallel, given by command line argument
|
||||||
|
_C.LOCAL_RANK = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _update_config_from_file(config, cfg_file):
|
||||||
|
config.defrost()
|
||||||
|
with open(cfg_file, 'r') as f:
|
||||||
|
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
||||||
|
if cfg:
|
||||||
|
_update_config_from_file(
|
||||||
|
config, os.path.join(os.path.dirname(cfg_file), cfg)
|
||||||
|
)
|
||||||
|
print('=> merge config from {}'.format(cfg_file))
|
||||||
|
config.merge_from_file(cfg_file)
|
||||||
|
config.freeze()
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(config, args):
|
||||||
|
_update_config_from_file(config, args.cfg)
|
||||||
|
|
||||||
|
config.defrost()
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
|
||||||
|
# merge from specific arguments
|
||||||
|
if args.batch_size:
|
||||||
|
config.DATA.BATCH_SIZE = args.batch_size
|
||||||
|
if args.data_path:
|
||||||
|
config.DATA.DATA_PATH = args.data_path
|
||||||
|
if args.zip:
|
||||||
|
config.DATA.ZIP_MODE = True
|
||||||
|
if args.cache_mode:
|
||||||
|
config.DATA.CACHE_MODE = args.cache_mode
|
||||||
|
if args.resume:
|
||||||
|
config.MODEL.RESUME = args.resume
|
||||||
|
if args.accumulation_steps:
|
||||||
|
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
|
||||||
|
if args.use_checkpoint:
|
||||||
|
config.TRAIN.USE_CHECKPOINT = True
|
||||||
|
if args.amp_opt_level:
|
||||||
|
config.AMP_OPT_LEVEL = args.amp_opt_level
|
||||||
|
if args.output:
|
||||||
|
config.OUTPUT = args.output
|
||||||
|
if args.tag:
|
||||||
|
config.TAG = args.tag
|
||||||
|
if args.eval:
|
||||||
|
config.EVAL_MODE = True
|
||||||
|
if args.throughput:
|
||||||
|
config.THROUGHPUT_MODE = True
|
||||||
|
|
||||||
|
# set local rank for distributed training
|
||||||
|
config.LOCAL_RANK = args.local_rank
|
||||||
|
|
||||||
|
# output folder
|
||||||
|
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
||||||
|
|
||||||
|
config.freeze()
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(args):
|
||||||
|
"""Get a yacs CfgNode object with default values."""
|
||||||
|
# Return a clone so that the defaults will not be altered
|
||||||
|
# This is for the "local variable" use pattern
|
||||||
|
config = _C.clone()
|
||||||
|
update_config(config, args)
|
||||||
|
|
||||||
|
return config
|
13
configs/swin_base_patch4_window12_384.yaml
Normal file
13
configs/swin_base_patch4_window12_384.yaml
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# only for evaluation
|
||||||
|
DATA:
|
||||||
|
IMG_SIZE: 384
|
||||||
|
MODEL:
|
||||||
|
TYPE: swin
|
||||||
|
NAME: swin_base_patch4_window12_384
|
||||||
|
SWIN:
|
||||||
|
EMBED_DIM: 128
|
||||||
|
DEPTHS: [ 2, 2, 18, 2 ]
|
||||||
|
NUM_HEADS: [ 4, 8, 16, 32 ]
|
||||||
|
WINDOW_SIZE: 12
|
||||||
|
TEST:
|
||||||
|
CROP: False
|
9
configs/swin_base_patch4_window7_224.yaml
Normal file
9
configs/swin_base_patch4_window7_224.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: swin
|
||||||
|
NAME: swin_base_patch4_window7_224
|
||||||
|
DROP_PATH_RATE: 0.5
|
||||||
|
SWIN:
|
||||||
|
EMBED_DIM: 128
|
||||||
|
DEPTHS: [ 2, 2, 18, 2 ]
|
||||||
|
NUM_HEADS: [ 4, 8, 16, 32 ]
|
||||||
|
WINDOW_SIZE: 7
|
13
configs/swin_large_patch4_window12_384.yaml
Normal file
13
configs/swin_large_patch4_window12_384.yaml
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# only for evaluation
|
||||||
|
DATA:
|
||||||
|
IMG_SIZE: 384
|
||||||
|
MODEL:
|
||||||
|
TYPE: swin
|
||||||
|
NAME: swin_large_patch4_window12_384
|
||||||
|
SWIN:
|
||||||
|
EMBED_DIM: 192
|
||||||
|
DEPTHS: [ 2, 2, 18, 2 ]
|
||||||
|
NUM_HEADS: [ 6, 12, 24, 48 ]
|
||||||
|
WINDOW_SIZE: 12
|
||||||
|
TEST:
|
||||||
|
CROP: False
|
9
configs/swin_large_patch4_window7_224.yaml
Normal file
9
configs/swin_large_patch4_window7_224.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# only for evaluation
|
||||||
|
MODEL:
|
||||||
|
TYPE: swin
|
||||||
|
NAME: swin_large_patch4_window7_224
|
||||||
|
SWIN:
|
||||||
|
EMBED_DIM: 192
|
||||||
|
DEPTHS: [ 2, 2, 18, 2 ]
|
||||||
|
NUM_HEADS: [ 6, 12, 24, 48 ]
|
||||||
|
WINDOW_SIZE: 7
|
9
configs/swin_small_patch4_window7_224.yaml
Normal file
9
configs/swin_small_patch4_window7_224.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: swin
|
||||||
|
NAME: swin_small_patch4_window7_224
|
||||||
|
DROP_PATH_RATE: 0.3
|
||||||
|
SWIN:
|
||||||
|
EMBED_DIM: 96
|
||||||
|
DEPTHS: [ 2, 2, 18, 2 ]
|
||||||
|
NUM_HEADS: [ 3, 6, 12, 24 ]
|
||||||
|
WINDOW_SIZE: 7
|
9
configs/swin_tiny_patch4_window7_224.yaml
Normal file
9
configs/swin_tiny_patch4_window7_224.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: swin
|
||||||
|
NAME: swin_tiny_patch4_window7_224
|
||||||
|
DROP_PATH_RATE: 0.2
|
||||||
|
SWIN:
|
||||||
|
EMBED_DIM: 96
|
||||||
|
DEPTHS: [ 2, 2, 6, 2 ]
|
||||||
|
NUM_HEADS: [ 3, 6, 12, 24 ]
|
||||||
|
WINDOW_SIZE: 7
|
1
data/__init__.py
Normal file
1
data/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .build import build_loader
|
128
data/build.py
Normal file
128
data/build.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from timm.data import Mixup
|
||||||
|
from timm.data import create_transform
|
||||||
|
from timm.data.transforms import _pil_interp
|
||||||
|
|
||||||
|
from .cached_image_folder import CachedImageFolder
|
||||||
|
from .samplers import SubsetRandomSampler
|
||||||
|
|
||||||
|
|
||||||
|
def build_loader(config):
|
||||||
|
config.defrost()
|
||||||
|
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
|
||||||
|
config.freeze()
|
||||||
|
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
|
||||||
|
dataset_val, _ = build_dataset(is_train=False, config=config)
|
||||||
|
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
|
||||||
|
|
||||||
|
num_tasks = dist.get_world_size()
|
||||||
|
global_rank = dist.get_rank()
|
||||||
|
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
|
||||||
|
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
|
||||||
|
sampler_train = SubsetRandomSampler(indices)
|
||||||
|
else:
|
||||||
|
sampler_train = torch.utils.data.DistributedSampler(
|
||||||
|
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
|
||||||
|
sampler_val = SubsetRandomSampler(indices)
|
||||||
|
|
||||||
|
data_loader_train = torch.utils.data.DataLoader(
|
||||||
|
dataset_train, sampler=sampler_train,
|
||||||
|
batch_size=config.DATA.BATCH_SIZE,
|
||||||
|
num_workers=config.DATA.NUM_WORKERS,
|
||||||
|
pin_memory=config.DATA.PIN_MEMORY,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_loader_val = torch.utils.data.DataLoader(
|
||||||
|
dataset_val, sampler=sampler_val,
|
||||||
|
batch_size=config.DATA.BATCH_SIZE,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=config.DATA.NUM_WORKERS,
|
||||||
|
pin_memory=config.DATA.PIN_MEMORY,
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup mixup / cutmix
|
||||||
|
mixup_fn = None
|
||||||
|
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
|
||||||
|
if mixup_active:
|
||||||
|
mixup_fn = Mixup(
|
||||||
|
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
|
||||||
|
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
|
||||||
|
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
|
||||||
|
|
||||||
|
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(is_train, config):
|
||||||
|
transform = build_transform(is_train, config)
|
||||||
|
if config.DATA.DATASET == 'imagenet':
|
||||||
|
prefix = 'train' if is_train else 'val'
|
||||||
|
if config.DATA.ZIP_MODE:
|
||||||
|
ann_file = prefix + "_map.txt"
|
||||||
|
prefix = prefix + ".zip@/"
|
||||||
|
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
|
||||||
|
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
|
||||||
|
else:
|
||||||
|
root = os.path.join(config.DATA.DATA_PATH, prefix)
|
||||||
|
dataset = datasets.ImageFolder(root, transform=transform)
|
||||||
|
nb_classes = 1000
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("We only support ImageNet Now.")
|
||||||
|
|
||||||
|
return dataset, nb_classes
|
||||||
|
|
||||||
|
|
||||||
|
def build_transform(is_train, config):
|
||||||
|
resize_im = config.DATA.IMG_SIZE > 32
|
||||||
|
if is_train:
|
||||||
|
# this should always dispatch to transforms_imagenet_train
|
||||||
|
transform = create_transform(
|
||||||
|
input_size=config.DATA.IMG_SIZE,
|
||||||
|
is_training=True,
|
||||||
|
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
|
||||||
|
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
|
||||||
|
re_prob=config.AUG.REPROB,
|
||||||
|
re_mode=config.AUG.REMODE,
|
||||||
|
re_count=config.AUG.RECOUNT,
|
||||||
|
interpolation=config.DATA.INTERPOLATION,
|
||||||
|
)
|
||||||
|
if not resize_im:
|
||||||
|
# replace RandomResizedCropAndInterpolation with
|
||||||
|
# RandomCrop
|
||||||
|
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
|
||||||
|
return transform
|
||||||
|
|
||||||
|
t = []
|
||||||
|
if resize_im:
|
||||||
|
if config.TEST.CROP:
|
||||||
|
size = int((256 / 224) * config.DATA.IMG_SIZE)
|
||||||
|
t.append(
|
||||||
|
transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
|
||||||
|
# to maintain same ratio w.r.t. 224 images
|
||||||
|
)
|
||||||
|
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
|
||||||
|
else:
|
||||||
|
t.append(
|
||||||
|
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
||||||
|
interpolation=_pil_interp(config.DATA.INTERPOLATION))
|
||||||
|
)
|
||||||
|
|
||||||
|
t.append(transforms.ToTensor())
|
||||||
|
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
||||||
|
return transforms.Compose(t)
|
251
data/cached_image_folder.py
Normal file
251
data/cached_image_folder.py
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.utils.data as data
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .zipreader import is_zip_path, ZipReader
|
||||||
|
|
||||||
|
|
||||||
|
def has_file_allowed_extension(filename, extensions):
|
||||||
|
"""Checks if a file is an allowed extension.
|
||||||
|
Args:
|
||||||
|
filename (string): path to a file
|
||||||
|
Returns:
|
||||||
|
bool: True if the filename ends with a known image extension
|
||||||
|
"""
|
||||||
|
filename_lower = filename.lower()
|
||||||
|
return any(filename_lower.endswith(ext) for ext in extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def find_classes(dir):
|
||||||
|
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
|
||||||
|
classes.sort()
|
||||||
|
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||||
|
return classes, class_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
def make_dataset(dir, class_to_idx, extensions):
|
||||||
|
images = []
|
||||||
|
dir = os.path.expanduser(dir)
|
||||||
|
for target in sorted(os.listdir(dir)):
|
||||||
|
d = os.path.join(dir, target)
|
||||||
|
if not os.path.isdir(d):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for root, _, fnames in sorted(os.walk(d)):
|
||||||
|
for fname in sorted(fnames):
|
||||||
|
if has_file_allowed_extension(fname, extensions):
|
||||||
|
path = os.path.join(root, fname)
|
||||||
|
item = (path, class_to_idx[target])
|
||||||
|
images.append(item)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def make_dataset_with_ann(ann_file, img_prefix, extensions):
|
||||||
|
images = []
|
||||||
|
with open(ann_file, "r") as f:
|
||||||
|
contents = f.readlines()
|
||||||
|
for line_str in contents:
|
||||||
|
path_contents = [c for c in line_str.split('\t')]
|
||||||
|
im_file_name = path_contents[0]
|
||||||
|
class_index = int(path_contents[1])
|
||||||
|
|
||||||
|
assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
|
||||||
|
item = (os.path.join(img_prefix, im_file_name), class_index)
|
||||||
|
|
||||||
|
images.append(item)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFolder(data.Dataset):
|
||||||
|
"""A generic data loader where the samples are arranged in this way: ::
|
||||||
|
root/class_x/xxx.ext
|
||||||
|
root/class_x/xxy.ext
|
||||||
|
root/class_x/xxz.ext
|
||||||
|
root/class_y/123.ext
|
||||||
|
root/class_y/nsdf3.ext
|
||||||
|
root/class_y/asd932_.ext
|
||||||
|
Args:
|
||||||
|
root (string): Root directory path.
|
||||||
|
loader (callable): A function to load a sample given its path.
|
||||||
|
extensions (list[string]): A list of allowed extensions.
|
||||||
|
transform (callable, optional): A function/transform that takes in
|
||||||
|
a sample and returns a transformed version.
|
||||||
|
E.g, ``transforms.RandomCrop`` for images.
|
||||||
|
target_transform (callable, optional): A function/transform that takes
|
||||||
|
in the target and transforms it.
|
||||||
|
Attributes:
|
||||||
|
samples (list): List of (sample path, class_index) tuples
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
|
||||||
|
cache_mode="no"):
|
||||||
|
# image folder mode
|
||||||
|
if ann_file == '':
|
||||||
|
_, class_to_idx = find_classes(root)
|
||||||
|
samples = make_dataset(root, class_to_idx, extensions)
|
||||||
|
# zip mode
|
||||||
|
else:
|
||||||
|
samples = make_dataset_with_ann(os.path.join(root, ann_file),
|
||||||
|
os.path.join(root, img_prefix),
|
||||||
|
extensions)
|
||||||
|
|
||||||
|
if len(samples) == 0:
|
||||||
|
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
|
||||||
|
"Supported extensions are: " + ",".join(extensions)))
|
||||||
|
|
||||||
|
self.root = root
|
||||||
|
self.loader = loader
|
||||||
|
self.extensions = extensions
|
||||||
|
|
||||||
|
self.samples = samples
|
||||||
|
self.labels = [y_1k for _, y_1k in samples]
|
||||||
|
self.classes = list(set(self.labels))
|
||||||
|
|
||||||
|
self.transform = transform
|
||||||
|
self.target_transform = target_transform
|
||||||
|
|
||||||
|
self.cache_mode = cache_mode
|
||||||
|
if self.cache_mode != "no":
|
||||||
|
self.init_cache()
|
||||||
|
|
||||||
|
def init_cache(self):
|
||||||
|
assert self.cache_mode in ["part", "full"]
|
||||||
|
n_sample = len(self.samples)
|
||||||
|
global_rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
samples_bytes = [None for _ in range(n_sample)]
|
||||||
|
start_time = time.time()
|
||||||
|
for index in range(n_sample):
|
||||||
|
if index % (n_sample // 10) == 0:
|
||||||
|
t = time.time() - start_time
|
||||||
|
print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
|
||||||
|
start_time = time.time()
|
||||||
|
path, target = self.samples[index]
|
||||||
|
if self.cache_mode == "full":
|
||||||
|
samples_bytes[index] = (ZipReader.read(path), target)
|
||||||
|
elif self.cache_mode == "part" and index % world_size == global_rank:
|
||||||
|
samples_bytes[index] = (ZipReader.read(path), target)
|
||||||
|
else:
|
||||||
|
samples_bytes[index] = (path, target)
|
||||||
|
self.samples = samples_bytes
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
Returns:
|
||||||
|
tuple: (sample, target) where target is class_index of the target class.
|
||||||
|
"""
|
||||||
|
path, target = self.samples[index]
|
||||||
|
sample = self.loader(path)
|
||||||
|
if self.transform is not None:
|
||||||
|
sample = self.transform(sample)
|
||||||
|
if self.target_transform is not None:
|
||||||
|
target = self.target_transform(target)
|
||||||
|
|
||||||
|
return sample, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||||
|
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
||||||
|
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||||
|
tmp = ' Transforms (if any): '
|
||||||
|
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||||
|
tmp = ' Target Transforms (if any): '
|
||||||
|
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||||
|
return fmt_str
|
||||||
|
|
||||||
|
|
||||||
|
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
|
||||||
|
|
||||||
|
|
||||||
|
def pil_loader(path):
|
||||||
|
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
img = Image.open(io.BytesIO(path))
|
||||||
|
elif is_zip_path(path):
|
||||||
|
data = ZipReader.read(path)
|
||||||
|
img = Image.open(io.BytesIO(data))
|
||||||
|
else:
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
img = Image.open(f)
|
||||||
|
return img.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
|
def accimage_loader(path):
|
||||||
|
import accimage
|
||||||
|
try:
|
||||||
|
return accimage.Image(path)
|
||||||
|
except IOError:
|
||||||
|
# Potentially a decoding problem, fall back to PIL.Image
|
||||||
|
return pil_loader(path)
|
||||||
|
|
||||||
|
|
||||||
|
def default_img_loader(path):
|
||||||
|
from torchvision import get_image_backend
|
||||||
|
if get_image_backend() == 'accimage':
|
||||||
|
return accimage_loader(path)
|
||||||
|
else:
|
||||||
|
return pil_loader(path)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedImageFolder(DatasetFolder):
|
||||||
|
"""A generic data loader where the images are arranged in this way: ::
|
||||||
|
root/dog/xxx.png
|
||||||
|
root/dog/xxy.png
|
||||||
|
root/dog/xxz.png
|
||||||
|
root/cat/123.png
|
||||||
|
root/cat/nsdf3.png
|
||||||
|
root/cat/asd932_.png
|
||||||
|
Args:
|
||||||
|
root (string): Root directory path.
|
||||||
|
transform (callable, optional): A function/transform that takes in an PIL image
|
||||||
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||||
|
target_transform (callable, optional): A function/transform that takes in the
|
||||||
|
target and transforms it.
|
||||||
|
loader (callable, optional): A function to load an image given its path.
|
||||||
|
Attributes:
|
||||||
|
imgs (list): List of (image path, class_index) tuples
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
|
||||||
|
loader=default_img_loader, cache_mode="no"):
|
||||||
|
super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
|
||||||
|
ann_file=ann_file, img_prefix=img_prefix,
|
||||||
|
transform=transform, target_transform=target_transform,
|
||||||
|
cache_mode=cache_mode)
|
||||||
|
self.imgs = self.samples
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
Returns:
|
||||||
|
tuple: (image, target) where target is class_index of the target class.
|
||||||
|
"""
|
||||||
|
path, target = self.samples[index]
|
||||||
|
image = self.loader(path)
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(image)
|
||||||
|
else:
|
||||||
|
img = image
|
||||||
|
if self.target_transform is not None:
|
||||||
|
target = self.target_transform(target)
|
||||||
|
|
||||||
|
return img, target
|
29
data/samplers.py
Normal file
29
data/samplers.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SubsetRandomSampler(torch.utils.data.Sampler):
|
||||||
|
r"""Samples elements randomly from a given list of indices, without replacement.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
indices (sequence): a sequence of indices
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, indices):
|
||||||
|
self.epoch = 0
|
||||||
|
self.indices = indices
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return (self.indices[i] for i in torch.randperm(len(self.indices)))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
self.epoch = epoch
|
103
data/zipreader.py
Normal file
103
data/zipreader.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import os
|
||||||
|
import zipfile
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from PIL import ImageFile
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
|
||||||
|
def is_zip_path(img_or_path):
|
||||||
|
"""judge if this is a zip path"""
|
||||||
|
return '.zip@' in img_or_path
|
||||||
|
|
||||||
|
|
||||||
|
class ZipReader(object):
|
||||||
|
"""A class to read zipped files"""
|
||||||
|
zip_bank = dict()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ZipReader, self).__init__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_zipfile(path):
|
||||||
|
zip_bank = ZipReader.zip_bank
|
||||||
|
if path not in zip_bank:
|
||||||
|
zfile = zipfile.ZipFile(path, 'r')
|
||||||
|
zip_bank[path] = zfile
|
||||||
|
return zip_bank[path]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_zip_style_path(path):
|
||||||
|
pos_at = path.index('@')
|
||||||
|
assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
|
||||||
|
|
||||||
|
zip_path = path[0: pos_at]
|
||||||
|
folder_path = path[pos_at + 1:]
|
||||||
|
folder_path = str.strip(folder_path, '/')
|
||||||
|
return zip_path, folder_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_folder(path):
|
||||||
|
zip_path, folder_path = ZipReader.split_zip_style_path(path)
|
||||||
|
|
||||||
|
zfile = ZipReader.get_zipfile(zip_path)
|
||||||
|
folder_list = []
|
||||||
|
for file_foler_name in zfile.namelist():
|
||||||
|
file_foler_name = str.strip(file_foler_name, '/')
|
||||||
|
if file_foler_name.startswith(folder_path) and \
|
||||||
|
len(os.path.splitext(file_foler_name)[-1]) == 0 and \
|
||||||
|
file_foler_name != folder_path:
|
||||||
|
if len(folder_path) == 0:
|
||||||
|
folder_list.append(file_foler_name)
|
||||||
|
else:
|
||||||
|
folder_list.append(file_foler_name[len(folder_path) + 1:])
|
||||||
|
|
||||||
|
return folder_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_files(path, extension=None):
|
||||||
|
if extension is None:
|
||||||
|
extension = ['.*']
|
||||||
|
zip_path, folder_path = ZipReader.split_zip_style_path(path)
|
||||||
|
|
||||||
|
zfile = ZipReader.get_zipfile(zip_path)
|
||||||
|
file_lists = []
|
||||||
|
for file_foler_name in zfile.namelist():
|
||||||
|
file_foler_name = str.strip(file_foler_name, '/')
|
||||||
|
if file_foler_name.startswith(folder_path) and \
|
||||||
|
str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
|
||||||
|
if len(folder_path) == 0:
|
||||||
|
file_lists.append(file_foler_name)
|
||||||
|
else:
|
||||||
|
file_lists.append(file_foler_name[len(folder_path) + 1:])
|
||||||
|
|
||||||
|
return file_lists
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read(path):
|
||||||
|
zip_path, path_img = ZipReader.split_zip_style_path(path)
|
||||||
|
zfile = ZipReader.get_zipfile(zip_path)
|
||||||
|
data = zfile.read(path_img)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def imread(path):
|
||||||
|
zip_path, path_img = ZipReader.split_zip_style_path(path)
|
||||||
|
zfile = ZipReader.get_zipfile(zip_path)
|
||||||
|
data = zfile.read(path_img)
|
||||||
|
try:
|
||||||
|
im = Image.open(io.BytesIO(data))
|
||||||
|
except:
|
||||||
|
print("ERROR IMG LOADED: ", path_img)
|
||||||
|
random_img = np.random.rand(224, 224, 3) * 255
|
||||||
|
im = Image.fromarray(np.uint8(random_img))
|
||||||
|
return im
|
BIN
figures/teaser.png
Normal file
BIN
figures/teaser.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 909 KiB |
203
get_started.md
Normal file
203
get_started.md
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
# Swin Transformer for Image Classification
|
||||||
|
|
||||||
|
This folder contains the implementation of the Swin Transformer for image classification.
|
||||||
|
|
||||||
|
## Model Zoo
|
||||||
|
|
||||||
|
### Regular ImageNet-1K trained models
|
||||||
|
|
||||||
|
| name | resolution |acc@1 | acc@5 | #params | FLOPs | model |
|
||||||
|
|:---:|:---:|:---:|:---:| :---:| :---:|:---:|
|
||||||
|
| Swin-T | 224x224 | 81.2 | 95.5 | 28M | 4.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
|
||||||
|
| Swin-S | 224x224 | 83.2 | 96.2 | 50M | 8.7G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
|
||||||
|
| Swin-B | 224x224 | 83.5 | 96.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
|
||||||
|
| Swin-B | 384x384 | 84.5 | 97.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
|
||||||
|
|
||||||
|
### ImageNet-22K pre-trained models
|
||||||
|
|
||||||
|
| name | resolution |acc@1 | acc@5 | #params | FLOPs | 22K model | 1K model |
|
||||||
|
|:---: |:---: |:---:|:---:|:---:|:---:|:---:|:---:|
|
||||||
|
| Swin-B | 224x224 | 85.2 | 97.5 | 88M | 15.4G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
|
||||||
|
| Swin-B | 384x384 | 86.4 | 98.0 | 88M | 47.1G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
|
||||||
|
| Swin-L | 224x224 | 86.3 | 97.9 | 197M | 34.5G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
|
||||||
|
| Swin-L | 384x384 | 87.3 | 98.2 | 197M | 103.9G | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
|
||||||
|
|
||||||
|
Note: access code for `baidu` is `swin`.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Install
|
||||||
|
|
||||||
|
- Clone this repo:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/microsoft/Swin-Transformer.git
|
||||||
|
cd Swin-Transformer
|
||||||
|
```
|
||||||
|
|
||||||
|
- Create a conda virtual environment and activate it:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n swin python=3.7 -y
|
||||||
|
conda activate swin
|
||||||
|
```
|
||||||
|
|
||||||
|
- Install `CUDA==10.1` with `cudnn7` following
|
||||||
|
the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
|
||||||
|
- Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
- Install `timm==0.3.2`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install timm==0.3.2
|
||||||
|
```
|
||||||
|
|
||||||
|
- Install `Apex`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/NVIDIA/apex
|
||||||
|
cd apex
|
||||||
|
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||||
|
```
|
||||||
|
|
||||||
|
- Install other requirements:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data preparation
|
||||||
|
|
||||||
|
We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
|
||||||
|
load data:
|
||||||
|
|
||||||
|
- For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
|
||||||
|
```bash
|
||||||
|
$ tree data
|
||||||
|
imagenet
|
||||||
|
├── train
|
||||||
|
│ ├── class1
|
||||||
|
│ │ ├── img1.jpeg
|
||||||
|
│ │ ├── img2.jpeg
|
||||||
|
│ │ └── ...
|
||||||
|
│ ├── class2
|
||||||
|
│ │ ├── img3.jpeg
|
||||||
|
│ │ └── ...
|
||||||
|
│ └── ...
|
||||||
|
└── val
|
||||||
|
├── class1
|
||||||
|
│ ├── img4.jpeg
|
||||||
|
│ ├── img5.jpeg
|
||||||
|
│ └── ...
|
||||||
|
├── class2
|
||||||
|
│ ├── img6.jpeg
|
||||||
|
│ └── ...
|
||||||
|
└── ...
|
||||||
|
|
||||||
|
```
|
||||||
|
- To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
|
||||||
|
four files:
|
||||||
|
- `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
|
||||||
|
- `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth
|
||||||
|
label. Make sure the data folder looks like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ tree data
|
||||||
|
data
|
||||||
|
└── ImageNet-Zip
|
||||||
|
├── train_map.txt
|
||||||
|
├── train.zip
|
||||||
|
├── val_map.txt
|
||||||
|
└── val.zip
|
||||||
|
|
||||||
|
$ head -n 5 data/ImageNet-Zip/val_map.txt
|
||||||
|
ILSVRC2012_val_00000001.JPEG 65
|
||||||
|
ILSVRC2012_val_00000002.JPEG 970
|
||||||
|
ILSVRC2012_val_00000003.JPEG 230
|
||||||
|
ILSVRC2012_val_00000004.JPEG 809
|
||||||
|
ILSVRC2012_val_00000005.JPEG 516
|
||||||
|
|
||||||
|
$ head -n 5 data/ImageNet-Zip/train_map.txt
|
||||||
|
n01440764/n01440764_10026.JPEG 0
|
||||||
|
n01440764/n01440764_10027.JPEG 0
|
||||||
|
n01440764/n01440764_10029.JPEG 0
|
||||||
|
n01440764/n01440764_10040.JPEG 0
|
||||||
|
n01440764/n01440764_10042.JPEG 0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
To evaluate a pre-trained `Swin Transformer` on ImageNet val, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \
|
||||||
|
--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, to evaluate the `Swin-B` with a single GPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
|
||||||
|
--cfg configs/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training from scratch
|
||||||
|
|
||||||
|
To train a `Swin Transformer` on ImageNet from scratch, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \
|
||||||
|
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Notes**:
|
||||||
|
|
||||||
|
- To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters.
|
||||||
|
- To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will
|
||||||
|
shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.
|
||||||
|
- When GPU memory is not enough, you can try the following suggestions:
|
||||||
|
- Use gradient accumulation by adding `--accumulation-steps <steps>`, set appropriate `<steps>` according to your need.
|
||||||
|
- Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`.
|
||||||
|
Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details.
|
||||||
|
- We recommend using multi-node with more GPUs for training very large models, a tutorial can be found
|
||||||
|
in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
|
||||||
|
- To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g.,
|
||||||
|
`--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5.
|
||||||
|
- For additional options, see [config](config.py) and run `python main.py --help` to get detailed message.
|
||||||
|
|
||||||
|
For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run:
|
||||||
|
|
||||||
|
`Swin-T`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
|
||||||
|
--cfg configs/swin_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
|
||||||
|
```
|
||||||
|
|
||||||
|
`Swin-S`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
|
||||||
|
--cfg configs/swin_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
|
||||||
|
```
|
||||||
|
|
||||||
|
`Swin-B`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
|
||||||
|
--cfg configs/swin_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \
|
||||||
|
--accumulation-steps 2 [--use-checkpoint]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Throughput
|
||||||
|
|
||||||
|
To measure the throughput, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \
|
||||||
|
--cfg <config-file> --data-path <imagenet-path> --batch-size 64 --throughput --amp-opt-level O0
|
||||||
|
```
|
41
logger.py
Normal file
41
logger.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import functools
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def create_logger(output_dir, dist_rank=0, name=''):
|
||||||
|
# create logger
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
# create formatter
|
||||||
|
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
|
||||||
|
color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
|
||||||
|
colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
|
||||||
|
|
||||||
|
# create console handlers for master process
|
||||||
|
if dist_rank == 0:
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(logging.DEBUG)
|
||||||
|
console_handler.setFormatter(
|
||||||
|
logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# create file handlers
|
||||||
|
file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
return logger
|
102
lr_scheduler.py
Normal file
102
lr_scheduler.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from timm.scheduler.cosine_lr import CosineLRScheduler
|
||||||
|
from timm.scheduler.step_lr import StepLRScheduler
|
||||||
|
from timm.scheduler.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def build_scheduler(config, optimizer, n_iter_per_epoch):
|
||||||
|
num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
|
||||||
|
warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
|
||||||
|
decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
|
||||||
|
|
||||||
|
lr_scheduler = None
|
||||||
|
if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
|
||||||
|
lr_scheduler = CosineLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
t_initial=num_steps,
|
||||||
|
t_mul=1.,
|
||||||
|
lr_min=config.TRAIN.MIN_LR,
|
||||||
|
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
||||||
|
warmup_t=warmup_steps,
|
||||||
|
cycle_limit=1,
|
||||||
|
t_in_epochs=False,
|
||||||
|
)
|
||||||
|
elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
|
||||||
|
lr_scheduler = LinearLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
t_initial=num_steps,
|
||||||
|
lr_min_rate=0.01,
|
||||||
|
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
||||||
|
warmup_t=warmup_steps,
|
||||||
|
t_in_epochs=False,
|
||||||
|
)
|
||||||
|
elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
|
||||||
|
lr_scheduler = StepLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
decay_t=decay_steps,
|
||||||
|
decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
|
||||||
|
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
||||||
|
warmup_t=warmup_steps,
|
||||||
|
t_in_epochs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLRScheduler(Scheduler):
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
t_initial: int,
|
||||||
|
lr_min_rate: float,
|
||||||
|
warmup_t=0,
|
||||||
|
warmup_lr_init=0.,
|
||||||
|
t_in_epochs=True,
|
||||||
|
noise_range_t=None,
|
||||||
|
noise_pct=0.67,
|
||||||
|
noise_std=1.0,
|
||||||
|
noise_seed=42,
|
||||||
|
initialize=True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
optimizer, param_group_field="lr",
|
||||||
|
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||||
|
initialize=initialize)
|
||||||
|
|
||||||
|
self.t_initial = t_initial
|
||||||
|
self.lr_min_rate = lr_min_rate
|
||||||
|
self.warmup_t = warmup_t
|
||||||
|
self.warmup_lr_init = warmup_lr_init
|
||||||
|
self.t_in_epochs = t_in_epochs
|
||||||
|
if self.warmup_t:
|
||||||
|
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||||
|
super().update_groups(self.warmup_lr_init)
|
||||||
|
else:
|
||||||
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
|
def _get_lr(self, t):
|
||||||
|
if t < self.warmup_t:
|
||||||
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
|
else:
|
||||||
|
t = t - self.warmup_t
|
||||||
|
total_t = self.t_initial - self.warmup_t
|
||||||
|
lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
|
||||||
|
return lrs
|
||||||
|
|
||||||
|
def get_epoch_values(self, epoch: int):
|
||||||
|
if self.t_in_epochs:
|
||||||
|
return self._get_lr(epoch)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_update_values(self, num_updates: int):
|
||||||
|
if not self.t_in_epochs:
|
||||||
|
return self._get_lr(num_updates)
|
||||||
|
else:
|
||||||
|
return None
|
345
main.py
Normal file
345
main.py
Normal file
|
@ -0,0 +1,345 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||||
|
from timm.utils import accuracy, AverageMeter
|
||||||
|
|
||||||
|
from config import get_config
|
||||||
|
from models import build_model
|
||||||
|
from data import build_loader
|
||||||
|
from lr_scheduler import build_scheduler
|
||||||
|
from optimizer import build_optimizer
|
||||||
|
from logger import create_logger
|
||||||
|
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
from apex import amp
|
||||||
|
except ImportError:
|
||||||
|
amp = None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_option():
|
||||||
|
parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
|
||||||
|
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
|
||||||
|
parser.add_argument(
|
||||||
|
"--opts",
|
||||||
|
help="Modify config options by adding 'KEY VALUE' pairs. ",
|
||||||
|
default=None,
|
||||||
|
nargs='+',
|
||||||
|
)
|
||||||
|
|
||||||
|
# easy config modification
|
||||||
|
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
|
||||||
|
parser.add_argument('--data-path', type=str, help='path to dataset')
|
||||||
|
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
|
||||||
|
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
|
||||||
|
help='no: no cache, '
|
||||||
|
'full: cache all data, '
|
||||||
|
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
|
||||||
|
parser.add_argument('--resume', help='resume from checkpoint')
|
||||||
|
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
|
||||||
|
parser.add_argument('--use-checkpoint', action='store_true',
|
||||||
|
help="whether to use gradient checkpointing to save memory")
|
||||||
|
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
|
||||||
|
help='mixed precision opt level, if O0, no amp is used')
|
||||||
|
parser.add_argument('--output', default='output', type=str, metavar='PATH',
|
||||||
|
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
|
||||||
|
parser.add_argument('--tag', help='tag of experiment')
|
||||||
|
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
||||||
|
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
|
||||||
|
|
||||||
|
# distributed training
|
||||||
|
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
||||||
|
|
||||||
|
args, unparsed = parser.parse_known_args()
|
||||||
|
|
||||||
|
config = get_config(args)
|
||||||
|
|
||||||
|
return args, config
|
||||||
|
|
||||||
|
|
||||||
|
def main(config):
|
||||||
|
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
|
||||||
|
|
||||||
|
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
||||||
|
model = build_model(config)
|
||||||
|
model.cuda()
|
||||||
|
logger.info(str(model))
|
||||||
|
|
||||||
|
optimizer = build_optimizer(config, model)
|
||||||
|
if config.AMP_OPT_LEVEL != "O0":
|
||||||
|
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
|
||||||
|
model_without_ddp = model.module
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
logger.info(f"number of params: {n_parameters}")
|
||||||
|
if hasattr(model_without_ddp, 'flops'):
|
||||||
|
flops = model_without_ddp.flops()
|
||||||
|
logger.info(f"number of GFLOPs: {flops / 1e9}")
|
||||||
|
|
||||||
|
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
|
||||||
|
|
||||||
|
if config.AUG.MIXUP > 0.:
|
||||||
|
# smoothing is handled with mixup label transform
|
||||||
|
criterion = SoftTargetCrossEntropy()
|
||||||
|
elif config.MODEL.LABEL_SMOOTHING > 0.:
|
||||||
|
criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
|
||||||
|
else:
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
max_accuracy = 0.0
|
||||||
|
|
||||||
|
if config.TRAIN.AUTO_RESUME:
|
||||||
|
resume_file = auto_resume_helper(config.OUTPUT)
|
||||||
|
if resume_file:
|
||||||
|
if config.MODEL.RESUME:
|
||||||
|
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
|
||||||
|
config.MODEL.RESUME = resume_file
|
||||||
|
logger.info(f'auto resuming from {resume_file}')
|
||||||
|
else:
|
||||||
|
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
|
||||||
|
|
||||||
|
if config.MODEL.RESUME:
|
||||||
|
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
|
||||||
|
acc1, acc5, loss = validate(config, data_loader_val, model)
|
||||||
|
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
||||||
|
if config.EVAL_MODE:
|
||||||
|
return
|
||||||
|
|
||||||
|
if config.THROUGHPUT_MODE:
|
||||||
|
throughput(data_loader_val, model, logger)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Start training")
|
||||||
|
start_time = time.time()
|
||||||
|
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
|
||||||
|
data_loader_train.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
|
||||||
|
if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
|
||||||
|
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
|
||||||
|
|
||||||
|
acc1, acc5, loss = validate(config, data_loader_val, model)
|
||||||
|
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
||||||
|
max_accuracy = max(max_accuracy, acc1)
|
||||||
|
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
logger.info('Training time {}'.format(total_time_str))
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler):
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
num_steps = len(data_loader)
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
loss_meter = AverageMeter()
|
||||||
|
norm_meter = AverageMeter()
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
end = time.time()
|
||||||
|
for idx, (samples, targets) in enumerate(data_loader):
|
||||||
|
samples = samples.cuda(non_blocking=True)
|
||||||
|
targets = targets.cuda(non_blocking=True)
|
||||||
|
|
||||||
|
if mixup_fn is not None:
|
||||||
|
samples, targets = mixup_fn(samples, targets)
|
||||||
|
|
||||||
|
outputs = model(samples)
|
||||||
|
|
||||||
|
if config.TRAIN.ACCUMULATION_STEPS > 1:
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
loss = loss / config.TRAIN.ACCUMULATION_STEPS
|
||||||
|
if config.AMP_OPT_LEVEL != "O0":
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
if config.TRAIN.CLIP_GRAD:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
|
||||||
|
else:
|
||||||
|
grad_norm = get_grad_norm(amp.master_params(optimizer))
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
if config.TRAIN.CLIP_GRAD:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
|
||||||
|
else:
|
||||||
|
grad_norm = get_grad_norm(model.parameters())
|
||||||
|
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
lr_scheduler.step_update(epoch * num_steps + idx)
|
||||||
|
else:
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if config.AMP_OPT_LEVEL != "O0":
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
if config.TRAIN.CLIP_GRAD:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
|
||||||
|
else:
|
||||||
|
grad_norm = get_grad_norm(amp.master_params(optimizer))
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
if config.TRAIN.CLIP_GRAD:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
|
||||||
|
else:
|
||||||
|
grad_norm = get_grad_norm(model.parameters())
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step_update(epoch * num_steps + idx)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
loss_meter.update(loss.item(), targets.size(0))
|
||||||
|
norm_meter.update(grad_norm)
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if idx % config.PRINT_FREQ == 0:
|
||||||
|
lr = optimizer.param_groups[0]['lr']
|
||||||
|
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
||||||
|
etas = batch_time.avg * (num_steps - idx)
|
||||||
|
logger.info(
|
||||||
|
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
|
||||||
|
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
|
||||||
|
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
|
||||||
|
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
|
||||||
|
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
|
||||||
|
f'mem {memory_used:.0f}MB')
|
||||||
|
epoch_time = time.time() - start
|
||||||
|
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate(config, data_loader, model):
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
loss_meter = AverageMeter()
|
||||||
|
acc1_meter = AverageMeter()
|
||||||
|
acc5_meter = AverageMeter()
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
for idx, (images, target) in enumerate(data_loader):
|
||||||
|
images = images.cuda(non_blocking=True)
|
||||||
|
target = target.cuda(non_blocking=True)
|
||||||
|
|
||||||
|
# compute output
|
||||||
|
output = model(images)
|
||||||
|
|
||||||
|
# measure accuracy and record loss
|
||||||
|
loss = criterion(output, target)
|
||||||
|
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||||
|
|
||||||
|
acc1 = reduce_tensor(acc1)
|
||||||
|
acc5 = reduce_tensor(acc5)
|
||||||
|
loss = reduce_tensor(loss)
|
||||||
|
|
||||||
|
loss_meter.update(loss.item(), target.size(0))
|
||||||
|
acc1_meter.update(acc1.item(), target.size(0))
|
||||||
|
acc5_meter.update(acc5.item(), target.size(0))
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if idx % config.PRINT_FREQ == 0:
|
||||||
|
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
||||||
|
logger.info(
|
||||||
|
f'Test: [{idx}/{len(data_loader)}]\t'
|
||||||
|
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||||
|
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
|
||||||
|
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
|
||||||
|
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
|
||||||
|
f'Mem {memory_used:.0f}MB')
|
||||||
|
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
|
||||||
|
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def throughput(data_loader, model, logger):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for idx, (images, _) in enumerate(data_loader):
|
||||||
|
images = images.cuda(non_blocking=True)
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
for i in range(50):
|
||||||
|
model(images)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
logger.info(f"throughput averaged with 30 times")
|
||||||
|
tic1 = time.time()
|
||||||
|
for i in range(30):
|
||||||
|
model(images)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
tic2 = time.time()
|
||||||
|
logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_, config = parse_option()
|
||||||
|
|
||||||
|
if config.AMP_OPT_LEVEL != "O0":
|
||||||
|
assert amp is not None, "amp not installed!"
|
||||||
|
|
||||||
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
|
||||||
|
else:
|
||||||
|
rank = -1
|
||||||
|
world_size = -1
|
||||||
|
torch.cuda.set_device(config.LOCAL_RANK)
|
||||||
|
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
seed = config.SEED + dist.get_rank()
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
# linear scale the learning rate according to total batch size, may not be optimal
|
||||||
|
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
|
||||||
|
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
|
||||||
|
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
|
||||||
|
# gradient accumulation also need to scale the learning rate
|
||||||
|
if config.TRAIN.ACCUMULATION_STEPS > 1:
|
||||||
|
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
|
||||||
|
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
|
||||||
|
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
|
||||||
|
config.defrost()
|
||||||
|
config.TRAIN.BASE_LR = linear_scaled_lr
|
||||||
|
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
|
||||||
|
config.TRAIN.MIN_LR = linear_scaled_min_lr
|
||||||
|
config.freeze()
|
||||||
|
|
||||||
|
os.makedirs(config.OUTPUT, exist_ok=True)
|
||||||
|
logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
path = os.path.join(config.OUTPUT, "config.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(config.dump())
|
||||||
|
logger.info(f"Full config saved to {path}")
|
||||||
|
|
||||||
|
# print config
|
||||||
|
logger.info(config.dump())
|
||||||
|
|
||||||
|
main(config)
|
1
models/__init__.py
Normal file
1
models/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .build import build_model
|
33
models/build.py
Normal file
33
models/build.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from .swin_transformer import SwinTransformer
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(config):
|
||||||
|
model_type = config.MODEL.TYPE
|
||||||
|
if model_type == 'swin':
|
||||||
|
model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
|
||||||
|
patch_size=config.MODEL.SWIN.PATCH_SIZE,
|
||||||
|
in_chans=config.MODEL.SWIN.IN_CHANS,
|
||||||
|
num_classes=config.MODEL.NUM_CLASSES,
|
||||||
|
embed_dim=config.MODEL.SWIN.EMBED_DIM,
|
||||||
|
depths=config.MODEL.SWIN.DEPTHS,
|
||||||
|
num_heads=config.MODEL.SWIN.NUM_HEADS,
|
||||||
|
window_size=config.MODEL.SWIN.WINDOW_SIZE,
|
||||||
|
mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
|
||||||
|
qkv_bias=config.MODEL.SWIN.QKV_BIAS,
|
||||||
|
qk_scale=config.MODEL.SWIN.QK_SCALE,
|
||||||
|
drop_rate=config.MODEL.DROP_RATE,
|
||||||
|
drop_path_rate=config.MODEL.DROP_PATH_RATE,
|
||||||
|
ape=config.MODEL.SWIN.APE,
|
||||||
|
patch_norm=config.MODEL.SWIN.PATCH_NORM,
|
||||||
|
use_checkpoint=config.TRAIN.USE_CHECKPOINT)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unkown model: {model_type}")
|
||||||
|
|
||||||
|
return model
|
585
models/swin_transformer.py
Normal file
585
models/swin_transformer.py
Normal file
|
@ -0,0 +1,585 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
window_size (int): window size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
"""
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size, H, W):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
window_size (int): Window size
|
||||||
|
H (int): Height of image
|
||||||
|
W (int): Width of image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
"""
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WindowAttention(nn.Module):
|
||||||
|
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||||
|
It supports both of shifted and non-shifted window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
window_size (tuple[int]): The height and width of the window.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||||
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||||
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.window_size = window_size # Wh, Ww
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
# define a parameter table of relative position bias
|
||||||
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||||
|
|
||||||
|
# get pair-wise relative position index for each token inside the window
|
||||||
|
coords_h = torch.arange(self.window_size[0])
|
||||||
|
coords_w = torch.arange(self.window_size[1])
|
||||||
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||||
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||||
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||||
|
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||||
|
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||||
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||||
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||||
|
self.register_buffer("relative_position_index", relative_position_index)
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: input features with shape of (num_windows*B, N, C)
|
||||||
|
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||||
|
"""
|
||||||
|
B_, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
|
||||||
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||||
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||||
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
nW = mask.shape[0]
|
||||||
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||||
|
attn = attn.view(-1, self.num_heads, N, N)
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
else:
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
||||||
|
|
||||||
|
def flops(self, N):
|
||||||
|
# calculate flops for 1 window with token length of N
|
||||||
|
flops = 0
|
||||||
|
# qkv = self.qkv(x)
|
||||||
|
flops += N * self.dim * 3 * self.dim
|
||||||
|
# attn = (q @ k.transpose(-2, -1))
|
||||||
|
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||||
|
# x = (attn @ v)
|
||||||
|
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||||
|
# x = self.proj(x)
|
||||||
|
flops += N * self.dim * self.dim
|
||||||
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformerBlock(nn.Module):
|
||||||
|
r""" Swin Transformer Block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
input_resolution (tuple[int]): Input resulotion.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
window_size (int): Window size.
|
||||||
|
shift_size (int): Shift size for SW-MSA.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||||
|
drop (float, optional): Dropout rate. Default: 0.0
|
||||||
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||||
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||||
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||||
|
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
||||||
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_size = shift_size
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
if min(self.input_resolution) <= self.window_size:
|
||||||
|
# if window size is larger than input resolution, we don't partition windows
|
||||||
|
self.shift_size = 0
|
||||||
|
self.window_size = min(self.input_resolution)
|
||||||
|
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = WindowAttention(
|
||||||
|
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||||
|
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
if self.shift_size > 0:
|
||||||
|
# calculate attention mask for SW-MSA
|
||||||
|
H, W = self.input_resolution
|
||||||
|
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||||
|
h_slices = (slice(0, -self.window_size),
|
||||||
|
slice(-self.window_size, -self.shift_size),
|
||||||
|
slice(-self.shift_size, None))
|
||||||
|
w_slices = (slice(0, -self.window_size),
|
||||||
|
slice(-self.window_size, -self.shift_size),
|
||||||
|
slice(-self.shift_size, None))
|
||||||
|
cnt = 0
|
||||||
|
for h in h_slices:
|
||||||
|
for w in w_slices:
|
||||||
|
img_mask[:, h, w, :] = cnt
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||||
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
self.register_buffer("attn_mask", attn_mask)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
H, W = self.input_resolution
|
||||||
|
B, L, C = x.shape
|
||||||
|
assert L == H * W, "input feature has wrong size"
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
# cyclic shift
|
||||||
|
if self.shift_size > 0:
|
||||||
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||||
|
else:
|
||||||
|
shifted_x = x
|
||||||
|
|
||||||
|
# partition windows
|
||||||
|
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||||
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||||
|
|
||||||
|
# W-MSA/SW-MSA
|
||||||
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||||
|
|
||||||
|
# merge windows
|
||||||
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||||
|
|
||||||
|
# reverse cyclic shift
|
||||||
|
if self.shift_size > 0:
|
||||||
|
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||||
|
else:
|
||||||
|
x = shifted_x
|
||||||
|
x = x.view(B, H * W, C)
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
x = shortcut + self.drop_path(x)
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
||||||
|
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
||||||
|
|
||||||
|
def flops(self):
|
||||||
|
flops = 0
|
||||||
|
H, W = self.input_resolution
|
||||||
|
# norm1
|
||||||
|
flops += self.dim * H * W
|
||||||
|
# W-MSA/SW-MSA
|
||||||
|
nW = H * W / self.window_size / self.window_size
|
||||||
|
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
||||||
|
# mlp
|
||||||
|
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
||||||
|
# norm2
|
||||||
|
flops += self.dim * H * W
|
||||||
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMerging(nn.Module):
|
||||||
|
r""" Patch Merging Layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_resolution (tuple[int]): Resolution of input feature.
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
self.dim = dim
|
||||||
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||||
|
self.norm = norm_layer(4 * dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: B, H*W, C
|
||||||
|
"""
|
||||||
|
H, W = self.input_resolution
|
||||||
|
B, L, C = x.shape
|
||||||
|
assert L == H * W, "input feature has wrong size"
|
||||||
|
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||||
|
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||||
|
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.reduction(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||||
|
|
||||||
|
def flops(self):
|
||||||
|
H, W = self.input_resolution
|
||||||
|
flops = H * W * self.dim
|
||||||
|
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||||
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
class BasicLayer(nn.Module):
|
||||||
|
""" A basic Swin Transformer layer for one stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
input_resolution (tuple[int]): Input resolution.
|
||||||
|
depth (int): Number of blocks.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
window_size (int): Local window size.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||||
|
drop (float, optional): Dropout rate. Default: 0.0
|
||||||
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||||
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||||
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||||
|
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||||
|
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
self.depth = depth
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
|
||||||
|
# build blocks
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
||||||
|
num_heads=num_heads, window_size=window_size,
|
||||||
|
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
drop=drop, attn_drop=attn_drop,
|
||||||
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
|
norm_layer=norm_layer)
|
||||||
|
for i in range(depth)])
|
||||||
|
|
||||||
|
# patch merging layer
|
||||||
|
if downsample is not None:
|
||||||
|
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for blk in self.blocks:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
x = checkpoint.checkpoint(blk, x)
|
||||||
|
else:
|
||||||
|
x = blk(x)
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||||
|
|
||||||
|
def flops(self):
|
||||||
|
flops = 0
|
||||||
|
for blk in self.blocks:
|
||||||
|
flops += blk.flops()
|
||||||
|
if self.downsample is not None:
|
||||||
|
flops += self.downsample.flops()
|
||||||
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
r""" Image to Patch Embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (int): Image size. Default: 224.
|
||||||
|
patch_size (int): Patch token size. Default: 4.
|
||||||
|
in_chans (int): Number of input image channels. Default: 3.
|
||||||
|
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||||
|
super().__init__()
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patches_resolution = patches_resolution
|
||||||
|
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||||
|
|
||||||
|
self.in_chans = in_chans
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
if norm_layer is not None:
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
# FIXME look at relaxing size constraints
|
||||||
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||||
|
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
||||||
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def flops(self):
|
||||||
|
Ho, Wo = self.patches_resolution
|
||||||
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||||
|
if self.norm is not None:
|
||||||
|
flops += Ho * Wo * self.embed_dim
|
||||||
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformer(nn.Module):
|
||||||
|
r""" Swin Transformer
|
||||||
|
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||||
|
https://arxiv.org/pdf/2103.14030
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (int | tuple(int)): Input image size. Default 224
|
||||||
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||||
|
in_chans (int): Number of input image channels. Default: 3
|
||||||
|
num_classes (int): Number of classes for classification head. Default: 1000
|
||||||
|
embed_dim (int): Patch embedding dimension. Default: 96
|
||||||
|
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||||
|
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||||
|
window_size (int): Window size. Default: 7
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||||
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||||
|
drop_rate (float): Dropout rate. Default: 0
|
||||||
|
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||||
|
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||||
|
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||||
|
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||||
|
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||||
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
||||||
|
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
||||||
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
|
use_checkpoint=False, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_layers = len(depths)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.ape = ape
|
||||||
|
self.patch_norm = patch_norm
|
||||||
|
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
# split image into non-overlapping patches
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||||
|
norm_layer=norm_layer if self.patch_norm else None)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
patches_resolution = self.patch_embed.patches_resolution
|
||||||
|
self.patches_resolution = patches_resolution
|
||||||
|
|
||||||
|
# absolute position embedding
|
||||||
|
if self.ape:
|
||||||
|
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||||
|
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||||
|
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
# stochastic depth
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||||
|
|
||||||
|
# build layers
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i_layer in range(self.num_layers):
|
||||||
|
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
|
||||||
|
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
||||||
|
patches_resolution[1] // (2 ** i_layer)),
|
||||||
|
depth=depths[i_layer],
|
||||||
|
num_heads=num_heads[i_layer],
|
||||||
|
window_size=window_size,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||||
|
use_checkpoint=use_checkpoint)
|
||||||
|
self.layers.append(layer)
|
||||||
|
|
||||||
|
self.norm = norm_layer(self.num_features)
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||||
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {'absolute_pos_embed'}
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay_keywords(self):
|
||||||
|
return {'relative_position_bias_table'}
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if self.ape:
|
||||||
|
x = x + self.absolute_pos_embed
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
x = self.norm(x) # B L C
|
||||||
|
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def flops(self):
|
||||||
|
flops = 0
|
||||||
|
flops += self.patch_embed.flops()
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
flops += layer.flops()
|
||||||
|
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
|
||||||
|
flops += self.num_features * self.num_classes
|
||||||
|
return flops
|
57
optimizer.py
Normal file
57
optimizer.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from torch import optim as optim
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(config, model):
|
||||||
|
"""
|
||||||
|
Build optimizer, set weight decay of normalization to 0 by default.
|
||||||
|
"""
|
||||||
|
skip = {}
|
||||||
|
skip_keywords = {}
|
||||||
|
if hasattr(model, 'no_weight_decay'):
|
||||||
|
skip = model.no_weight_decay()
|
||||||
|
if hasattr(model, 'no_weight_decay_keywords'):
|
||||||
|
skip_keywords = model.no_weight_decay_keywords()
|
||||||
|
parameters = set_weight_decay(model, skip, skip_keywords)
|
||||||
|
|
||||||
|
opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
|
||||||
|
optimizer = None
|
||||||
|
if opt_lower == 'sgd':
|
||||||
|
optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
|
||||||
|
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
||||||
|
elif opt_lower == 'adamw':
|
||||||
|
optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
|
||||||
|
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def set_weight_decay(model, skip_list=(), skip_keywords=()):
|
||||||
|
has_decay = []
|
||||||
|
no_decay = []
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue # frozen weights
|
||||||
|
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
|
||||||
|
check_keywords_in_name(name, skip_keywords):
|
||||||
|
no_decay.append(param)
|
||||||
|
# print(f"{name} has no weight decay")
|
||||||
|
else:
|
||||||
|
has_decay.append(param)
|
||||||
|
return [{'params': has_decay},
|
||||||
|
{'params': no_decay, 'weight_decay': 0.}]
|
||||||
|
|
||||||
|
|
||||||
|
def check_keywords_in_name(name, keywords=()):
|
||||||
|
isin = False
|
||||||
|
for keyword in keywords:
|
||||||
|
if keyword in name:
|
||||||
|
isin = True
|
||||||
|
return isin
|
90
utils.py
Normal file
90
utils.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Ze Liu
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
from apex import amp
|
||||||
|
except ImportError:
|
||||||
|
amp = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
|
||||||
|
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
|
||||||
|
if config.MODEL.RESUME.startswith('https'):
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
config.MODEL.RESUME, map_location='cpu', check_hash=True)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
|
||||||
|
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
||||||
|
logger.info(msg)
|
||||||
|
max_accuracy = 0.0
|
||||||
|
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
||||||
|
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
|
||||||
|
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
|
||||||
|
amp.load_state_dict(checkpoint['amp'])
|
||||||
|
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
|
||||||
|
if 'max_accuracy' in checkpoint:
|
||||||
|
max_accuracy = checkpoint['max_accuracy']
|
||||||
|
|
||||||
|
del checkpoint
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return max_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
|
||||||
|
save_state = {'model': model.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'lr_scheduler': lr_scheduler.state_dict(),
|
||||||
|
'max_accuracy': max_accuracy,
|
||||||
|
'epoch': epoch,
|
||||||
|
'config': config}
|
||||||
|
if config.AMP_OPT_LEVEL != "O0":
|
||||||
|
save_state['amp'] = amp.state_dict()
|
||||||
|
|
||||||
|
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
|
||||||
|
logger.info(f"{save_path} saving......")
|
||||||
|
torch.save(save_state, save_path)
|
||||||
|
logger.info(f"{save_path} saved !!!")
|
||||||
|
|
||||||
|
|
||||||
|
def get_grad_norm(parameters, norm_type=2):
|
||||||
|
if isinstance(parameters, torch.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
total_norm = 0
|
||||||
|
for p in parameters:
|
||||||
|
param_norm = p.grad.data.norm(norm_type)
|
||||||
|
total_norm += param_norm.item() ** norm_type
|
||||||
|
total_norm = total_norm ** (1. / norm_type)
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
|
def auto_resume_helper(output_dir):
|
||||||
|
checkpoints = os.listdir(output_dir)
|
||||||
|
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
|
||||||
|
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
|
||||||
|
if len(checkpoints) > 0:
|
||||||
|
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
|
||||||
|
print(f"The latest checkpoint founded: {latest_checkpoint}")
|
||||||
|
resume_file = latest_checkpoint
|
||||||
|
else:
|
||||||
|
resume_file = None
|
||||||
|
return resume_file
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_tensor(tensor):
|
||||||
|
rt = tensor.clone()
|
||||||
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||||
|
rt /= dist.get_world_size()
|
||||||
|
return rt
|
Loading…
Reference in New Issue
Block a user