venkatesh-thiru's picture
Upload model
670e5e8 verified
raw
history blame
No virus
672 Bytes
from transformers import PreTrainedModel
from .AUNet import AUNet
from .AUNetConfig import AUNetConfig
import torch
class s2l8hModel(PreTrainedModel):
config_class=AUNetConfig
def __init__(self, config):
super().__init__(config)
self.model = AUNet(
in_channels = config.in_channels, out_channels = config.out_channels,
depth = config.depth, spatial_attention = config.spatial_attention,
growth_factor = config.growth_factor, interp_mode = config.interp_mode,
up_mode = config.up_mode, ca_layer = config.ca_layer
)
def forward(self, MS, PAN):
return self.model.forward(MS, PAN)