Swin-Transformer/ops/local_relation/local_relation_func.py
2021-04-13 00:44:08 +08:00

103 lines
3.7 KiB
Python

# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2019 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Han Hu, Jiarui Xu
# Modified by Ze Liu
# --------------------------------------------------------
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from . import local_relation_cuda
class LocalRelationFunction(Function):
@staticmethod
def forward(ctx,
query,
key,
value,
pos_weight,
kernel_size,
groups,
stride=1,
dilation=1,
scale=1.,
no_define_value=-100.,
norm_method=0,
sim_method=0,
batch_step=32):
for input in [query, key, value]:
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.kernel_size = _pair(kernel_size)
ctx.groups = groups
ctx.stride = stride
ctx.dilation = dilation
ctx.scale = scale
ctx.no_define_value = no_define_value
ctx.norm_method = norm_method
ctx.sim_method = sim_method
ctx.batch_step = batch_step
ctx.save_for_backward(query, key, value, pos_weight)
output = query.new_empty(
LocalRelationFunction._output_size(query, value))
scale_tensor = query.new_tensor([ctx.scale])
no_define_value_tensor = query.new_tensor([ctx.no_define_value])
if not input.is_cuda:
raise NotImplementedError
else:
batch_step = min(ctx.batch_step, query.shape[0])
local_relation_cuda.local_relation_forward_cuda(
query, key, value, pos_weight, scale_tensor, no_define_value_tensor,
output, ctx.kernel_size[1], ctx.kernel_size[0], ctx.groups,
ctx.dilation, ctx.stride, batch_step, ctx.norm_method, ctx.sim_method)
return output
@staticmethod
def backward(ctx, grad_output):
query, key, value, pos_weight = ctx.saved_tensors
grad_query = grad_key = grad_value = grad_pos_weight = None
scale_tensor = query.new_tensor(ctx.scale)
no_define_value_tensor = query.new_tensor(ctx.no_define_value)
if not grad_output.is_cuda:
raise NotImplementedError
else:
batch_step = min(ctx.batch_step, query.shape[0])
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2] or ctx.needs_input_grad[3]:
grad_query = torch.zeros_like(query)
grad_key = torch.zeros_like(key)
grad_value = torch.zeros_like(value)
grad_pos_weight = torch.zeros_like(pos_weight)
local_relation_cuda.local_relation_backward_cuda(
query, key, value, pos_weight,
scale_tensor, no_define_value_tensor, grad_output,
grad_query, grad_key, grad_value, grad_pos_weight,
ctx.kernel_size[1], ctx.kernel_size[0],
ctx.groups, ctx.dilation, ctx.stride, batch_step,
ctx.norm_method, ctx.sim_method)
return (grad_query, grad_key, grad_value, grad_pos_weight, None, None, None,
None, None, None, None, None, None)
@staticmethod
def _output_size(query, value):
output_size = (query.size(0), value.size(1), query.size(2), query.size(3))
return output_size
local_relation = LocalRelationFunction.apply