From be50b6cc5187e1a759e6bff2bff0553814ef1a91 Mon Sep 17 00:00:00 2001 From: v-zeliu1 Date: Tue, 13 Apr 2021 00:44:08 +0800 Subject: [PATCH] add LR-Net V2 --- README.md | 108 +- ops/local_relation/.gitignore | 4 + ops/local_relation/__init__.py | 1 + ops/local_relation/local_relation_func.py | 102 ++ ops/local_relation/setup.py | 12 + .../src/local_relation_cuda.cpp | 331 ++++++ .../src/local_relation_cuda_kernel.cu | 1004 +++++++++++++++++ 7 files changed, 1471 insertions(+), 91 deletions(-) create mode 100644 ops/local_relation/.gitignore create mode 100644 ops/local_relation/__init__.py create mode 100644 ops/local_relation/local_relation_func.py create mode 100644 ops/local_relation/setup.py create mode 100644 ops/local_relation/src/local_relation_cuda.cpp create mode 100644 ops/local_relation/src/local_relation_cuda_kernel.cu diff --git a/README.md b/README.md index 76299ef..a591402 100644 --- a/README.md +++ b/README.md @@ -1,92 +1,24 @@ -# Swin Transformer +# Local Relation Networks V2 (LR-Net V2) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-hierarchical-vision) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-hierarchical-vision) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=swin-transformer-hierarchical-vision) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/instance-segmentation-on-coco-minival)](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=swin-transformer-hierarchical-vision) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-hierarchical-vision) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-hierarchical-vision/semantic-segmentation-on-ade20k-val)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k-val?p=swin-transformer-hierarchical-vision) +This branch is an improved implementation of ["Local Relation Networks for Image Recognition (LR-Net)"](https://arxiv.org/pdf/1904.11491.pdf). The original LR-Net utilizes sliding window based self-attention layer to replace the `3x3` convolution layers in a ResNet architecture. This improved implementation applies this layer into a stronger overall architecture based on Tranformers, dubbed as LR-Net V2. We provide cuda kernels for the local relation layers. Training scripts and pre-trained models will be provided in the future. -By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/). +## Install +```bash +cd ops/local_relation +python setup.py build_ext --inplace +``` -This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf). It currently includes code and models for the following tasks: - -> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start. - -> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection). - -> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation). - -## Updates - -***04/12/2021*** - -Initial commits: - -1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided. -2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided. -3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net). - -## Introduction - -**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a -general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is -computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention -computation to non-overlapping local windows while also allowing for cross-window connection. - -Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and -ADE20K semantic segmentatiion (`53.5 mIoU` on val), surpassing previous models by a large margin. - -![teaser](figures/teaser.png) - -## Main Results on ImageNet with Pretrained Models - -**ImageNet-1K and ImageNet-22K Pretrained Models** - -| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model | -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: | -| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) | -| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) | -| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) | -| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) | -| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) | -| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) | -| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) | -| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) | - -Note: access code for `baidu` is `swin`. - -## Main Results on Downstream Tasks - -**COCO Object Detection (2017 val)** - -| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs | -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G | -| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G | -| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G | -| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G | -| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G | -| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G | -| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G | -| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G | -| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G | -| Swin-L | HTC++* | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - | - -Note: * indicates multi-scale testing. - -**ADE20K Semantic Segmentation (val)** - -| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs | -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G | -| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G | -| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G | -| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G | -| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G | - -## Citing Swin Transformer +## Citing Local Relation Networks +``` +@inproceedings{hu2019local, + title={Local relation networks for image recognition}, + author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={3464--3473}, + year={2019} +} +``` ``` @article{liu2021Swin, title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, @@ -96,12 +28,6 @@ Note: * indicates multi-scale testing. } ``` -## Getting Started - -- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions. -- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection). -- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation). - ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/ops/local_relation/.gitignore b/ops/local_relation/.gitignore new file mode 100644 index 0000000..c9c5499 --- /dev/null +++ b/ops/local_relation/.gitignore @@ -0,0 +1,4 @@ +*.so +_ext* +__pycache__ +build diff --git a/ops/local_relation/__init__.py b/ops/local_relation/__init__.py new file mode 100644 index 0000000..866a932 --- /dev/null +++ b/ops/local_relation/__init__.py @@ -0,0 +1 @@ +from .local_relation_func import local_relation diff --git a/ops/local_relation/local_relation_func.py b/ops/local_relation/local_relation_func.py new file mode 100644 index 0000000..534a4bd --- /dev/null +++ b/ops/local_relation/local_relation_func.py @@ -0,0 +1,102 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2019 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Han Hu, Jiarui Xu +# Modified by Ze Liu +# -------------------------------------------------------- + +import torch +from torch.autograd import Function +from torch.nn.modules.utils import _pair + +from . import local_relation_cuda + + +class LocalRelationFunction(Function): + + @staticmethod + def forward(ctx, + query, + key, + value, + pos_weight, + kernel_size, + groups, + stride=1, + dilation=1, + scale=1., + no_define_value=-100., + norm_method=0, + sim_method=0, + batch_step=32): + for input in [query, key, value]: + if input is not None and input.dim() != 4: + raise ValueError( + "Expected 4D tensor as input, got {}D tensor instead.".format( + input.dim())) + ctx.kernel_size = _pair(kernel_size) + ctx.groups = groups + ctx.stride = stride + ctx.dilation = dilation + ctx.scale = scale + ctx.no_define_value = no_define_value + ctx.norm_method = norm_method + ctx.sim_method = sim_method + ctx.batch_step = batch_step + + ctx.save_for_backward(query, key, value, pos_weight) + + output = query.new_empty( + LocalRelationFunction._output_size(query, value)) + + scale_tensor = query.new_tensor([ctx.scale]) + no_define_value_tensor = query.new_tensor([ctx.no_define_value]) + + if not input.is_cuda: + raise NotImplementedError + else: + batch_step = min(ctx.batch_step, query.shape[0]) + local_relation_cuda.local_relation_forward_cuda( + query, key, value, pos_weight, scale_tensor, no_define_value_tensor, + output, ctx.kernel_size[1], ctx.kernel_size[0], ctx.groups, + ctx.dilation, ctx.stride, batch_step, ctx.norm_method, ctx.sim_method) + return output + + @staticmethod + def backward(ctx, grad_output): + query, key, value, pos_weight = ctx.saved_tensors + + grad_query = grad_key = grad_value = grad_pos_weight = None + + scale_tensor = query.new_tensor(ctx.scale) + no_define_value_tensor = query.new_tensor(ctx.no_define_value) + + if not grad_output.is_cuda: + raise NotImplementedError + else: + batch_step = min(ctx.batch_step, query.shape[0]) + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2] or ctx.needs_input_grad[3]: + grad_query = torch.zeros_like(query) + grad_key = torch.zeros_like(key) + grad_value = torch.zeros_like(value) + grad_pos_weight = torch.zeros_like(pos_weight) + local_relation_cuda.local_relation_backward_cuda( + query, key, value, pos_weight, + scale_tensor, no_define_value_tensor, grad_output, + grad_query, grad_key, grad_value, grad_pos_weight, + ctx.kernel_size[1], ctx.kernel_size[0], + ctx.groups, ctx.dilation, ctx.stride, batch_step, + ctx.norm_method, ctx.sim_method) + + return (grad_query, grad_key, grad_value, grad_pos_weight, None, None, None, + None, None, None, None, None, None) + + @staticmethod + def _output_size(query, value): + output_size = (query.size(0), value.size(1), query.size(2), query.size(3)) + return output_size + + +local_relation = LocalRelationFunction.apply diff --git a/ops/local_relation/setup.py b/ops/local_relation/setup.py new file mode 100644 index 0000000..abddebc --- /dev/null +++ b/ops/local_relation/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='local_relation', + ext_modules=[ + CUDAExtension('local_relation_cuda', [ + 'src/local_relation_cuda.cpp', + 'src/local_relation_cuda_kernel.cu', + ]), + ], + cmdclass={'build_ext': BuildExtension}) diff --git a/ops/local_relation/src/local_relation_cuda.cpp b/ops/local_relation/src/local_relation_cuda.cpp new file mode 100644 index 0000000..aad58ca --- /dev/null +++ b/ops/local_relation/src/local_relation_cuda.cpp @@ -0,0 +1,331 @@ +/*! + * Copyright (c) 2019 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file local_relation_cuda.cpp + * \brief + * \author Han Hu + * \modified by Jiarui Xu, Ze Liu +*/ + +#include + +#include +#include + +void similarity_compute_forward( + const at::Tensor key, + const at::Tensor query, + const at::Tensor pos_weight, + const int batch_size, + const int key_channels, + const int query_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const at::Tensor scale, + const at::Tensor no_define_value, + const int dilate, + const int stride, + const int in_height, + const int in_width, + const int sim_method, + at::Tensor output, + const int key_offset, + const int query_offset); + +void similarity_compute_backward( + const at::Tensor key, + const at::Tensor query, + const at::Tensor output_grad, + const int batch_size, + const int key_channels, + const int query_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int key_per_group, + const at::Tensor scale, + const int dilate, + const int stride, + const int in_height, + const int in_width, + const int sim_method, + at::Tensor key_grad, + at::Tensor query_grad, + const int key_grad_offset, + const int query_grad_offset); + +void aggregation_forward( + const at::Tensor value, + const at::Tensor softmax_data, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + at::Tensor output, + const int value_offset, + const int output_offset); + +void aggregation_value_backward( + const at::Tensor softmax_data, + const at::Tensor output_grad, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + at::Tensor value_grad, + const int output_grad_offset, + const int value_grad_offset); + +void aggregation_softmax_backward( + const at::Tensor value, + const at::Tensor output_grad, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + at::Tensor softmax_grad, + const int value_offset, + const int output_grad_offset); + + +int local_relation_forward_cuda( + at::Tensor query, + at::Tensor key, + at::Tensor value, + at::Tensor pos_weight, + at::Tensor scale, + at::Tensor no_define_value, + at::Tensor output, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int batch_step, + const int norm_method, + const int sim_method) +{ + query = query.contiguous(); + key = key.contiguous(); + value = value.contiguous(); + pos_weight = pos_weight.contiguous(); + + const int query_channels = query.size(1); + const int key_channels = key.size(1); + const int value_channels = value.size(1); + const int batch_size = key.size(0); + const int height = query.size(2); + const int width = query.size(3); + const int in_height = key.size(2); + const int in_width = key.size(3); + + const int batch_step_ = std::min(batch_size, batch_step); + const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width; + + const int key_step = batch_step_ * key_channels * in_height * in_width; + const int query_step = batch_step_ * query_channels * height * width; + const int value_step = batch_step_ * value_channels * in_height * in_width; + const int output_step = batch_step_ * value_channels * height * width; + + at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width}, + query.options()); + + at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width}, + query.options()); + + at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width}); + + int M = (batch_size - 1) / batch_step_ + 1; + for (int i = 0; i < M; ++i) { + int cur_batch_step = batch_step_; + if (i == M - 1) { + cur_batch_step = batch_size - (M - 1) * batch_step_; + if (cur_batch_step != batch_step_) { + sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options()); + softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options()); + sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options()); + } + + // TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size"); + } + similarity_compute_forward(key, query, pos_weight, cur_batch_step, + key_channels, query_channels, height, width, + kernel_height, kernel_width, num_group, scale, no_define_value, + dilate, stride, in_height, in_width, sim_method, sim_buffer, + key_step * i, query_step * i); + + // softmax + if (norm_method == 0) { + softmax_buffer = sim_buffer.softmax(1); + } + else { + AT_ERROR("Not implemented yet"); + } + + aggregation_forward(value, softmax_buffer, cur_batch_step, + value_channels, height, width, kernel_height, kernel_width, + num_group, dilate, stride, in_height, in_width, output, value_step * i, output_step * i); + } + + return 1; + +} + +int local_relation_backward_cuda( + at::Tensor query, + at::Tensor key, + at::Tensor value, + at::Tensor pos_weight, + at::Tensor scale, + at::Tensor no_define_value, + at::Tensor output_grad, + at::Tensor query_grad, + at::Tensor key_grad, + at::Tensor value_grad, + at::Tensor pos_weight_grad, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int batch_step, + const int norm_method, + const int sim_method) +{ + query = query.contiguous(); + key = key.contiguous(); + value = value.contiguous(); + pos_weight = pos_weight.contiguous(); + + output_grad = output_grad.contiguous(); + query_grad = query_grad.contiguous(); + key_grad = key_grad.contiguous(); + value_grad = value_grad.contiguous(); + pos_weight_grad = pos_weight_grad.contiguous(); + + const int query_channels = query.size(1); + const int key_channels = key.size(1); + const int value_channels = value.size(1); + const int batch_size = key.size(0); + const int height = query.size(2); + const int width = query.size(3); + const int in_height = key.size(2); + const int in_width = key.size(3); + const int key_per_group = query_channels / num_group; + + const int batch_step_ = std::min(batch_size, batch_step); + const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width; + + const int key_step = batch_step_ * key_channels * in_height * in_width; + const int query_step = batch_step_ * query_channels * height * width; + const int value_step = batch_step_ * value_channels * in_height * in_width; + const int output_step = batch_step_ * value_channels * height * width; + + at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width}, + query.options()); + + at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width}, + query.options()); + + at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width}, + query.options()); + + at::Tensor sim_grad_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width}, + query.options()); + + int M = (batch_size - 1) / batch_step_ + 1; + + const int pos_weight_size = num_group * kernel_height * kernel_width; + + for (int i = 0; i < M; ++i) { + int cur_batch_step = batch_step_; + if (i == M - 1) { + cur_batch_step = batch_size - (M - 1) * batch_step_; + if (cur_batch_step != batch_step_) { + sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options()); + softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options()); + sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options()); + sim_grad_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options()); + } + // TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size"); + } + + similarity_compute_forward(key, query, pos_weight, cur_batch_step, + key_channels, query_channels, height, width, + kernel_height, kernel_width, num_group, scale, no_define_value, + dilate, stride, in_height, in_width, sim_method, sim_buffer, + key_step * i, query_step * i); + + // softmax + + if (norm_method == 0) { + softmax_buffer = sim_buffer.softmax(1); + } + else { + AT_ERROR("Not implemented yet"); + } + + aggregation_value_backward(softmax_buffer, output_grad, cur_batch_step, + value_channels, height, width, kernel_height, kernel_width, + num_group, dilate, stride, in_height, in_width, value_grad, + output_step * i, value_step * i); + + aggregation_softmax_backward(value, output_grad, cur_batch_step, + value_channels, height, width, kernel_height, kernel_width, + num_group, dilate, stride, in_height, in_width, sim_buffer, + value_step * i, output_step * i); + + if (norm_method == 0) { + sum_softmax_buffer = (softmax_buffer * sim_buffer).sum(1, true); + sim_grad_buffer = softmax_buffer * (sim_buffer - sum_softmax_buffer); + } + else { + AT_ERROR("Not implemented yet"); + } + + similarity_compute_backward(key, query, sim_grad_buffer, cur_batch_step, + key_channels, query_channels, height, width, + kernel_height, kernel_width, num_group, key_per_group, scale, + dilate, stride, in_height, in_width, sim_method, key_grad, query_grad, + key_step * i, query_step * i); + + pos_weight_grad += sim_grad_buffer.view({cur_batch_step, num_group, kernel_height, kernel_width, height * width}).sum(4).sum(0); + + } + + return 1; + +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("local_relation_forward_cuda", &local_relation_forward_cuda, + "local relation forward (CUDA)"); + m.def("local_relation_backward_cuda", &local_relation_backward_cuda, + "local relation backward (CUDA)"); +} diff --git a/ops/local_relation/src/local_relation_cuda_kernel.cu b/ops/local_relation/src/local_relation_cuda_kernel.cu new file mode 100644 index 0000000..1e84eb5 --- /dev/null +++ b/ops/local_relation/src/local_relation_cuda_kernel.cu @@ -0,0 +1,1004 @@ +/*! + * Copyright (c) 2019 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file local_relation_cuda_kernel.cu + * \brief + * \author Han Hu + * \modified by Jiarui Xu, Ze Liu +*/ + +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__global__ void similarity_compute_forward_kernel(const int n, + const scalar_t* key, + const scalar_t* query, + const scalar_t* pos_weight, + const int batch_size, + const int key_channels, + const int query_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const scalar_t* scale_ptr, + const scalar_t* no_define_value_ptr, + const int dilate, + const int stride, + const int in_height, + const int in_width, + const int sim_method, + scalar_t* output) { + // n = batch_size * num_group * kernel_height * kernel_width * height * width + const scalar_t scale = scale_ptr[0]; + const scalar_t no_define_value = no_define_value_ptr[0]; + CUDA_KERNEL_LOOP(index, n) { + const int w = index % width; + int h = index / width; + int kw = h / height; + h = h % height; + int kh = kw / kernel_width; + kw = kw % kernel_width; + int g = kh / kernel_height; + kh = kh % kernel_height; + const int b = g / num_group; + g = g % num_group; + + scalar_t sum_sim = 0; + const int half_kh = kernel_height / 2; + const int half_kw = kernel_width / 2; + if (sim_method >= 0){ + const int key_per_group = query_channels / num_group; + + const int spatial_dim = height * width; + const int in_spatial_dim = in_height * in_width; + int query_inds = 0; + if (sim_method != 1) { + query_inds = ((b * num_group + g) * key_per_group * height + h) * width + w; + } + const int key_saliency_group = key_channels - query_channels; + + if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) { + int key_inds = ((b * key_channels + g * key_per_group) * in_height + h * stride + dilate * (kh - half_kh)) * in_width + w * stride + dilate * (kw - half_kw); + + for (int i = 0; i < key_per_group; ++i) { + if (sim_method == 0) { + sum_sim += query[query_inds + i * spatial_dim] * key[key_inds + i * in_spatial_dim] * scale; + } + else if (sim_method == 1) { + sum_sim += key[key_inds + i * in_spatial_dim] * scale; + } + else if (sim_method == 2) { + sum_sim += -abs(query[query_inds + i * spatial_dim] - key[key_inds + i * in_spatial_dim]) * scale; + } + else if (sim_method == 3) { + scalar_t query_val = query[query_inds + i * spatial_dim]; + scalar_t key_val = key[key_inds + i * in_spatial_dim]; + sum_sim += -abs(query_val - key_val) / (abs(query_val) + abs(key_val) + scalar_t(1.0)) * scale; + } + else if (sim_method == 4) { + scalar_t query_val = query[query_inds + i * spatial_dim]; + scalar_t key_val = key[key_inds + i * in_spatial_dim]; + sum_sim += -(query_val - key_val) * (query_val - key_val) / (abs(query_val * key_val) + scalar_t(1.0)) * scale; + } + else if (sim_method == 5) { + scalar_t query_val = query[query_inds + i * spatial_dim]; + scalar_t key_val = key[key_inds + i * in_spatial_dim]; + sum_sim += -(query_val - key_val) * (query_val - key_val) * scale; + } + if (key_saliency_group > 0) { + int key_sal_inds = (b * key_channels + query_channels + int(g * key_saliency_group) / num_group) * in_spatial_dim + + (h * stride + dilate * (kh - half_kh)) * in_width + w * stride + dilate * (kw - half_kw); + sum_sim += key[key_sal_inds]; + } + } + } + else{ + sum_sim = no_define_value; + } + } + + if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) { + } + else { + sum_sim = no_define_value; + } + int pos_inds = (g * kernel_height + kh) * kernel_width + kw; + sum_sim += pos_weight[pos_inds]; + + output[index] = sum_sim; + } +} + +template +__global__ void similarity_compute_backward_kernel(const int n, + const scalar_t* key, + const scalar_t* query, + const scalar_t* output_grad, + const int batch_size, + const int key_channels, + const int query_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int key_per_group, + const scalar_t* scale_ptr, + const int dilate, + const int stride, + const int in_height, + const int in_width, + const int sim_method, + scalar_t* key_grad, + scalar_t* query_grad) { + const scalar_t scale = scale_ptr[0]; + CUDA_KERNEL_LOOP(index, n) { + const int w = index % width; + int h = index / width; + int kpg = h / height; + h = h % height; + int g = kpg / key_per_group; + kpg = kpg % key_per_group; + const int b = g / num_group; + g = g % num_group; + + const int half_kh = kernel_height / 2; + const int half_kw = kernel_width / 2; + + const int spatial_dim = height * width; + const int key_saliency_group = key_channels - query_channels; + + int output_inds = ((b * num_group + g) * kernel_height * kernel_width * height + h) * width + w; + scalar_t sum_query_grad = 0; + + int key_inds = ((b * key_channels + g * key_per_group + kpg) * in_height + h * stride) * in_width + w * stride; + for (int kh = 0; kh < kernel_height; ++kh) { + for (int kw = 0; kw < kernel_width; ++kw) { + if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width + && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) { + scalar_t c_out_grad = output_grad[output_inds + (kh * kernel_width + kw) * spatial_dim]; + if (sim_method == 0) { + sum_query_grad += c_out_grad + * key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)]; + } + else if (sim_method == 2) { + scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)]; + scalar_t query_val = query[index]; + if (key_val > query_val) { + sum_query_grad += c_out_grad; + } + else if (key_val < query_val) { + sum_query_grad += -c_out_grad; + } + } + else if (sim_method == 3) { + scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)]; + scalar_t query_val = query[index]; + if (key_val > query_val) { + sum_query_grad += c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + else if (key_val < query_val) { + sum_query_grad += -c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + + if (query_val > 0) { + sum_query_grad += c_out_grad * abs(key_val - query_val) / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + else if (query_val < 0) { + sum_query_grad += -c_out_grad * abs(key_val - query_val) / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 4) { + scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)]; + scalar_t query_val = query[index]; + sum_query_grad += 2 * c_out_grad * (key_val - query_val) / (abs(key_val * query_val) + scalar_t(1.0)); + + if (key_val * query_val > 0) { + sum_query_grad += c_out_grad * key_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + else if(key_val * query_val < 0) { + sum_query_grad += -c_out_grad * key_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 5) { + scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)]; + scalar_t query_val = query[index]; + sum_query_grad += 2 * c_out_grad * (key_val - query_val); + } + } + } + } + sum_query_grad *= scale; + query_grad[index] += sum_query_grad; + + scalar_t sum_key_grad = 0; + int start_kh = -half_kh / stride; + int end_kh = half_kh / stride; + int start_kw = -half_kw / stride; + int end_kw = half_kw / stride; + int key_sal_inds = (b * key_channels + query_channels + int(g * key_saliency_group) / num_group) * in_height * in_width + + h * stride * in_width + w * stride; + + scalar_t sum_key_sal_grad = 0; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride) * spatial_dim + spatial_offset]; + scalar_t query_val = query[index + spatial_offset]; + if (sim_method == 0) { + sum_key_grad += c_out_grad + * query_val * scalar_t(scale); + } + else if (sim_method == 1) { + sum_key_grad += c_out_grad * scalar_t(scale); + } + else if (sim_method == 2) { + scalar_t key_val = key[key_inds]; + if (key_val > query_val) { + sum_key_grad += scalar_t(-scale) * c_out_grad; + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad; + } + } + else if (sim_method == 3) { + scalar_t key_val = key[key_inds]; + if (key_val > query_val) { + sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + + if (key_val > 0) { + sum_key_grad += c_out_grad * scalar_t(scale) * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + else if (key_val < 0){ + sum_key_grad += -c_out_grad * scalar_t(scale) * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 4) { + scalar_t key_val = key[key_inds]; + sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0)); + + if (key_val * query_val > 0) { + sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + else if(key_val * query_val < 0) { + sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 5) { + scalar_t key_val = key[key_inds]; + sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2; + } + + if (key_saliency_group > 0) { + sum_key_sal_grad += c_out_grad; + } + } + } + } + key_grad[key_inds] += sum_key_grad; + if (key_saliency_group > 0) { + atomicAdd(key_grad + key_sal_inds, sum_key_sal_grad); + } + + if (stride == 2){ + if (h * stride + 1 < in_height) { + sum_key_grad = 0; + sum_key_sal_grad = 0; + start_kh = (1 - half_kh) / stride; + end_kh = (half_kh + 1) / stride; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride) * spatial_dim + spatial_offset]; + scalar_t query_val = query[index + spatial_offset]; + if (sim_method == 0) { + sum_key_grad += c_out_grad + * query_val * scalar_t(scale); + } + else if (sim_method == 1) { + sum_key_grad += c_out_grad * scalar_t(scale); + } + else if (sim_method == 2) { + scalar_t key_val = key[key_inds + in_width]; + if (key_val > query_val) { + sum_key_grad += scalar_t(-scale) * c_out_grad; + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad; + } + else { + sum_key_grad += scalar_t(0.0); + } + } + else if (sim_method == 3) { + scalar_t key_val = key[key_inds + in_width]; + if (key_val > query_val) { + sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + + if (key_val > 0) { + sum_key_grad += scalar_t(scale) * c_out_grad * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + else if (key_val < 0){ + sum_key_grad += -scalar_t(scale) * c_out_grad * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 4) { + scalar_t key_val = key[key_inds + in_width]; + sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0)); + + if (key_val * query_val > 0) { + sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + else if(key_val * query_val < 0) { + sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 5) { + scalar_t key_val = key[key_inds + in_width]; + sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2; + } + + if (key_saliency_group > 0) { + sum_key_sal_grad += c_out_grad; + } + } + } + } + key_grad[key_inds + in_width] += sum_key_grad; + if (key_saliency_group > 0) { + atomicAdd(key_grad + key_sal_inds + in_width, sum_key_sal_grad); + } + } + if (w * stride + 1 < in_width) { + sum_key_grad = 0; + sum_key_sal_grad = 0; + start_kh = -half_kh / stride; + end_kh = half_kh / stride; + start_kw = (1 - half_kw) / stride; + end_kw = (half_kw + 1) / stride; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride + 1) * spatial_dim + spatial_offset]; + scalar_t query_val = query[index + spatial_offset]; + if (sim_method == 0) { + sum_key_grad += c_out_grad + * query_val * scalar_t(scale); + } + else if (sim_method == 1) { + sum_key_grad += c_out_grad * scalar_t(scale); + } + else if (sim_method == 2) { + scalar_t key_val = key[key_inds + 1]; + if (key_val > query_val) { + sum_key_grad += scalar_t(-scale) * c_out_grad; + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad; + } + else { + sum_key_grad += scalar_t(0.0); + } + } + else if (sim_method == 3) { + scalar_t key_val = key[key_inds + 1]; + if (key_val > query_val) { + sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + + if (key_val > 0) { + sum_key_grad += scalar_t(scale) * c_out_grad * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + else if (key_val < 0){ + sum_key_grad += -scalar_t(scale) * c_out_grad * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 4) { + scalar_t key_val = key[key_inds + 1]; + sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0)); + + if (key_val * query_val > 0) { + sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + else if(key_val * query_val < 0) { + sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 5) { + scalar_t key_val = key[key_inds + 1]; + sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2; + } + + if (key_saliency_group > 0) { + sum_key_sal_grad += c_out_grad; + } + } + } + } + key_grad[key_inds + 1] += sum_key_grad; + if (key_saliency_group > 0) { + atomicAdd(key_grad + key_sal_inds + 1, sum_key_sal_grad); + } + } + if (h * stride + 1 < in_height && w * stride + 1 < in_width) { + sum_key_grad = 0; + sum_key_sal_grad = 0; + start_kh = (1 - half_kh) / stride; + end_kh = (half_kh + 1) / stride; + start_kw = (1 - half_kw) / stride; + end_kw = (half_kw + 1) / stride; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride + 1) * spatial_dim + spatial_offset]; + scalar_t query_val = query[index + spatial_offset]; + if (sim_method == 0) { + sum_key_grad += c_out_grad + * query_val * scalar_t(scale); + } + else if (sim_method == 1) { + sum_key_grad += c_out_grad * scalar_t(scale); + } + else if (sim_method == 2) { + scalar_t key_val = key[key_inds + in_width + 1]; + if (key_val > query_val) { + sum_key_grad += scalar_t(-scale) * c_out_grad; + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad; + } + } + else if (sim_method == 3) { + scalar_t key_val = key[key_inds + in_width + 1]; + if (key_val > query_val) { + sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + else if (key_val < query_val) { + sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0)); + } + + if (key_val > 0) { + sum_key_grad += scalar_t(scale) * c_out_grad * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + else if (key_val < 0){ + sum_key_grad += -scalar_t(scale) * c_out_grad * abs(key_val - query_val) + / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) + * (abs(key_val) +abs(query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 4) { + scalar_t key_val = key[key_inds + in_width + 1]; + sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0)); + + if (key_val * query_val > 0) { + sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + else if(key_val * query_val < 0) { + sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0))); + } + } + else if (sim_method == 5) { + scalar_t key_val = key[key_inds + in_width + 1]; + sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2; + } + + + if (key_saliency_group > 0) { + sum_key_sal_grad += c_out_grad; + } + } + } + } + key_grad[key_inds + in_width + 1] += sum_key_grad; + if (key_saliency_group > 0) { + atomicAdd(key_grad + key_sal_inds + in_width + 1, sum_key_sal_grad); + } + } + } + } +} + +/* +# [batch_size, num_group, 49, height, width] +app_geo_sim = mx.sym.softmax(app_geo_sim, axis=2) +# [batch_size, num_group, 1, 49, height, width] +app_geo_sim = mx.sym.expand_dims(app_geo_sim, axis=2) +output_value = mx.sym.reshape(mx.sym.sum(mx.sym.broadcast_mul(app_geo_sim, warp_value_data_reshape), axis=3), shape=(0, -3, -2)) +*/ +// value: [batch_size, value_channels, height, width] +// softmax_data: [batch_size, num_group * kernel_height * kernel_width, height, width] +// num_group: +// output: [batch_size, value_channels, height, width] + +template +__global__ void aggregation_forward_kernel(const int n, + const scalar_t* value, + const scalar_t* softmax_data, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + scalar_t* output) { + // n = batch_size * value_channels * height * width + CUDA_KERNEL_LOOP(index, n) { + const int w = index % width; + int h = index / width; + int c = h / height; + h = h % height; + const int b = c / value_channels; + c = c % value_channels; + + const int value_per_group = value_channels / num_group; + + const int g = c / value_per_group; + const int g_in_group = c % value_per_group; + + const int half_kh = kernel_height / 2; + const int half_kw = kernel_width / 2; + + const int spatial_dim = height * width; + scalar_t sum_val = 0; + + int value_inds = (((b * num_group + g) * value_per_group + g_in_group) * in_height + h * stride) * in_width + w * stride; + int softmax_inds = ((b * num_group + g) * kernel_height * kernel_width * height + h) * width + w; + for (int kh = 0; kh < kernel_height; ++kh) { + for (int kw = 0; kw < kernel_width; ++kw) { + if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width + && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) { + sum_val += value[value_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)] * softmax_data[softmax_inds + kh * kernel_width * spatial_dim + kw * spatial_dim]; + //if ((value_inds) == 10001) { + // printf("b: %d, g: %d, h: %d, w: %d, softmax_inds: %d, value_inds: %d, sum_val: %.4f, k:%d, w:%d, softmax: %.4f, val: %.4f\n", + // b, g, h, w, softmax_inds, value_inds, sum_val, kh, kw, + // softmax_data[softmax_inds + kh * kernel_width * spatial_dim + kw * spatial_dim], + // value[value_inds + dilate * (kh - half_kh) * width + dilate * (kw - half_kw)]); + // } + } + } + } + //if (value_inds % 10000 == 1) { + // printf("b: %d, g: %d, h: %d, w: %d, softmax_inds: %d, value_inds: %d, sum_val: %.4f\n", + // b, g, h, w, softmax_inds, value_inds, sum_val); + //} + output[index] = sum_val; + } +} + +template +__global__ void aggregation_value_backward_kernel(const int n, + const scalar_t* softmax_data, + const scalar_t* output_grad, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + scalar_t* value_grad) { + // n = batch_size * value_channels * height * width + CUDA_KERNEL_LOOP(index, n) { + const int w = index % width; + int h = index / width; + int c = h / height; + h = h % height; + const int b = c / value_channels; + c = c % value_channels; + + const int value_per_group = value_channels / num_group; + + const int g = c / value_per_group; + const int g_in_group = c % value_per_group; + + const int half_kh = kernel_height / 2; + const int half_kw = kernel_width / 2; + + const int spatial_dim = height * width; + scalar_t sum_val = 0; + + int value_inds = (((b * num_group + g) * value_per_group + g_in_group) * in_height + h * stride) * in_width + w * stride; + int softmax_inds = ((b * num_group + g) * kernel_height * kernel_width * height + h) * width + w; + + int start_kh = -half_kh / stride; + int end_kh = half_kh / stride; + int start_kw = -half_kw / stride; + int end_kw = half_kw / stride; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + sum_val += output_grad[index + spatial_offset] + * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride) * spatial_dim]; + } + } + } + value_grad[value_inds] += sum_val; + + if (stride == 2){ + if (h * stride + 1 < in_height) { + sum_val = 0; + start_kh = (1 - half_kh) / stride; + end_kh = (half_kh + 1) / stride; + start_kw = -half_kw / stride; + end_kw = half_kw / stride; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + sum_val += output_grad[index + spatial_offset] + * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride) * spatial_dim]; + } + } + } + value_grad[value_inds + in_width] += sum_val; + } + if (w * stride + 1 < in_width) { + sum_val = 0; + start_kh = -half_kh / stride; + end_kh = half_kh / stride; + start_kw = (1 - half_kw) / stride; + end_kw = (half_kw + 1) / stride; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + sum_val += output_grad[index + spatial_offset] + * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride + 1) * spatial_dim]; + } + } + } + value_grad[value_inds + 1] += sum_val; + } + if (h * stride + 1 < in_height && w * stride + 1 < in_width) { + sum_val = 0; + start_kh = (1 - half_kh) / stride; + end_kh = (half_kh + 1) / stride; + start_kw = (1 - half_kw) / stride; + end_kw = (half_kw + 1) / stride; + for (int kh = start_kh; kh <= end_kh; ++kh) { + for (int kw = start_kw; kw <= end_kw; ++kw) { + if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) { + int spatial_offset = dilate * kh * width + dilate * kw; + sum_val += output_grad[index + spatial_offset] + * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride + 1) * spatial_dim]; + } + } + } + value_grad[value_inds + in_width + 1] += sum_val; + } + } + } +} + +template +__global__ void aggregation_softmax_backward_kernel(const int n, + const scalar_t* value, + const scalar_t* output_grad, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + scalar_t* softmax_grad) { + // n = batch_size * num_group * kernel_height * kernel_width * height * width + CUDA_KERNEL_LOOP(index, n) { + const int w = index % width; + int h = index / width; + int kw = h / height; + h = h % height; + int kh = kw / kernel_width; + kw = kw % kernel_width; + int g = kh / kernel_height; + kh = kh % kernel_height; + const int b = g / num_group; + g = g % num_group; + + const int half_kh = kernel_height / 2; + const int half_kw = kernel_width / 2; + + const int value_per_group = value_channels / num_group; + + const int spatial_dim = height * width; + const int in_spatial_dim = in_height * in_width; + scalar_t sum_val = 0; + + int value_inds = ((b * num_group + g) * value_per_group * in_height + h * stride) * in_width + w * stride; + int output_inds = ((b * num_group + g) * value_per_group * height + h) * width + w; + + if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) { + for (int iv = 0; iv < value_per_group; ++iv) { + sum_val += output_grad[output_inds + iv * spatial_dim] * value[value_inds + iv * in_spatial_dim + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)]; + } + } + softmax_grad[index] = sum_val; + } +} + +void similarity_compute_forward( + const at::Tensor key, + const at::Tensor query, + const at::Tensor pos_weight, + const int batch_size, + const int key_channels, + const int query_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const at::Tensor scale, + const at::Tensor no_define_value, + const int dilate, + const int stride, + const int in_height, + const int in_width, + const int sim_method, + at::Tensor output, + const int key_offset, + const int query_offset) +{ + + const int num_kernels = batch_size * num_group * kernel_width * kernel_height * height * width; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key.type(), "similarity_compute_forward_gpu", ([&] { + const scalar_t *key_ptr = key.data_ptr() + key_offset; + const scalar_t *query_ptr = query.data_ptr() + query_offset; + const scalar_t *pos_weight_ptr = pos_weight.data_ptr(); + scalar_t *output_ptr = output.data_ptr(); + const scalar_t *scale_ptr = scale.data_ptr(); + const scalar_t *no_define_value_ptr = no_define_value.data_ptr(); + + similarity_compute_forward_kernel<<>>( + num_kernels, key_ptr, query_ptr, pos_weight_ptr, + batch_size, key_channels, query_channels, height, width, + kernel_height, kernel_width, num_group, + scale_ptr, no_define_value_ptr, dilate, stride, in_height, in_width, sim_method, output_ptr); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in similarity_compute_forward: %s\n", cudaGetErrorString(err)); + } +} + +void similarity_compute_backward( + const at::Tensor key, + const at::Tensor query, + const at::Tensor output_grad, + const int batch_size, + const int key_channels, + const int query_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int key_per_group, + const at::Tensor scale, + const int dilate, + const int stride, + const int in_height, + const int in_width, + const int sim_method, + at::Tensor key_grad, + at::Tensor query_grad, + const int key_grad_offset, + const int query_grad_offset) +{ + const int num_kernels = batch_size * query_channels * height * width; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key.type(), "similarity_compute_backward_gpu", ([&] { + // fixbug: add offset to key and query + const scalar_t *key_ptr = key.data_ptr() + key_grad_offset; + const scalar_t *query_ptr = query.data_ptr() + query_grad_offset; + const scalar_t *output_grad_ptr = output_grad.data_ptr(); + scalar_t *key_grad_ptr = key_grad.data_ptr() + key_grad_offset; + scalar_t *query_grad_ptr = query_grad.data_ptr() + query_grad_offset; + const scalar_t *scale_ptr = scale.data_ptr(); + + similarity_compute_backward_kernel<<>>( + num_kernels, key_ptr, query_ptr, output_grad_ptr, batch_size, + key_channels, query_channels, height, width, + kernel_height, kernel_width, num_group, + key_per_group, scale_ptr, dilate, stride, in_height, in_width, + sim_method, key_grad_ptr, query_grad_ptr); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in similarity_compute_backward: %s\n", cudaGetErrorString(err)); + } + +} + +void aggregation_forward( + const at::Tensor value, + const at::Tensor softmax_data, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + at::Tensor output, + const int value_offset, + const int output_offset) +{ + const int num_kernels = batch_size * value_channels * height * width; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + value.type(), "aggregation_forward_gpu", ([&] { + const scalar_t *value_ptr = value.data_ptr() + value_offset; + const scalar_t *softmax_data_ptr = softmax_data.data_ptr(); + scalar_t *output_ptr = output.data_ptr() + output_offset; + + aggregation_forward_kernel<<>>( + num_kernels, value_ptr, softmax_data_ptr, + batch_size, value_channels, height, width, + kernel_height, kernel_width, num_group, + dilate, stride, in_height, in_width, + output_ptr); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in aggregation_forward: %s\n", cudaGetErrorString(err)); + } + +} + +void aggregation_value_backward( + const at::Tensor softmax_data, + const at::Tensor output_grad, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + at::Tensor value_grad, + const int output_grad_offset, + const int value_grad_offset) +{ + const int num_kernels = batch_size * value_channels * height * width; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + output_grad.type(), "aggregation_value_backward_gpu", ([&] { + const scalar_t *softmax_data_ptr = softmax_data.data_ptr(); + const scalar_t *output_grad_ptr = output_grad.data_ptr() + output_grad_offset; + scalar_t *value_grad_ptr = value_grad.data_ptr() + value_grad_offset; + + aggregation_value_backward_kernel<<>>( + num_kernels, softmax_data_ptr, output_grad_ptr, + batch_size, value_channels, height, width, + kernel_height, kernel_width, num_group, + dilate, stride, in_height, in_width, + value_grad_ptr); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in aggregation_value_backward: %s\n", cudaGetErrorString(err)); + } + +} + +void aggregation_softmax_backward( + const at::Tensor value, + const at::Tensor output_grad, + const int batch_size, + const int value_channels, + const int height, + const int width, + const int kernel_height, + const int kernel_width, + const int num_group, + const int dilate, + const int stride, + const int in_height, + const int in_width, + at::Tensor softmax_grad, + const int value_offset, + const int output_grad_offset) +{ + const int num_kernels = batch_size * num_group * kernel_height * kernel_width * height * width; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + value.type(), "aggregation_softmax_backward_gpu", ([&] { + + const scalar_t *value_ptr = value.data_ptr() + value_offset; + const scalar_t *output_grad_ptr = output_grad.data_ptr() + output_grad_offset; + scalar_t *softmax_grad_ptr = softmax_grad.data_ptr(); + + aggregation_softmax_backward_kernel<<>>( + num_kernels, value_ptr, output_grad_ptr, + batch_size, value_channels, height, width, + kernel_height, kernel_width, num_group, + dilate, stride, in_height, in_width, + softmax_grad_ptr); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in aggregation_softmax_backward: %s\n", cudaGetErrorString(err)); + } +}