34 lines
1.6 KiB
Python
34 lines
1.6 KiB
Python
# --------------------------------------------------------
|
|
# Swin Transformer
|
|
# Copyright (c) 2021 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# Written by Ze Liu
|
|
# --------------------------------------------------------
|
|
|
|
from .swin_transformer import SwinTransformer
|
|
|
|
|
|
def build_model(config):
|
|
model_type = config.MODEL.TYPE
|
|
if model_type == 'swin':
|
|
model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
|
|
patch_size=config.MODEL.SWIN.PATCH_SIZE,
|
|
in_chans=config.MODEL.SWIN.IN_CHANS,
|
|
num_classes=config.MODEL.NUM_CLASSES,
|
|
embed_dim=config.MODEL.SWIN.EMBED_DIM,
|
|
depths=config.MODEL.SWIN.DEPTHS,
|
|
num_heads=config.MODEL.SWIN.NUM_HEADS,
|
|
window_size=config.MODEL.SWIN.WINDOW_SIZE,
|
|
mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
|
|
qkv_bias=config.MODEL.SWIN.QKV_BIAS,
|
|
qk_scale=config.MODEL.SWIN.QK_SCALE,
|
|
drop_rate=config.MODEL.DROP_RATE,
|
|
drop_path_rate=config.MODEL.DROP_PATH_RATE,
|
|
ape=config.MODEL.SWIN.APE,
|
|
patch_norm=config.MODEL.SWIN.PATCH_NORM,
|
|
use_checkpoint=config.TRAIN.USE_CHECKPOINT)
|
|
else:
|
|
raise NotImplementedError(f"Unkown model: {model_type}")
|
|
|
|
return model
|