diff --git a/README.md b/README.md
index 76299ef..a591402 100644
--- a/README.md
+++ b/README.md
@@ -1,92 +1,24 @@
-# Swin Transformer
+# Local Relation Networks V2 (LR-Net V2)
-[](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-hierarchical-vision)
-[](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-hierarchical-vision)
-[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=swin-transformer-hierarchical-vision)
-[](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=swin-transformer-hierarchical-vision)
-[](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-hierarchical-vision)
-[](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k-val?p=swin-transformer-hierarchical-vision)
+This branch is an improved implementation of ["Local Relation Networks for Image Recognition (LR-Net)"](https://arxiv.org/pdf/1904.11491.pdf). The original LR-Net utilizes sliding window based self-attention layer to replace the `3x3` convolution layers in a ResNet architecture. This improved implementation applies this layer into a stronger overall architecture based on Tranformers, dubbed as LR-Net V2. We provide cuda kernels for the local relation layers. Training scripts and pre-trained models will be provided in the future.
-By [Ze Liu](https://github.com/zeliu98/)\*, [Yutong Lin](https://github.com/impiga)\*, [Yue Cao](http://yue-cao.me)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Yixuan Wei](https://github.com/weiyx16), [Zheng Zhang](https://stupidzz.github.io/), [Stephen Lin](https://scholar.google.com/citations?user=c3PYmxUAAAAJ&hl=en) and [Baining Guo](https://www.microsoft.com/en-us/research/people/bainguo/).
+## Install
+```bash
+cd ops/local_relation
+python setup.py build_ext --inplace
+```
-This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf). It currently includes code and models for the following tasks:
-
-> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.
-
-> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
-
-> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
-
-## Updates
-
-***04/12/2021***
-
-Initial commits:
-
-1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
-2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
-3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
-
-## Introduction
-
-**Swin Transformer** is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
-general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
-computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
-computation to non-overlapping local windows while also allowing for cross-window connection.
-
-Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
-ADE20K semantic segmentatiion (`53.5 mIoU` on val), surpassing previous models by a large margin.
-
-
-
-## Main Results on ImageNet with Pretrained Models
-
-**ImageNet-1K and ImageNet-22K Pretrained Models**
-
-| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
-| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
-| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w) |
-| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg) |
-| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ) |
-| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw) |
-| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg) |
-| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg) |
-| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ) |
-| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA) |
-
-Note: access code for `baidu` is `swin`.
-
-## Main Results on Downstream Tasks
-
-**COCO Object Detection (2017 val)**
-
-| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |
-| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
-| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |
-| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |
-| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |
-| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G |
-| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G |
-| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |
-| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |
-| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |
-| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |
-| Swin-L | HTC++* | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |
-
-Note: * indicates multi-scale testing.
-
-**ADE20K Semantic Segmentation (val)**
-
-| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |
-| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
-| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |
-| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |
-| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |
-| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |
-| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |
-
-## Citing Swin Transformer
+## Citing Local Relation Networks
+```
+@inproceedings{hu2019local,
+ title={Local relation networks for image recognition},
+ author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ pages={3464--3473},
+ year={2019}
+}
+```
```
@article{liu2021Swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
@@ -96,12 +28,6 @@ Note: * indicates multi-scale testing.
}
```
-## Getting Started
-
-- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.
-- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
-- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
-
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
diff --git a/ops/local_relation/.gitignore b/ops/local_relation/.gitignore
new file mode 100644
index 0000000..c9c5499
--- /dev/null
+++ b/ops/local_relation/.gitignore
@@ -0,0 +1,4 @@
+*.so
+_ext*
+__pycache__
+build
diff --git a/ops/local_relation/__init__.py b/ops/local_relation/__init__.py
new file mode 100644
index 0000000..866a932
--- /dev/null
+++ b/ops/local_relation/__init__.py
@@ -0,0 +1 @@
+from .local_relation_func import local_relation
diff --git a/ops/local_relation/local_relation_func.py b/ops/local_relation/local_relation_func.py
new file mode 100644
index 0000000..534a4bd
--- /dev/null
+++ b/ops/local_relation/local_relation_func.py
@@ -0,0 +1,102 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2019 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Han Hu, Jiarui Xu
+# Modified by Ze Liu
+# --------------------------------------------------------
+
+import torch
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from . import local_relation_cuda
+
+
+class LocalRelationFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ query,
+ key,
+ value,
+ pos_weight,
+ kernel_size,
+ groups,
+ stride=1,
+ dilation=1,
+ scale=1.,
+ no_define_value=-100.,
+ norm_method=0,
+ sim_method=0,
+ batch_step=32):
+ for input in [query, key, value]:
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ "Expected 4D tensor as input, got {}D tensor instead.".format(
+ input.dim()))
+ ctx.kernel_size = _pair(kernel_size)
+ ctx.groups = groups
+ ctx.stride = stride
+ ctx.dilation = dilation
+ ctx.scale = scale
+ ctx.no_define_value = no_define_value
+ ctx.norm_method = norm_method
+ ctx.sim_method = sim_method
+ ctx.batch_step = batch_step
+
+ ctx.save_for_backward(query, key, value, pos_weight)
+
+ output = query.new_empty(
+ LocalRelationFunction._output_size(query, value))
+
+ scale_tensor = query.new_tensor([ctx.scale])
+ no_define_value_tensor = query.new_tensor([ctx.no_define_value])
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ batch_step = min(ctx.batch_step, query.shape[0])
+ local_relation_cuda.local_relation_forward_cuda(
+ query, key, value, pos_weight, scale_tensor, no_define_value_tensor,
+ output, ctx.kernel_size[1], ctx.kernel_size[0], ctx.groups,
+ ctx.dilation, ctx.stride, batch_step, ctx.norm_method, ctx.sim_method)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ query, key, value, pos_weight = ctx.saved_tensors
+
+ grad_query = grad_key = grad_value = grad_pos_weight = None
+
+ scale_tensor = query.new_tensor(ctx.scale)
+ no_define_value_tensor = query.new_tensor(ctx.no_define_value)
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ batch_step = min(ctx.batch_step, query.shape[0])
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2] or ctx.needs_input_grad[3]:
+ grad_query = torch.zeros_like(query)
+ grad_key = torch.zeros_like(key)
+ grad_value = torch.zeros_like(value)
+ grad_pos_weight = torch.zeros_like(pos_weight)
+ local_relation_cuda.local_relation_backward_cuda(
+ query, key, value, pos_weight,
+ scale_tensor, no_define_value_tensor, grad_output,
+ grad_query, grad_key, grad_value, grad_pos_weight,
+ ctx.kernel_size[1], ctx.kernel_size[0],
+ ctx.groups, ctx.dilation, ctx.stride, batch_step,
+ ctx.norm_method, ctx.sim_method)
+
+ return (grad_query, grad_key, grad_value, grad_pos_weight, None, None, None,
+ None, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(query, value):
+ output_size = (query.size(0), value.size(1), query.size(2), query.size(3))
+ return output_size
+
+
+local_relation = LocalRelationFunction.apply
diff --git a/ops/local_relation/setup.py b/ops/local_relation/setup.py
new file mode 100644
index 0000000..abddebc
--- /dev/null
+++ b/ops/local_relation/setup.py
@@ -0,0 +1,12 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='local_relation',
+ ext_modules=[
+ CUDAExtension('local_relation_cuda', [
+ 'src/local_relation_cuda.cpp',
+ 'src/local_relation_cuda_kernel.cu',
+ ]),
+ ],
+ cmdclass={'build_ext': BuildExtension})
diff --git a/ops/local_relation/src/local_relation_cuda.cpp b/ops/local_relation/src/local_relation_cuda.cpp
new file mode 100644
index 0000000..aad58ca
--- /dev/null
+++ b/ops/local_relation/src/local_relation_cuda.cpp
@@ -0,0 +1,331 @@
+/*!
+ * Copyright (c) 2019 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file local_relation_cuda.cpp
+ * \brief
+ * \author Han Hu
+ * \modified by Jiarui Xu, Ze Liu
+*/
+
+#include
+
+#include
+#include
+
+void similarity_compute_forward(
+ const at::Tensor key,
+ const at::Tensor query,
+ const at::Tensor pos_weight,
+ const int batch_size,
+ const int key_channels,
+ const int query_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const at::Tensor scale,
+ const at::Tensor no_define_value,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ const int sim_method,
+ at::Tensor output,
+ const int key_offset,
+ const int query_offset);
+
+void similarity_compute_backward(
+ const at::Tensor key,
+ const at::Tensor query,
+ const at::Tensor output_grad,
+ const int batch_size,
+ const int key_channels,
+ const int query_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int key_per_group,
+ const at::Tensor scale,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ const int sim_method,
+ at::Tensor key_grad,
+ at::Tensor query_grad,
+ const int key_grad_offset,
+ const int query_grad_offset);
+
+void aggregation_forward(
+ const at::Tensor value,
+ const at::Tensor softmax_data,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ at::Tensor output,
+ const int value_offset,
+ const int output_offset);
+
+void aggregation_value_backward(
+ const at::Tensor softmax_data,
+ const at::Tensor output_grad,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ at::Tensor value_grad,
+ const int output_grad_offset,
+ const int value_grad_offset);
+
+void aggregation_softmax_backward(
+ const at::Tensor value,
+ const at::Tensor output_grad,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ at::Tensor softmax_grad,
+ const int value_offset,
+ const int output_grad_offset);
+
+
+int local_relation_forward_cuda(
+ at::Tensor query,
+ at::Tensor key,
+ at::Tensor value,
+ at::Tensor pos_weight,
+ at::Tensor scale,
+ at::Tensor no_define_value,
+ at::Tensor output,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int batch_step,
+ const int norm_method,
+ const int sim_method)
+{
+ query = query.contiguous();
+ key = key.contiguous();
+ value = value.contiguous();
+ pos_weight = pos_weight.contiguous();
+
+ const int query_channels = query.size(1);
+ const int key_channels = key.size(1);
+ const int value_channels = value.size(1);
+ const int batch_size = key.size(0);
+ const int height = query.size(2);
+ const int width = query.size(3);
+ const int in_height = key.size(2);
+ const int in_width = key.size(3);
+
+ const int batch_step_ = std::min(batch_size, batch_step);
+ const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width;
+
+ const int key_step = batch_step_ * key_channels * in_height * in_width;
+ const int query_step = batch_step_ * query_channels * height * width;
+ const int value_step = batch_step_ * value_channels * in_height * in_width;
+ const int output_step = batch_step_ * value_channels * height * width;
+
+ at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
+ query.options());
+
+ at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
+ query.options());
+
+ at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width});
+
+ int M = (batch_size - 1) / batch_step_ + 1;
+ for (int i = 0; i < M; ++i) {
+ int cur_batch_step = batch_step_;
+ if (i == M - 1) {
+ cur_batch_step = batch_size - (M - 1) * batch_step_;
+ if (cur_batch_step != batch_step_) {
+ sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
+ softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options());
+ sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options());
+ }
+
+ // TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size");
+ }
+ similarity_compute_forward(key, query, pos_weight, cur_batch_step,
+ key_channels, query_channels, height, width,
+ kernel_height, kernel_width, num_group, scale, no_define_value,
+ dilate, stride, in_height, in_width, sim_method, sim_buffer,
+ key_step * i, query_step * i);
+
+ // softmax
+ if (norm_method == 0) {
+ softmax_buffer = sim_buffer.softmax(1);
+ }
+ else {
+ AT_ERROR("Not implemented yet");
+ }
+
+ aggregation_forward(value, softmax_buffer, cur_batch_step,
+ value_channels, height, width, kernel_height, kernel_width,
+ num_group, dilate, stride, in_height, in_width, output, value_step * i, output_step * i);
+ }
+
+ return 1;
+
+}
+
+int local_relation_backward_cuda(
+ at::Tensor query,
+ at::Tensor key,
+ at::Tensor value,
+ at::Tensor pos_weight,
+ at::Tensor scale,
+ at::Tensor no_define_value,
+ at::Tensor output_grad,
+ at::Tensor query_grad,
+ at::Tensor key_grad,
+ at::Tensor value_grad,
+ at::Tensor pos_weight_grad,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int batch_step,
+ const int norm_method,
+ const int sim_method)
+{
+ query = query.contiguous();
+ key = key.contiguous();
+ value = value.contiguous();
+ pos_weight = pos_weight.contiguous();
+
+ output_grad = output_grad.contiguous();
+ query_grad = query_grad.contiguous();
+ key_grad = key_grad.contiguous();
+ value_grad = value_grad.contiguous();
+ pos_weight_grad = pos_weight_grad.contiguous();
+
+ const int query_channels = query.size(1);
+ const int key_channels = key.size(1);
+ const int value_channels = value.size(1);
+ const int batch_size = key.size(0);
+ const int height = query.size(2);
+ const int width = query.size(3);
+ const int in_height = key.size(2);
+ const int in_width = key.size(3);
+ const int key_per_group = query_channels / num_group;
+
+ const int batch_step_ = std::min(batch_size, batch_step);
+ const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width;
+
+ const int key_step = batch_step_ * key_channels * in_height * in_width;
+ const int query_step = batch_step_ * query_channels * height * width;
+ const int value_step = batch_step_ * value_channels * in_height * in_width;
+ const int output_step = batch_step_ * value_channels * height * width;
+
+ at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
+ query.options());
+
+ at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
+ query.options());
+
+ at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width},
+ query.options());
+
+ at::Tensor sim_grad_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
+ query.options());
+
+ int M = (batch_size - 1) / batch_step_ + 1;
+
+ const int pos_weight_size = num_group * kernel_height * kernel_width;
+
+ for (int i = 0; i < M; ++i) {
+ int cur_batch_step = batch_step_;
+ if (i == M - 1) {
+ cur_batch_step = batch_size - (M - 1) * batch_step_;
+ if (cur_batch_step != batch_step_) {
+ sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
+ softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options());
+ sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options());
+ sim_grad_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
+ }
+ // TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size");
+ }
+
+ similarity_compute_forward(key, query, pos_weight, cur_batch_step,
+ key_channels, query_channels, height, width,
+ kernel_height, kernel_width, num_group, scale, no_define_value,
+ dilate, stride, in_height, in_width, sim_method, sim_buffer,
+ key_step * i, query_step * i);
+
+ // softmax
+
+ if (norm_method == 0) {
+ softmax_buffer = sim_buffer.softmax(1);
+ }
+ else {
+ AT_ERROR("Not implemented yet");
+ }
+
+ aggregation_value_backward(softmax_buffer, output_grad, cur_batch_step,
+ value_channels, height, width, kernel_height, kernel_width,
+ num_group, dilate, stride, in_height, in_width, value_grad,
+ output_step * i, value_step * i);
+
+ aggregation_softmax_backward(value, output_grad, cur_batch_step,
+ value_channels, height, width, kernel_height, kernel_width,
+ num_group, dilate, stride, in_height, in_width, sim_buffer,
+ value_step * i, output_step * i);
+
+ if (norm_method == 0) {
+ sum_softmax_buffer = (softmax_buffer * sim_buffer).sum(1, true);
+ sim_grad_buffer = softmax_buffer * (sim_buffer - sum_softmax_buffer);
+ }
+ else {
+ AT_ERROR("Not implemented yet");
+ }
+
+ similarity_compute_backward(key, query, sim_grad_buffer, cur_batch_step,
+ key_channels, query_channels, height, width,
+ kernel_height, kernel_width, num_group, key_per_group, scale,
+ dilate, stride, in_height, in_width, sim_method, key_grad, query_grad,
+ key_step * i, query_step * i);
+
+ pos_weight_grad += sim_grad_buffer.view({cur_batch_step, num_group, kernel_height, kernel_width, height * width}).sum(4).sum(0);
+
+ }
+
+ return 1;
+
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("local_relation_forward_cuda", &local_relation_forward_cuda,
+ "local relation forward (CUDA)");
+ m.def("local_relation_backward_cuda", &local_relation_backward_cuda,
+ "local relation backward (CUDA)");
+}
diff --git a/ops/local_relation/src/local_relation_cuda_kernel.cu b/ops/local_relation/src/local_relation_cuda_kernel.cu
new file mode 100644
index 0000000..1e84eb5
--- /dev/null
+++ b/ops/local_relation/src/local_relation_cuda_kernel.cu
@@ -0,0 +1,1004 @@
+/*!
+ * Copyright (c) 2019 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file local_relation_cuda_kernel.cu
+ * \brief
+ * \author Han Hu
+ * \modified by Jiarui Xu, Ze Liu
+*/
+
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template
+__global__ void similarity_compute_forward_kernel(const int n,
+ const scalar_t* key,
+ const scalar_t* query,
+ const scalar_t* pos_weight,
+ const int batch_size,
+ const int key_channels,
+ const int query_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const scalar_t* scale_ptr,
+ const scalar_t* no_define_value_ptr,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ const int sim_method,
+ scalar_t* output) {
+ // n = batch_size * num_group * kernel_height * kernel_width * height * width
+ const scalar_t scale = scale_ptr[0];
+ const scalar_t no_define_value = no_define_value_ptr[0];
+ CUDA_KERNEL_LOOP(index, n) {
+ const int w = index % width;
+ int h = index / width;
+ int kw = h / height;
+ h = h % height;
+ int kh = kw / kernel_width;
+ kw = kw % kernel_width;
+ int g = kh / kernel_height;
+ kh = kh % kernel_height;
+ const int b = g / num_group;
+ g = g % num_group;
+
+ scalar_t sum_sim = 0;
+ const int half_kh = kernel_height / 2;
+ const int half_kw = kernel_width / 2;
+ if (sim_method >= 0){
+ const int key_per_group = query_channels / num_group;
+
+ const int spatial_dim = height * width;
+ const int in_spatial_dim = in_height * in_width;
+ int query_inds = 0;
+ if (sim_method != 1) {
+ query_inds = ((b * num_group + g) * key_per_group * height + h) * width + w;
+ }
+ const int key_saliency_group = key_channels - query_channels;
+
+ if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) {
+ int key_inds = ((b * key_channels + g * key_per_group) * in_height + h * stride + dilate * (kh - half_kh)) * in_width + w * stride + dilate * (kw - half_kw);
+
+ for (int i = 0; i < key_per_group; ++i) {
+ if (sim_method == 0) {
+ sum_sim += query[query_inds + i * spatial_dim] * key[key_inds + i * in_spatial_dim] * scale;
+ }
+ else if (sim_method == 1) {
+ sum_sim += key[key_inds + i * in_spatial_dim] * scale;
+ }
+ else if (sim_method == 2) {
+ sum_sim += -abs(query[query_inds + i * spatial_dim] - key[key_inds + i * in_spatial_dim]) * scale;
+ }
+ else if (sim_method == 3) {
+ scalar_t query_val = query[query_inds + i * spatial_dim];
+ scalar_t key_val = key[key_inds + i * in_spatial_dim];
+ sum_sim += -abs(query_val - key_val) / (abs(query_val) + abs(key_val) + scalar_t(1.0)) * scale;
+ }
+ else if (sim_method == 4) {
+ scalar_t query_val = query[query_inds + i * spatial_dim];
+ scalar_t key_val = key[key_inds + i * in_spatial_dim];
+ sum_sim += -(query_val - key_val) * (query_val - key_val) / (abs(query_val * key_val) + scalar_t(1.0)) * scale;
+ }
+ else if (sim_method == 5) {
+ scalar_t query_val = query[query_inds + i * spatial_dim];
+ scalar_t key_val = key[key_inds + i * in_spatial_dim];
+ sum_sim += -(query_val - key_val) * (query_val - key_val) * scale;
+ }
+ if (key_saliency_group > 0) {
+ int key_sal_inds = (b * key_channels + query_channels + int(g * key_saliency_group) / num_group) * in_spatial_dim
+ + (h * stride + dilate * (kh - half_kh)) * in_width + w * stride + dilate * (kw - half_kw);
+ sum_sim += key[key_sal_inds];
+ }
+ }
+ }
+ else{
+ sum_sim = no_define_value;
+ }
+ }
+
+ if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) {
+ }
+ else {
+ sum_sim = no_define_value;
+ }
+ int pos_inds = (g * kernel_height + kh) * kernel_width + kw;
+ sum_sim += pos_weight[pos_inds];
+
+ output[index] = sum_sim;
+ }
+}
+
+template
+__global__ void similarity_compute_backward_kernel(const int n,
+ const scalar_t* key,
+ const scalar_t* query,
+ const scalar_t* output_grad,
+ const int batch_size,
+ const int key_channels,
+ const int query_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int key_per_group,
+ const scalar_t* scale_ptr,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ const int sim_method,
+ scalar_t* key_grad,
+ scalar_t* query_grad) {
+ const scalar_t scale = scale_ptr[0];
+ CUDA_KERNEL_LOOP(index, n) {
+ const int w = index % width;
+ int h = index / width;
+ int kpg = h / height;
+ h = h % height;
+ int g = kpg / key_per_group;
+ kpg = kpg % key_per_group;
+ const int b = g / num_group;
+ g = g % num_group;
+
+ const int half_kh = kernel_height / 2;
+ const int half_kw = kernel_width / 2;
+
+ const int spatial_dim = height * width;
+ const int key_saliency_group = key_channels - query_channels;
+
+ int output_inds = ((b * num_group + g) * kernel_height * kernel_width * height + h) * width + w;
+ scalar_t sum_query_grad = 0;
+
+ int key_inds = ((b * key_channels + g * key_per_group + kpg) * in_height + h * stride) * in_width + w * stride;
+ for (int kh = 0; kh < kernel_height; ++kh) {
+ for (int kw = 0; kw < kernel_width; ++kw) {
+ if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width
+ && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) {
+ scalar_t c_out_grad = output_grad[output_inds + (kh * kernel_width + kw) * spatial_dim];
+ if (sim_method == 0) {
+ sum_query_grad += c_out_grad
+ * key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)];
+ }
+ else if (sim_method == 2) {
+ scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)];
+ scalar_t query_val = query[index];
+ if (key_val > query_val) {
+ sum_query_grad += c_out_grad;
+ }
+ else if (key_val < query_val) {
+ sum_query_grad += -c_out_grad;
+ }
+ }
+ else if (sim_method == 3) {
+ scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)];
+ scalar_t query_val = query[index];
+ if (key_val > query_val) {
+ sum_query_grad += c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+ else if (key_val < query_val) {
+ sum_query_grad += -c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+
+ if (query_val > 0) {
+ sum_query_grad += c_out_grad * abs(key_val - query_val) / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ else if (query_val < 0) {
+ sum_query_grad += -c_out_grad * abs(key_val - query_val) / ((abs(key_val) +abs(query_val) + scalar_t(1.0)) * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 4) {
+ scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)];
+ scalar_t query_val = query[index];
+ sum_query_grad += 2 * c_out_grad * (key_val - query_val) / (abs(key_val * query_val) + scalar_t(1.0));
+
+ if (key_val * query_val > 0) {
+ sum_query_grad += c_out_grad * key_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ else if(key_val * query_val < 0) {
+ sum_query_grad += -c_out_grad * key_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 5) {
+ scalar_t key_val = key[key_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)];
+ scalar_t query_val = query[index];
+ sum_query_grad += 2 * c_out_grad * (key_val - query_val);
+ }
+ }
+ }
+ }
+ sum_query_grad *= scale;
+ query_grad[index] += sum_query_grad;
+
+ scalar_t sum_key_grad = 0;
+ int start_kh = -half_kh / stride;
+ int end_kh = half_kh / stride;
+ int start_kw = -half_kw / stride;
+ int end_kw = half_kw / stride;
+ int key_sal_inds = (b * key_channels + query_channels + int(g * key_saliency_group) / num_group) * in_height * in_width
+ + h * stride * in_width + w * stride;
+
+ scalar_t sum_key_sal_grad = 0;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride) * spatial_dim + spatial_offset];
+ scalar_t query_val = query[index + spatial_offset];
+ if (sim_method == 0) {
+ sum_key_grad += c_out_grad
+ * query_val * scalar_t(scale);
+ }
+ else if (sim_method == 1) {
+ sum_key_grad += c_out_grad * scalar_t(scale);
+ }
+ else if (sim_method == 2) {
+ scalar_t key_val = key[key_inds];
+ if (key_val > query_val) {
+ sum_key_grad += scalar_t(-scale) * c_out_grad;
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad;
+ }
+ }
+ else if (sim_method == 3) {
+ scalar_t key_val = key[key_inds];
+ if (key_val > query_val) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+
+ if (key_val > 0) {
+ sum_key_grad += c_out_grad * scalar_t(scale) * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ else if (key_val < 0){
+ sum_key_grad += -c_out_grad * scalar_t(scale) * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 4) {
+ scalar_t key_val = key[key_inds];
+ sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0));
+
+ if (key_val * query_val > 0) {
+ sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ else if(key_val * query_val < 0) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 5) {
+ scalar_t key_val = key[key_inds];
+ sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2;
+ }
+
+ if (key_saliency_group > 0) {
+ sum_key_sal_grad += c_out_grad;
+ }
+ }
+ }
+ }
+ key_grad[key_inds] += sum_key_grad;
+ if (key_saliency_group > 0) {
+ atomicAdd(key_grad + key_sal_inds, sum_key_sal_grad);
+ }
+
+ if (stride == 2){
+ if (h * stride + 1 < in_height) {
+ sum_key_grad = 0;
+ sum_key_sal_grad = 0;
+ start_kh = (1 - half_kh) / stride;
+ end_kh = (half_kh + 1) / stride;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride) * spatial_dim + spatial_offset];
+ scalar_t query_val = query[index + spatial_offset];
+ if (sim_method == 0) {
+ sum_key_grad += c_out_grad
+ * query_val * scalar_t(scale);
+ }
+ else if (sim_method == 1) {
+ sum_key_grad += c_out_grad * scalar_t(scale);
+ }
+ else if (sim_method == 2) {
+ scalar_t key_val = key[key_inds + in_width];
+ if (key_val > query_val) {
+ sum_key_grad += scalar_t(-scale) * c_out_grad;
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad;
+ }
+ else {
+ sum_key_grad += scalar_t(0.0);
+ }
+ }
+ else if (sim_method == 3) {
+ scalar_t key_val = key[key_inds + in_width];
+ if (key_val > query_val) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+
+ if (key_val > 0) {
+ sum_key_grad += scalar_t(scale) * c_out_grad * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ else if (key_val < 0){
+ sum_key_grad += -scalar_t(scale) * c_out_grad * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 4) {
+ scalar_t key_val = key[key_inds + in_width];
+ sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0));
+
+ if (key_val * query_val > 0) {
+ sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ else if(key_val * query_val < 0) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 5) {
+ scalar_t key_val = key[key_inds + in_width];
+ sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2;
+ }
+
+ if (key_saliency_group > 0) {
+ sum_key_sal_grad += c_out_grad;
+ }
+ }
+ }
+ }
+ key_grad[key_inds + in_width] += sum_key_grad;
+ if (key_saliency_group > 0) {
+ atomicAdd(key_grad + key_sal_inds + in_width, sum_key_sal_grad);
+ }
+ }
+ if (w * stride + 1 < in_width) {
+ sum_key_grad = 0;
+ sum_key_sal_grad = 0;
+ start_kh = -half_kh / stride;
+ end_kh = half_kh / stride;
+ start_kw = (1 - half_kw) / stride;
+ end_kw = (half_kw + 1) / stride;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride + 1) * spatial_dim + spatial_offset];
+ scalar_t query_val = query[index + spatial_offset];
+ if (sim_method == 0) {
+ sum_key_grad += c_out_grad
+ * query_val * scalar_t(scale);
+ }
+ else if (sim_method == 1) {
+ sum_key_grad += c_out_grad * scalar_t(scale);
+ }
+ else if (sim_method == 2) {
+ scalar_t key_val = key[key_inds + 1];
+ if (key_val > query_val) {
+ sum_key_grad += scalar_t(-scale) * c_out_grad;
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad;
+ }
+ else {
+ sum_key_grad += scalar_t(0.0);
+ }
+ }
+ else if (sim_method == 3) {
+ scalar_t key_val = key[key_inds + 1];
+ if (key_val > query_val) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+
+ if (key_val > 0) {
+ sum_key_grad += scalar_t(scale) * c_out_grad * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ else if (key_val < 0){
+ sum_key_grad += -scalar_t(scale) * c_out_grad * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 4) {
+ scalar_t key_val = key[key_inds + 1];
+ sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0));
+
+ if (key_val * query_val > 0) {
+ sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ else if(key_val * query_val < 0) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 5) {
+ scalar_t key_val = key[key_inds + 1];
+ sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2;
+ }
+
+ if (key_saliency_group > 0) {
+ sum_key_sal_grad += c_out_grad;
+ }
+ }
+ }
+ }
+ key_grad[key_inds + 1] += sum_key_grad;
+ if (key_saliency_group > 0) {
+ atomicAdd(key_grad + key_sal_inds + 1, sum_key_sal_grad);
+ }
+ }
+ if (h * stride + 1 < in_height && w * stride + 1 < in_width) {
+ sum_key_grad = 0;
+ sum_key_sal_grad = 0;
+ start_kh = (1 - half_kh) / stride;
+ end_kh = (half_kh + 1) / stride;
+ start_kw = (1 - half_kw) / stride;
+ end_kw = (half_kw + 1) / stride;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ scalar_t c_out_grad = output_grad[output_inds + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride + 1) * spatial_dim + spatial_offset];
+ scalar_t query_val = query[index + spatial_offset];
+ if (sim_method == 0) {
+ sum_key_grad += c_out_grad
+ * query_val * scalar_t(scale);
+ }
+ else if (sim_method == 1) {
+ sum_key_grad += c_out_grad * scalar_t(scale);
+ }
+ else if (sim_method == 2) {
+ scalar_t key_val = key[key_inds + in_width + 1];
+ if (key_val > query_val) {
+ sum_key_grad += scalar_t(-scale) * c_out_grad;
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad;
+ }
+ }
+ else if (sim_method == 3) {
+ scalar_t key_val = key[key_inds + in_width + 1];
+ if (key_val > query_val) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+ else if (key_val < query_val) {
+ sum_key_grad += scalar_t(scale) * c_out_grad / (abs(key_val) + abs(query_val) + scalar_t(1.0));
+ }
+
+ if (key_val > 0) {
+ sum_key_grad += scalar_t(scale) * c_out_grad * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ else if (key_val < 0){
+ sum_key_grad += -scalar_t(scale) * c_out_grad * abs(key_val - query_val)
+ / ((abs(key_val) +abs(query_val) + scalar_t(1.0))
+ * (abs(key_val) +abs(query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 4) {
+ scalar_t key_val = key[key_inds + in_width + 1];
+ sum_key_grad += 2 * scalar_t(scale) * c_out_grad * (query_val - key_val) / (abs(key_val * query_val) + scalar_t(1.0));
+
+ if (key_val * query_val > 0) {
+ sum_key_grad += scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ else if(key_val * query_val < 0) {
+ sum_key_grad += -scalar_t(scale) * c_out_grad * query_val * (key_val - query_val) * (key_val - query_val) / ((abs(key_val * query_val) + scalar_t(1.0)) * (abs(key_val * query_val) + scalar_t(1.0)));
+ }
+ }
+ else if (sim_method == 5) {
+ scalar_t key_val = key[key_inds + in_width + 1];
+ sum_key_grad += scalar_t(scale) * c_out_grad * (query_val - key_val) * 2;
+ }
+
+
+ if (key_saliency_group > 0) {
+ sum_key_sal_grad += c_out_grad;
+ }
+ }
+ }
+ }
+ key_grad[key_inds + in_width + 1] += sum_key_grad;
+ if (key_saliency_group > 0) {
+ atomicAdd(key_grad + key_sal_inds + in_width + 1, sum_key_sal_grad);
+ }
+ }
+ }
+ }
+}
+
+/*
+# [batch_size, num_group, 49, height, width]
+app_geo_sim = mx.sym.softmax(app_geo_sim, axis=2)
+# [batch_size, num_group, 1, 49, height, width]
+app_geo_sim = mx.sym.expand_dims(app_geo_sim, axis=2)
+output_value = mx.sym.reshape(mx.sym.sum(mx.sym.broadcast_mul(app_geo_sim, warp_value_data_reshape), axis=3), shape=(0, -3, -2))
+*/
+// value: [batch_size, value_channels, height, width]
+// softmax_data: [batch_size, num_group * kernel_height * kernel_width, height, width]
+// num_group:
+// output: [batch_size, value_channels, height, width]
+
+template
+__global__ void aggregation_forward_kernel(const int n,
+ const scalar_t* value,
+ const scalar_t* softmax_data,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ scalar_t* output) {
+ // n = batch_size * value_channels * height * width
+ CUDA_KERNEL_LOOP(index, n) {
+ const int w = index % width;
+ int h = index / width;
+ int c = h / height;
+ h = h % height;
+ const int b = c / value_channels;
+ c = c % value_channels;
+
+ const int value_per_group = value_channels / num_group;
+
+ const int g = c / value_per_group;
+ const int g_in_group = c % value_per_group;
+
+ const int half_kh = kernel_height / 2;
+ const int half_kw = kernel_width / 2;
+
+ const int spatial_dim = height * width;
+ scalar_t sum_val = 0;
+
+ int value_inds = (((b * num_group + g) * value_per_group + g_in_group) * in_height + h * stride) * in_width + w * stride;
+ int softmax_inds = ((b * num_group + g) * kernel_height * kernel_width * height + h) * width + w;
+ for (int kh = 0; kh < kernel_height; ++kh) {
+ for (int kw = 0; kw < kernel_width; ++kw) {
+ if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width
+ && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) {
+ sum_val += value[value_inds + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)] * softmax_data[softmax_inds + kh * kernel_width * spatial_dim + kw * spatial_dim];
+ //if ((value_inds) == 10001) {
+ // printf("b: %d, g: %d, h: %d, w: %d, softmax_inds: %d, value_inds: %d, sum_val: %.4f, k:%d, w:%d, softmax: %.4f, val: %.4f\n",
+ // b, g, h, w, softmax_inds, value_inds, sum_val, kh, kw,
+ // softmax_data[softmax_inds + kh * kernel_width * spatial_dim + kw * spatial_dim],
+ // value[value_inds + dilate * (kh - half_kh) * width + dilate * (kw - half_kw)]);
+ // }
+ }
+ }
+ }
+ //if (value_inds % 10000 == 1) {
+ // printf("b: %d, g: %d, h: %d, w: %d, softmax_inds: %d, value_inds: %d, sum_val: %.4f\n",
+ // b, g, h, w, softmax_inds, value_inds, sum_val);
+ //}
+ output[index] = sum_val;
+ }
+}
+
+template
+__global__ void aggregation_value_backward_kernel(const int n,
+ const scalar_t* softmax_data,
+ const scalar_t* output_grad,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ scalar_t* value_grad) {
+ // n = batch_size * value_channels * height * width
+ CUDA_KERNEL_LOOP(index, n) {
+ const int w = index % width;
+ int h = index / width;
+ int c = h / height;
+ h = h % height;
+ const int b = c / value_channels;
+ c = c % value_channels;
+
+ const int value_per_group = value_channels / num_group;
+
+ const int g = c / value_per_group;
+ const int g_in_group = c % value_per_group;
+
+ const int half_kh = kernel_height / 2;
+ const int half_kw = kernel_width / 2;
+
+ const int spatial_dim = height * width;
+ scalar_t sum_val = 0;
+
+ int value_inds = (((b * num_group + g) * value_per_group + g_in_group) * in_height + h * stride) * in_width + w * stride;
+ int softmax_inds = ((b * num_group + g) * kernel_height * kernel_width * height + h) * width + w;
+
+ int start_kh = -half_kh / stride;
+ int end_kh = half_kh / stride;
+ int start_kw = -half_kw / stride;
+ int end_kw = half_kw / stride;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ sum_val += output_grad[index + spatial_offset]
+ * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride) * spatial_dim];
+ }
+ }
+ }
+ value_grad[value_inds] += sum_val;
+
+ if (stride == 2){
+ if (h * stride + 1 < in_height) {
+ sum_val = 0;
+ start_kh = (1 - half_kh) / stride;
+ end_kh = (half_kh + 1) / stride;
+ start_kw = -half_kw / stride;
+ end_kw = half_kw / stride;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ sum_val += output_grad[index + spatial_offset]
+ * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride) * spatial_dim];
+ }
+ }
+ }
+ value_grad[value_inds + in_width] += sum_val;
+ }
+ if (w * stride + 1 < in_width) {
+ sum_val = 0;
+ start_kh = -half_kh / stride;
+ end_kh = half_kh / stride;
+ start_kw = (1 - half_kw) / stride;
+ end_kw = (half_kw + 1) / stride;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ sum_val += output_grad[index + spatial_offset]
+ * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride) * kernel_width + half_kw - kw * stride + 1) * spatial_dim];
+ }
+ }
+ }
+ value_grad[value_inds + 1] += sum_val;
+ }
+ if (h * stride + 1 < in_height && w * stride + 1 < in_width) {
+ sum_val = 0;
+ start_kh = (1 - half_kh) / stride;
+ end_kh = (half_kh + 1) / stride;
+ start_kw = (1 - half_kw) / stride;
+ end_kw = (half_kw + 1) / stride;
+ for (int kh = start_kh; kh <= end_kh; ++kh) {
+ for (int kw = start_kw; kw <= end_kw; ++kw) {
+ if (dilate * kh + h >= 0 && dilate * kh + h < height && dilate * kw + w >= 0 && dilate * kw + w < width) {
+ int spatial_offset = dilate * kh * width + dilate * kw;
+ sum_val += output_grad[index + spatial_offset]
+ * softmax_data[softmax_inds + spatial_offset + ((half_kh - kh * stride + 1) * kernel_width + half_kw - kw * stride + 1) * spatial_dim];
+ }
+ }
+ }
+ value_grad[value_inds + in_width + 1] += sum_val;
+ }
+ }
+ }
+}
+
+template
+__global__ void aggregation_softmax_backward_kernel(const int n,
+ const scalar_t* value,
+ const scalar_t* output_grad,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ scalar_t* softmax_grad) {
+ // n = batch_size * num_group * kernel_height * kernel_width * height * width
+ CUDA_KERNEL_LOOP(index, n) {
+ const int w = index % width;
+ int h = index / width;
+ int kw = h / height;
+ h = h % height;
+ int kh = kw / kernel_width;
+ kw = kw % kernel_width;
+ int g = kh / kernel_height;
+ kh = kh % kernel_height;
+ const int b = g / num_group;
+ g = g % num_group;
+
+ const int half_kh = kernel_height / 2;
+ const int half_kw = kernel_width / 2;
+
+ const int value_per_group = value_channels / num_group;
+
+ const int spatial_dim = height * width;
+ const int in_spatial_dim = in_height * in_width;
+ scalar_t sum_val = 0;
+
+ int value_inds = ((b * num_group + g) * value_per_group * in_height + h * stride) * in_width + w * stride;
+ int output_inds = ((b * num_group + g) * value_per_group * height + h) * width + w;
+
+ if (w * stride + dilate * (kw - half_kw) >= 0 && w * stride + dilate * (kw - half_kw) < in_width && h * stride + dilate * (kh - half_kh) >= 0 && h * stride + dilate * (kh - half_kh) < in_height) {
+ for (int iv = 0; iv < value_per_group; ++iv) {
+ sum_val += output_grad[output_inds + iv * spatial_dim] * value[value_inds + iv * in_spatial_dim + dilate * (kh - half_kh) * in_width + dilate * (kw - half_kw)];
+ }
+ }
+ softmax_grad[index] = sum_val;
+ }
+}
+
+void similarity_compute_forward(
+ const at::Tensor key,
+ const at::Tensor query,
+ const at::Tensor pos_weight,
+ const int batch_size,
+ const int key_channels,
+ const int query_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const at::Tensor scale,
+ const at::Tensor no_define_value,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ const int sim_method,
+ at::Tensor output,
+ const int key_offset,
+ const int query_offset)
+{
+
+ const int num_kernels = batch_size * num_group * kernel_width * kernel_height * height * width;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ key.type(), "similarity_compute_forward_gpu", ([&] {
+ const scalar_t *key_ptr = key.data_ptr() + key_offset;
+ const scalar_t *query_ptr = query.data_ptr() + query_offset;
+ const scalar_t *pos_weight_ptr = pos_weight.data_ptr();
+ scalar_t *output_ptr = output.data_ptr();
+ const scalar_t *scale_ptr = scale.data_ptr();
+ const scalar_t *no_define_value_ptr = no_define_value.data_ptr();
+
+ similarity_compute_forward_kernel<<>>(
+ num_kernels, key_ptr, query_ptr, pos_weight_ptr,
+ batch_size, key_channels, query_channels, height, width,
+ kernel_height, kernel_width, num_group,
+ scale_ptr, no_define_value_ptr, dilate, stride, in_height, in_width, sim_method, output_ptr);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in similarity_compute_forward: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void similarity_compute_backward(
+ const at::Tensor key,
+ const at::Tensor query,
+ const at::Tensor output_grad,
+ const int batch_size,
+ const int key_channels,
+ const int query_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int key_per_group,
+ const at::Tensor scale,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ const int sim_method,
+ at::Tensor key_grad,
+ at::Tensor query_grad,
+ const int key_grad_offset,
+ const int query_grad_offset)
+{
+ const int num_kernels = batch_size * query_channels * height * width;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ key.type(), "similarity_compute_backward_gpu", ([&] {
+ // fixbug: add offset to key and query
+ const scalar_t *key_ptr = key.data_ptr() + key_grad_offset;
+ const scalar_t *query_ptr = query.data_ptr() + query_grad_offset;
+ const scalar_t *output_grad_ptr = output_grad.data_ptr();
+ scalar_t *key_grad_ptr = key_grad.data_ptr() + key_grad_offset;
+ scalar_t *query_grad_ptr = query_grad.data_ptr() + query_grad_offset;
+ const scalar_t *scale_ptr = scale.data_ptr();
+
+ similarity_compute_backward_kernel<<>>(
+ num_kernels, key_ptr, query_ptr, output_grad_ptr, batch_size,
+ key_channels, query_channels, height, width,
+ kernel_height, kernel_width, num_group,
+ key_per_group, scale_ptr, dilate, stride, in_height, in_width,
+ sim_method, key_grad_ptr, query_grad_ptr);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in similarity_compute_backward: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+void aggregation_forward(
+ const at::Tensor value,
+ const at::Tensor softmax_data,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ at::Tensor output,
+ const int value_offset,
+ const int output_offset)
+{
+ const int num_kernels = batch_size * value_channels * height * width;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ value.type(), "aggregation_forward_gpu", ([&] {
+ const scalar_t *value_ptr = value.data_ptr() + value_offset;
+ const scalar_t *softmax_data_ptr = softmax_data.data_ptr();
+ scalar_t *output_ptr = output.data_ptr() + output_offset;
+
+ aggregation_forward_kernel<<>>(
+ num_kernels, value_ptr, softmax_data_ptr,
+ batch_size, value_channels, height, width,
+ kernel_height, kernel_width, num_group,
+ dilate, stride, in_height, in_width,
+ output_ptr);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in aggregation_forward: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+void aggregation_value_backward(
+ const at::Tensor softmax_data,
+ const at::Tensor output_grad,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ at::Tensor value_grad,
+ const int output_grad_offset,
+ const int value_grad_offset)
+{
+ const int num_kernels = batch_size * value_channels * height * width;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ output_grad.type(), "aggregation_value_backward_gpu", ([&] {
+ const scalar_t *softmax_data_ptr = softmax_data.data_ptr();
+ const scalar_t *output_grad_ptr = output_grad.data_ptr() + output_grad_offset;
+ scalar_t *value_grad_ptr = value_grad.data_ptr() + value_grad_offset;
+
+ aggregation_value_backward_kernel<<>>(
+ num_kernels, softmax_data_ptr, output_grad_ptr,
+ batch_size, value_channels, height, width,
+ kernel_height, kernel_width, num_group,
+ dilate, stride, in_height, in_width,
+ value_grad_ptr);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in aggregation_value_backward: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+void aggregation_softmax_backward(
+ const at::Tensor value,
+ const at::Tensor output_grad,
+ const int batch_size,
+ const int value_channels,
+ const int height,
+ const int width,
+ const int kernel_height,
+ const int kernel_width,
+ const int num_group,
+ const int dilate,
+ const int stride,
+ const int in_height,
+ const int in_width,
+ at::Tensor softmax_grad,
+ const int value_offset,
+ const int output_grad_offset)
+{
+ const int num_kernels = batch_size * num_group * kernel_height * kernel_width * height * width;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ value.type(), "aggregation_softmax_backward_gpu", ([&] {
+
+ const scalar_t *value_ptr = value.data_ptr() + value_offset;
+ const scalar_t *output_grad_ptr = output_grad.data_ptr() + output_grad_offset;
+ scalar_t *softmax_grad_ptr = softmax_grad.data_ptr();
+
+ aggregation_softmax_backward_kernel<<>>(
+ num_kernels, value_ptr, output_grad_ptr,
+ batch_size, value_channels, height, width,
+ kernel_height, kernel_width, num_group,
+ dilate, stride, in_height, in_width,
+ softmax_grad_ptr);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in aggregation_softmax_backward: %s\n", cudaGetErrorString(err));
+ }
+}