add LR-Net V2

This commit is contained in:
v-zeliu1
2021-04-13 00:44:08 +08:00
parent 3dc2a55301
commit be50b6cc51
7 changed files with 1471 additions and 91 deletions

4
ops/local_relation/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
*.so
_ext*
__pycache__
build

View File

@@ -0,0 +1 @@
from .local_relation_func import local_relation

View File

@@ -0,0 +1,102 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2019 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Han Hu, Jiarui Xu
# Modified by Ze Liu
# --------------------------------------------------------
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from . import local_relation_cuda
class LocalRelationFunction(Function):
@staticmethod
def forward(ctx,
query,
key,
value,
pos_weight,
kernel_size,
groups,
stride=1,
dilation=1,
scale=1.,
no_define_value=-100.,
norm_method=0,
sim_method=0,
batch_step=32):
for input in [query, key, value]:
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.kernel_size = _pair(kernel_size)
ctx.groups = groups
ctx.stride = stride
ctx.dilation = dilation
ctx.scale = scale
ctx.no_define_value = no_define_value
ctx.norm_method = norm_method
ctx.sim_method = sim_method
ctx.batch_step = batch_step
ctx.save_for_backward(query, key, value, pos_weight)
output = query.new_empty(
LocalRelationFunction._output_size(query, value))
scale_tensor = query.new_tensor([ctx.scale])
no_define_value_tensor = query.new_tensor([ctx.no_define_value])
if not input.is_cuda:
raise NotImplementedError
else:
batch_step = min(ctx.batch_step, query.shape[0])
local_relation_cuda.local_relation_forward_cuda(
query, key, value, pos_weight, scale_tensor, no_define_value_tensor,
output, ctx.kernel_size[1], ctx.kernel_size[0], ctx.groups,
ctx.dilation, ctx.stride, batch_step, ctx.norm_method, ctx.sim_method)
return output
@staticmethod
def backward(ctx, grad_output):
query, key, value, pos_weight = ctx.saved_tensors
grad_query = grad_key = grad_value = grad_pos_weight = None
scale_tensor = query.new_tensor(ctx.scale)
no_define_value_tensor = query.new_tensor(ctx.no_define_value)
if not grad_output.is_cuda:
raise NotImplementedError
else:
batch_step = min(ctx.batch_step, query.shape[0])
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2] or ctx.needs_input_grad[3]:
grad_query = torch.zeros_like(query)
grad_key = torch.zeros_like(key)
grad_value = torch.zeros_like(value)
grad_pos_weight = torch.zeros_like(pos_weight)
local_relation_cuda.local_relation_backward_cuda(
query, key, value, pos_weight,
scale_tensor, no_define_value_tensor, grad_output,
grad_query, grad_key, grad_value, grad_pos_weight,
ctx.kernel_size[1], ctx.kernel_size[0],
ctx.groups, ctx.dilation, ctx.stride, batch_step,
ctx.norm_method, ctx.sim_method)
return (grad_query, grad_key, grad_value, grad_pos_weight, None, None, None,
None, None, None, None, None, None)
@staticmethod
def _output_size(query, value):
output_size = (query.size(0), value.size(1), query.size(2), query.size(3))
return output_size
local_relation = LocalRelationFunction.apply

View File

@@ -0,0 +1,12 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='local_relation',
ext_modules=[
CUDAExtension('local_relation_cuda', [
'src/local_relation_cuda.cpp',
'src/local_relation_cuda_kernel.cu',
]),
],
cmdclass={'build_ext': BuildExtension})

View File

@@ -0,0 +1,331 @@
/*!
* Copyright (c) 2019 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file local_relation_cuda.cpp
* \brief
* \author Han Hu
* \modified by Jiarui Xu, Ze Liu
*/
#include <torch/extension.h>
#include <cmath>
#include <vector>
void similarity_compute_forward(
const at::Tensor key,
const at::Tensor query,
const at::Tensor pos_weight,
const int batch_size,
const int key_channels,
const int query_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const at::Tensor scale,
const at::Tensor no_define_value,
const int dilate,
const int stride,
const int in_height,
const int in_width,
const int sim_method,
at::Tensor output,
const int key_offset,
const int query_offset);
void similarity_compute_backward(
const at::Tensor key,
const at::Tensor query,
const at::Tensor output_grad,
const int batch_size,
const int key_channels,
const int query_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int key_per_group,
const at::Tensor scale,
const int dilate,
const int stride,
const int in_height,
const int in_width,
const int sim_method,
at::Tensor key_grad,
at::Tensor query_grad,
const int key_grad_offset,
const int query_grad_offset);
void aggregation_forward(
const at::Tensor value,
const at::Tensor softmax_data,
const int batch_size,
const int value_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int in_height,
const int in_width,
at::Tensor output,
const int value_offset,
const int output_offset);
void aggregation_value_backward(
const at::Tensor softmax_data,
const at::Tensor output_grad,
const int batch_size,
const int value_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int in_height,
const int in_width,
at::Tensor value_grad,
const int output_grad_offset,
const int value_grad_offset);
void aggregation_softmax_backward(
const at::Tensor value,
const at::Tensor output_grad,
const int batch_size,
const int value_channels,
const int height,
const int width,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int in_height,
const int in_width,
at::Tensor softmax_grad,
const int value_offset,
const int output_grad_offset);
int local_relation_forward_cuda(
at::Tensor query,
at::Tensor key,
at::Tensor value,
at::Tensor pos_weight,
at::Tensor scale,
at::Tensor no_define_value,
at::Tensor output,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int batch_step,
const int norm_method,
const int sim_method)
{
query = query.contiguous();
key = key.contiguous();
value = value.contiguous();
pos_weight = pos_weight.contiguous();
const int query_channels = query.size(1);
const int key_channels = key.size(1);
const int value_channels = value.size(1);
const int batch_size = key.size(0);
const int height = query.size(2);
const int width = query.size(3);
const int in_height = key.size(2);
const int in_width = key.size(3);
const int batch_step_ = std::min(batch_size, batch_step);
const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width;
const int key_step = batch_step_ * key_channels * in_height * in_width;
const int query_step = batch_step_ * query_channels * height * width;
const int value_step = batch_step_ * value_channels * in_height * in_width;
const int output_step = batch_step_ * value_channels * height * width;
at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width});
int M = (batch_size - 1) / batch_step_ + 1;
for (int i = 0; i < M; ++i) {
int cur_batch_step = batch_step_;
if (i == M - 1) {
cur_batch_step = batch_size - (M - 1) * batch_step_;
if (cur_batch_step != batch_step_) {
sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options());
sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options());
}
// TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size");
}
similarity_compute_forward(key, query, pos_weight, cur_batch_step,
key_channels, query_channels, height, width,
kernel_height, kernel_width, num_group, scale, no_define_value,
dilate, stride, in_height, in_width, sim_method, sim_buffer,
key_step * i, query_step * i);
// softmax
if (norm_method == 0) {
softmax_buffer = sim_buffer.softmax(1);
}
else {
AT_ERROR("Not implemented yet");
}
aggregation_forward(value, softmax_buffer, cur_batch_step,
value_channels, height, width, kernel_height, kernel_width,
num_group, dilate, stride, in_height, in_width, output, value_step * i, output_step * i);
}
return 1;
}
int local_relation_backward_cuda(
at::Tensor query,
at::Tensor key,
at::Tensor value,
at::Tensor pos_weight,
at::Tensor scale,
at::Tensor no_define_value,
at::Tensor output_grad,
at::Tensor query_grad,
at::Tensor key_grad,
at::Tensor value_grad,
at::Tensor pos_weight_grad,
const int kernel_height,
const int kernel_width,
const int num_group,
const int dilate,
const int stride,
const int batch_step,
const int norm_method,
const int sim_method)
{
query = query.contiguous();
key = key.contiguous();
value = value.contiguous();
pos_weight = pos_weight.contiguous();
output_grad = output_grad.contiguous();
query_grad = query_grad.contiguous();
key_grad = key_grad.contiguous();
value_grad = value_grad.contiguous();
pos_weight_grad = pos_weight_grad.contiguous();
const int query_channels = query.size(1);
const int key_channels = key.size(1);
const int value_channels = value.size(1);
const int batch_size = key.size(0);
const int height = query.size(2);
const int width = query.size(3);
const int in_height = key.size(2);
const int in_width = key.size(3);
const int key_per_group = query_channels / num_group;
const int batch_step_ = std::min(batch_size, batch_step);
const int sim_size = batch_step_ * num_group * kernel_height * kernel_width * height * width;
const int key_step = batch_step_ * key_channels * in_height * in_width;
const int query_step = batch_step_ * query_channels * height * width;
const int value_step = batch_step_ * value_channels * in_height * in_width;
const int output_step = batch_step_ * value_channels * height * width;
at::Tensor sim_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor softmax_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
at::Tensor sum_softmax_buffer = at::zeros({batch_step_ * num_group, height * width},
query.options());
at::Tensor sim_grad_buffer = at::zeros({batch_step_ * num_group, kernel_height * kernel_width, height * width},
query.options());
int M = (batch_size - 1) / batch_step_ + 1;
const int pos_weight_size = num_group * kernel_height * kernel_width;
for (int i = 0; i < M; ++i) {
int cur_batch_step = batch_step_;
if (i == M - 1) {
cur_batch_step = batch_size - (M - 1) * batch_step_;
if (cur_batch_step != batch_step_) {
sim_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
softmax_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width},query.options());
sum_softmax_buffer = at::zeros({cur_batch_step * num_group, height * width}, query.options());
sim_grad_buffer = at::zeros({cur_batch_step * num_group, kernel_height * kernel_width, height * width}, query.options());
}
// TORCH_CHECK(cur_batch_step % batch_step_ == 0, "batch_step must be divided by batch_size");
}
similarity_compute_forward(key, query, pos_weight, cur_batch_step,
key_channels, query_channels, height, width,
kernel_height, kernel_width, num_group, scale, no_define_value,
dilate, stride, in_height, in_width, sim_method, sim_buffer,
key_step * i, query_step * i);
// softmax
if (norm_method == 0) {
softmax_buffer = sim_buffer.softmax(1);
}
else {
AT_ERROR("Not implemented yet");
}
aggregation_value_backward(softmax_buffer, output_grad, cur_batch_step,
value_channels, height, width, kernel_height, kernel_width,
num_group, dilate, stride, in_height, in_width, value_grad,
output_step * i, value_step * i);
aggregation_softmax_backward(value, output_grad, cur_batch_step,
value_channels, height, width, kernel_height, kernel_width,
num_group, dilate, stride, in_height, in_width, sim_buffer,
value_step * i, output_step * i);
if (norm_method == 0) {
sum_softmax_buffer = (softmax_buffer * sim_buffer).sum(1, true);
sim_grad_buffer = softmax_buffer * (sim_buffer - sum_softmax_buffer);
}
else {
AT_ERROR("Not implemented yet");
}
similarity_compute_backward(key, query, sim_grad_buffer, cur_batch_step,
key_channels, query_channels, height, width,
kernel_height, kernel_width, num_group, key_per_group, scale,
dilate, stride, in_height, in_width, sim_method, key_grad, query_grad,
key_step * i, query_step * i);
pos_weight_grad += sim_grad_buffer.view({cur_batch_step, num_group, kernel_height, kernel_width, height * width}).sum(4).sum(0);
}
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("local_relation_forward_cuda", &local_relation_forward_cuda,
"local relation forward (CUDA)");
m.def("local_relation_backward_cuda", &local_relation_backward_cuda,
"local relation backward (CUDA)");
}

File diff suppressed because it is too large Load Diff