|
import torch |
|
import torch.nn as nn |
|
|
|
from .attention import ( |
|
single_head_full_attention, |
|
single_head_full_attention_1d, |
|
single_head_split_window_attention, |
|
single_head_split_window_attention_1d, |
|
) |
|
from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d |
|
|
|
|
|
class TransformerLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model=128, |
|
nhead=1, |
|
no_ffn=False, |
|
ffn_dim_expansion=4, |
|
): |
|
super().__init__() |
|
|
|
self.dim = d_model |
|
self.nhead = nhead |
|
self.no_ffn = no_ffn |
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model, bias=False) |
|
self.k_proj = nn.Linear(d_model, d_model, bias=False) |
|
self.v_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
|
self.merge = nn.Linear(d_model, d_model, bias=False) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
if not self.no_ffn: |
|
in_channels = d_model * 2 |
|
self.mlp = nn.Sequential( |
|
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), |
|
nn.GELU(), |
|
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), |
|
) |
|
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
def forward( |
|
self, |
|
source, |
|
target, |
|
height=None, |
|
width=None, |
|
shifted_window_attn_mask=None, |
|
shifted_window_attn_mask_1d=None, |
|
attn_type="swin", |
|
with_shift=False, |
|
attn_num_splits=None, |
|
): |
|
|
|
query, key, value = source, target, target |
|
|
|
|
|
is_self_attn = (query - key).abs().max() < 1e-6 |
|
|
|
|
|
query = self.q_proj(query) |
|
key = self.k_proj(key) |
|
value = self.v_proj(value) |
|
|
|
if attn_type == "swin" and attn_num_splits > 1: |
|
if self.nhead > 1: |
|
|
|
|
|
raise NotImplementedError |
|
else: |
|
message = single_head_split_window_attention( |
|
query, |
|
key, |
|
value, |
|
num_splits=attn_num_splits, |
|
with_shift=with_shift, |
|
h=height, |
|
w=width, |
|
attn_mask=shifted_window_attn_mask, |
|
) |
|
|
|
elif attn_type == "self_swin2d_cross_1d": |
|
if self.nhead > 1: |
|
raise NotImplementedError |
|
else: |
|
if is_self_attn: |
|
if attn_num_splits > 1: |
|
message = single_head_split_window_attention( |
|
query, |
|
key, |
|
value, |
|
num_splits=attn_num_splits, |
|
with_shift=with_shift, |
|
h=height, |
|
w=width, |
|
attn_mask=shifted_window_attn_mask, |
|
) |
|
else: |
|
|
|
message = single_head_full_attention(query, key, value) |
|
|
|
else: |
|
|
|
message = single_head_full_attention_1d( |
|
query, |
|
key, |
|
value, |
|
h=height, |
|
w=width, |
|
) |
|
|
|
elif attn_type == "self_swin2d_cross_swin1d": |
|
if self.nhead > 1: |
|
raise NotImplementedError |
|
else: |
|
if is_self_attn: |
|
if attn_num_splits > 1: |
|
|
|
message = single_head_split_window_attention( |
|
query, |
|
key, |
|
value, |
|
num_splits=attn_num_splits, |
|
with_shift=with_shift, |
|
h=height, |
|
w=width, |
|
attn_mask=shifted_window_attn_mask, |
|
) |
|
else: |
|
|
|
message = single_head_full_attention(query, key, value) |
|
else: |
|
if attn_num_splits > 1: |
|
assert shifted_window_attn_mask_1d is not None |
|
|
|
message = single_head_split_window_attention_1d( |
|
query, |
|
key, |
|
value, |
|
num_splits=attn_num_splits, |
|
with_shift=with_shift, |
|
h=height, |
|
w=width, |
|
attn_mask=shifted_window_attn_mask_1d, |
|
) |
|
else: |
|
message = single_head_full_attention_1d( |
|
query, |
|
key, |
|
value, |
|
h=height, |
|
w=width, |
|
) |
|
|
|
else: |
|
message = single_head_full_attention(query, key, value) |
|
|
|
message = self.merge(message) |
|
message = self.norm1(message) |
|
|
|
if not self.no_ffn: |
|
message = self.mlp(torch.cat([source, message], dim=-1)) |
|
message = self.norm2(message) |
|
|
|
return source + message |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
"""self attention + cross attention + FFN.""" |
|
|
|
def __init__( |
|
self, |
|
d_model=128, |
|
nhead=1, |
|
ffn_dim_expansion=4, |
|
): |
|
super().__init__() |
|
|
|
self.self_attn = TransformerLayer( |
|
d_model=d_model, |
|
nhead=nhead, |
|
no_ffn=True, |
|
ffn_dim_expansion=ffn_dim_expansion, |
|
) |
|
|
|
self.cross_attn_ffn = TransformerLayer( |
|
d_model=d_model, |
|
nhead=nhead, |
|
ffn_dim_expansion=ffn_dim_expansion, |
|
) |
|
|
|
def forward( |
|
self, |
|
source, |
|
target, |
|
height=None, |
|
width=None, |
|
shifted_window_attn_mask=None, |
|
shifted_window_attn_mask_1d=None, |
|
attn_type="swin", |
|
with_shift=False, |
|
attn_num_splits=None, |
|
): |
|
|
|
|
|
|
|
source = self.self_attn( |
|
source, |
|
source, |
|
height=height, |
|
width=width, |
|
shifted_window_attn_mask=shifted_window_attn_mask, |
|
attn_type=attn_type, |
|
with_shift=with_shift, |
|
attn_num_splits=attn_num_splits, |
|
) |
|
|
|
|
|
source = self.cross_attn_ffn( |
|
source, |
|
target, |
|
height=height, |
|
width=width, |
|
shifted_window_attn_mask=shifted_window_attn_mask, |
|
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, |
|
attn_type=attn_type, |
|
with_shift=with_shift, |
|
attn_num_splits=attn_num_splits, |
|
) |
|
|
|
return source |
|
|
|
|
|
class FeatureTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
num_layers=6, |
|
d_model=128, |
|
nhead=1, |
|
ffn_dim_expansion=4, |
|
): |
|
super().__init__() |
|
|
|
self.d_model = d_model |
|
self.nhead = nhead |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
TransformerBlock( |
|
d_model=d_model, |
|
nhead=nhead, |
|
ffn_dim_expansion=ffn_dim_expansion, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward( |
|
self, |
|
feature0, |
|
feature1, |
|
attn_type="swin", |
|
attn_num_splits=None, |
|
**kwargs, |
|
): |
|
|
|
b, c, h, w = feature0.shape |
|
assert self.d_model == c |
|
|
|
feature0 = feature0.flatten(-2).permute(0, 2, 1) |
|
feature1 = feature1.flatten(-2).permute(0, 2, 1) |
|
|
|
|
|
if "swin" in attn_type and attn_num_splits > 1: |
|
|
|
window_size_h = h // attn_num_splits |
|
window_size_w = w // attn_num_splits |
|
|
|
|
|
shifted_window_attn_mask = generate_shift_window_attn_mask( |
|
input_resolution=(h, w), |
|
window_size_h=window_size_h, |
|
window_size_w=window_size_w, |
|
shift_size_h=window_size_h // 2, |
|
shift_size_w=window_size_w // 2, |
|
device=feature0.device, |
|
) |
|
else: |
|
shifted_window_attn_mask = None |
|
|
|
|
|
if "swin1d" in attn_type and attn_num_splits > 1: |
|
window_size_w = w // attn_num_splits |
|
|
|
|
|
shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( |
|
input_w=w, |
|
window_size_w=window_size_w, |
|
shift_size_w=window_size_w // 2, |
|
device=feature0.device, |
|
) |
|
else: |
|
shifted_window_attn_mask_1d = None |
|
|
|
|
|
concat0 = torch.cat((feature0, feature1), dim=0) |
|
concat1 = torch.cat((feature1, feature0), dim=0) |
|
|
|
for i, layer in enumerate(self.layers): |
|
concat0 = layer( |
|
concat0, |
|
concat1, |
|
height=h, |
|
width=w, |
|
attn_type=attn_type, |
|
with_shift="swin" in attn_type and attn_num_splits > 1 and i % 2 == 1, |
|
attn_num_splits=attn_num_splits, |
|
shifted_window_attn_mask=shifted_window_attn_mask, |
|
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, |
|
) |
|
|
|
|
|
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) |
|
|
|
feature0, feature1 = concat0.chunk(chunks=2, dim=0) |
|
|
|
|
|
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
|
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
|
|
|
return feature0, feature1 |
|
|