Compare commits

..

12 Commits
LR-Net ... main

Author SHA1 Message Date
7486fe1cc6 convert pytorch model to paddlepaddle model 2021-05-06 23:03:23 +08:00
Han Hu
eec0e76fdb
Merge pull request #41 from kamalkraj/patch-1
Update README.md
2021-04-21 19:33:08 +08:00
Han Hu
008d5afc38
More explanation on adding links to 3rd party repo 2021-04-21 13:06:40 +08:00
Kamal Raj
0fccdb0b6f
Update README.md 2021-04-20 21:33:14 +05:30
Ze Liu
081385b99d
fix link to Swin-L-IN22K 2021-04-16 12:12:28 +08:00
Han Hu
110fb0e4e9
Merge pull request #27 from caoyue10/patch-2
Add a cross link to third party usage
2021-04-16 11:01:52 +08:00
Yue Cao
fa35a5e435
Add a cross link to third party usage 2021-04-16 10:04:48 +08:00
Han Hu
bd122d3b42
Add cross link to third party usage 2021-04-14 22:06:53 +08:00
Ze Liu
6cc8ebd200
minor bug fixes (#25) 2021-04-14 20:50:10 +08:00
Han Hu
c7e6d6efbd
Add explanation of the name "Swin" 2021-04-13 14:06:04 +08:00
Ze Liu
657aeeeefa
Merge pull request #13 from caoyue10/patch-1
Fix a typo.
2021-04-13 11:42:51 +08:00
Yue Cao
5e785157ed
Fix a typo. 2021-04-13 11:39:23 +08:00
11 changed files with 136 additions and 1471 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

120
README.md
View File

@ -1,24 +1,92 @@
# Local Relation Networks V2 (LR-Net V2) # Swin Transformer
This branch is an improved implementation of ["Local Relation Networks for Image Recognition (LR-Net)"](https://arxiv.org/pdf/1904.11491.pdf). The original LR-Net utilizes sliding window based self-attention layer to replace the `3x3` convolution layers in a ResNet architecture. This improved implementation applies this layer into a stronger overall architecture based on Tranformers, dubbed as LR-Net V2. We provide cuda kernels for the local relation layers. Training scripts and pre-trained models will be provided in the future. [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/instance-segmentation-on-coco-minival)](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-hierarchical-vision)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/semantic-segmentation-on-ade20k-val)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k-val?p=swin-transformer-hierarchical-vision)
## Install By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
```bash
cd ops/local_relation
python setup.py build_ext --inplace
```
## Citing Local Relation Networks This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf). It currently includes code and models for the following tasks:
> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.
> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
## Updates
***04/12/2021***
Initial commits:
1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
## Introduction
**Swin Transformer** (the name `Swin` stands for **S**hifted **win**dow) is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
computation to non-overlapping local windows while also allowing for cross-window connection.
Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
ADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by a large margin.
![teaser](figures/teaser.png)
## Main Results on ImageNet with Pretrained Models
**ImageNet-1K and ImageNet-22K Pretrained Models**
| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
Note: access code for `baidu` is `swin`.
## Main Results on Downstream Tasks
**COCO Object Detection (2017 val)**
| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |
| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |
| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |
| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G |
| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G |
| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |
| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |
| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |
| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |
| Swin-L | HTC++<sup>*</sup> | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |
Note: <sup>*</sup> indicates multi-scale testing.
**ADE20K Semantic Segmentation (val)**
| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |
| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |
| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |
| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |
| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |
## Citing Swin Transformer
```
@inproceedings{hu2019local,
title={Local relation networks for image recognition},
author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={3464--3473},
year={2019}
}
```
``` ```
@article{liu2021Swin, @article{liu2021Swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
@ -28,6 +96,24 @@ python setup.py build_ext --inplace
} }
``` ```
## Getting Started
- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.
- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
## Third-party Usage and Experiments
***In this pargraph, we cross link third-party repositories which use Swin and report results. You can let us know by raising an issue***
(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)
[04/14/2021] Swin for RetinaNet in Detectron: https://github.com/xiaohu2015/SwinT_detectron2.
[04/16/2021] Included in a famous model zoo: https://github.com/rwightman/pytorch-image-models.
[04/20/2021] Swin-Transformer classifier inference using TorchServe: https://github.com/kamalkraj/Swin-Transformer-Serve
## Contributing ## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a This project welcomes contributions and suggestions. Most contributions require you to agree to a

View File

@ -110,7 +110,9 @@ def main(config):
if resume_file: if resume_file:
if config.MODEL.RESUME: if config.MODEL.RESUME:
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
config.defrost()
config.MODEL.RESUME = resume_file config.MODEL.RESUME = resume_file
config.freeze()
logger.info(f'auto resuming from {resume_file}') logger.info(f'auto resuming from {resume_file}')
else: else:
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')

View File

@ -1,4 +0,0 @@
*.so
_ext*
__pycache__
build

View File

@ -1 +0,0 @@
from .local_relation_func import local_relation

View File

@ -1,102 +0,0 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2019 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Han Hu, Jiarui Xu
# Modified by Ze Liu
# --------------------------------------------------------
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from . import local_relation_cuda
class LocalRelationFunction(Function):
@staticmethod
def forward(ctx,
query,
key,
value,
pos_weight,
kernel_size,
groups,
stride=1,
dilation=1,
scale=1.,
no_define_value=-100.,
norm_method=0,
sim_method=0,
batch_step=32):
for input in [query, key, value]:
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.kernel_size = _pair(kernel_size)
ctx.groups = groups
ctx.stride = stride
ctx.dilation = dilation
ctx.scale = scale
ctx.no_define_value = no_define_value
ctx.norm_method = norm_method
ctx.sim_method = sim_method
ctx.batch_step = batch_step
ctx.save_for_backward(query, key, value, pos_weight)
output = query.new_empty(
LocalRelationFunction._output_size(query, value))
scale_tensor = query.new_tensor([ctx.scale])
no_define_value_tensor = query.new_tensor([ctx.no_define_value])
if not input.is_cuda:
raise NotImplementedError
else:
batch_step = min(ctx.batch_step, query.shape[0])
local_relation_cuda.local_relation_forward_cuda(
query, key, value, pos_weight, scale_tensor, no_define_value_tensor,
output, ctx.kernel_size[1], ctx.kernel_size[0], ctx.groups,
ctx.dilation, ctx.stride, batch_step, ctx.norm_method, ctx.sim_method)
return output
@staticmethod
def backward(ctx, grad_output):
query, key, value, pos_weight = ctx.saved_tensors
grad_query = grad_key = grad_value = grad_pos_weight = None
scale_tensor = query.new_tensor(ctx.scale)
no_define_value_tensor = query.new_tensor(ctx.no_define_value)
if not grad_output.is_cuda:
raise NotImplementedError
else:
batch_step = min(ctx.batch_step, query.shape[0])
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2] or ctx.needs_input_grad[3]:
grad_query = torch.zeros_like(query)
grad_key = torch.zeros_like(key)
grad_value = torch.zeros_like(value)
grad_pos_weight = torch.zeros_like(pos_weight)
local_relation_cuda.local_relation_backward_cuda(
query, key, value, pos_weight,
scale_tensor, no_define_value_tensor, grad_output,
grad_query, grad_key, grad_value, grad_pos_weight,
ctx.kernel_size[1], ctx.kernel_size[0],
ctx.groups, ctx.dilation, ctx.stride, batch_step,
ctx.norm_method, ctx.sim_method)
return (grad_query, grad_key, grad_value, grad_pos_weight, None, None, None,
None, None, None, None, None, None)
@staticmethod
def _output_size(query, value):
output_size = (query.size(0), value.size(1), query.size(2), query.size(3))
return output_size
local_relation = LocalRelationFunction.apply

View File

@ -1,12 +0,0 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='local_relation',
ext_modules=[
CUDAExtension('local_relation_cuda', [
'src/local_relation_cuda.cpp',
'src/local_relation_cuda_kernel.cu',
]),
],
cmdclass={'build_ext': BuildExtension})

View File

@ -1,331 +0,0 @@
/*!
* Copyright (c) 2019 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file local_relation_cuda.cpp
* \brief
* \author Han Hu
* \modified by Jiarui Xu, Ze Liu
*/
#include <torch/extension.h>
#include <cmath>
#include <vector>
void similarity_compute_forward(
const at::Tensor key,
const at::Tensor query,
const at::Tensor pos_weight,
const int batch_size,
const int key_channels,
const int query_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const at::Tensor scale,
const at::Tensor no_define_value,
const int dilate,
const int stride,
const int in_height,
const int in_width,
const int sim_method,
at::Tensor output,
const int key_offset,
const int query_offset);
void similarity_compute_backward(
const at::Tensor key,
const at::Tensor query,
const at::Tensor output_grad,
const int batch_size,
const int key_channels,
const int query_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int key_per_group,
const at::Tensor scale,
const int dilate,
const int stride,
const int in_height,
const int in_width,
const int sim_method,
at::Tensor key_grad,
at::Tensor query_grad,
const int key_grad_offset,
const int query_grad_offset);
void aggregation_forward(
const at::Tensor value,
const at::Tensor softmax_data,
const int batch_size,
const int value_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int in_height,
const int in_width,
at::Tensor output,
const int value_offset,
const int output_offset);
void aggregation_value_backward(
const at::Tensor softmax_data,
const at::Tensor output_grad,
const int batch_size,
const int value_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int in_height,
const int in_width,
at::Tensor value_grad,
const int output_grad_offset,
const int value_grad_offset);
void aggregation_softmax_backward(
const at::Tensor value,
const at::Tensor output_grad,
const int batch_size,
const int value_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int in_height,
const int in_width,
at::Tensor softmax_grad,
const int value_offset,
const int output_grad_offset);
int local_relation_forward_cuda(
at::Tensor query,
at::Tensor key,
at::Tensor value,
at::Tensor pos_weight,
at::Tensor scale,
at::Tensor no_define_value,
at::Tensor output,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int batch_step,
const int norm_method,
const int sim_method)
{
query = query.contiguous();
key = key.contiguous();
value = value.contiguous();
pos_weight = pos_weight.contiguous();
const int query_channels = query.size(1);
const int key_channels = key.size(1);
const int value_channels = value.size(1);
const int batch_size = key.size(0);
const int height = query.size(2);
const int width = query.size(3);
const int in_height = key.size(2);
const int in_width = key.size(3);
const int batch_step_ = std::min(batch_size, batch_step);
const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width;
const int key_step = batch_step_ * key_channels * in_height * in_width;
const int query_step = batch_step_ * query_channels * height * width;
const int value_step = batch_step_ * value_channels * in_height * in_width;
const int output_step = batch_step_ * value_channels * height * width;
at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width});
int M = (batch_size - 1) / batch_step_ + 1;
for (int i = 0; i < M; ++i) {
int cur_batch_step = batch_step_;
if (i == M - 1) {
cur_batch_step = batch_size - (M - 1) * batch_step_;
if (cur_batch_step != batch_step_) {
sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options());
sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options());
}
// TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size");
}
similarity_compute_forward(key, query, pos_weight, cur_batch_step,
key_channels, query_channels, height, width,
kernel_height, kernel_width, num_group, scale, no_define_value,
dilate, stride, in_height, in_width, sim_method, sim_buffer,
key_step * i, query_step * i);
// softmax
if (norm_method == 0) {
softmax_buffer = sim_buffer.softmax(1);
}
else {
AT_ERROR("Not implemented yet");
}
aggregation_forward(value, softmax_buffer, cur_batch_step,
value_channels, height, width, kernel_height, kernel_width,
num_group, dilate, stride, in_height, in_width, output, value_step * i, output_step * i);
}
return 1;
}
int local_relation_backward_cuda(
at::Tensor query,
at::Tensor key,
at::Tensor value,
at::Tensor pos_weight,
at::Tensor scale,
at::Tensor no_define_value,
at::Tensor output_grad,
at::Tensor query_grad,
at::Tensor key_grad,
at::Tensor value_grad,
at::Tensor pos_weight_grad,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int batch_step,
const int norm_method,
const int sim_method)
{
query = query.contiguous();
key = key.contiguous();
value = value.contiguous();
pos_weight = pos_weight.contiguous();
output_grad = output_grad.contiguous();
query_grad = query_grad.contiguous();
key_grad = key_grad.contiguous();
value_grad = value_grad.contiguous();
pos_weight_grad = pos_weight_grad.contiguous();
const int query_channels = query.size(1);
const int key_channels = key.size(1);
const int value_channels = value.size(1);
const int batch_size = key.size(0);
const int height = query.size(2);
const int width = query.size(3);
const int in_height = key.size(2);
const int in_width = key.size(3);
const int key_per_group = query_channels / num_group;
const int batch_step_ = std::min(batch_size, batch_step);
const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width;
const int key_step = batch_step_ * key_channels * in_height * in_width;
const int query_step = batch_step_ * query_channels * height * width;
const int value_step = batch_step_ * value_channels * in_height * in_width;
const int output_step = batch_step_ * value_channels * height * width;
at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width},
query.options());
at::Tensor sim_grad_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
int M = (batch_size - 1) / batch_step_ + 1;
const int pos_weight_size = num_group * kernel_height * kernel_width;
for (int i = 0; i < M; ++i) {
int cur_batch_step = batch_step_;
if (i == M - 1) {
cur_batch_step = batch_size - (M - 1) * batch_step_;
if (cur_batch_step != batch_step_) {
sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options());
sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options());
sim_grad_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
}
// TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size");
}
similarity_compute_forward(key, query, pos_weight, cur_batch_step,
key_channels, query_channels, height, width,
kernel_height, kernel_width, num_group, scale, no_define_value,
dilate, stride, in_height, in_width, sim_method, sim_buffer,
key_step * i, query_step * i);
// softmax
if (norm_method == 0) {
softmax_buffer = sim_buffer.softmax(1);
}
else {
AT_ERROR("Not implemented yet");
}
aggregation_value_backward(softmax_buffer, output_grad, cur_batch_step,
value_channels, height, width, kernel_height, kernel_width,
num_group, dilate, stride, in_height, in_width, value_grad,
output_step * i, value_step * i);
aggregation_softmax_backward(value, output_grad, cur_batch_step,
value_channels, height, width, kernel_height, kernel_width,
num_group, dilate, stride, in_height, in_width, sim_buffer,
value_step * i, output_step * i);
if (norm_method == 0) {
sum_softmax_buffer = (softmax_buffer * sim_buffer).sum(1, true);
sim_grad_buffer = softmax_buffer * (sim_buffer - sum_softmax_buffer);
}
else {
AT_ERROR("Not implemented yet");
}
similarity_compute_backward(key, query, sim_grad_buffer, cur_batch_step,
key_channels, query_channels, height, width,
kernel_height, kernel_width, num_group, key_per_group, scale,
dilate, stride, in_height, in_width, sim_method, key_grad, query_grad,
key_step * i, query_step * i);
pos_weight_grad += sim_grad_buffer.view({cur_batch_step, num_group, kernel_height, kernel_width, height * width}).sum(4).sum(0);
}
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("local_relation_forward_cuda", &local_relation_forward_cuda,
"local relation forward (CUDA)");
m.def("local_relation_backward_cuda", &local_relation_backward_cuda,
"local relation backward (CUDA)");
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,28 @@
import torch
import numpy as np
from models.swin_transformer import SwinTransformer
# 构建输入
input_data = np.random.rand(1, 3, 224, 224).astype("float32")
swin_model_cfg_map = {
"swin_tiny_patch4_window7_224": {
"EMBED_DIM": 96,
"DEPTHS": [ 2, 2, 6, 2 ],
"NUM_HEADS": [ 3, 6, 12, 24 ],
"WINDOW_SIZE": 7,
}
}
model_name = "swin_tiny_patch4_window7_224"
torch_module = SwinTransformer(**swin_model_cfg_map[model_name])
torch_state_dict = torch.load("/home/andy/data/pretrained_models/{}.pth".format(model_name))["model"]
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_module,
save_dir="pd_{}".format(model_name),
jit_type="trace",
input_examples=[torch.tensor(input_data)])

View File

@ -29,7 +29,9 @@ def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
amp.load_state_dict(checkpoint['amp']) amp.load_state_dict(checkpoint['amp'])
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")