from transformers import PreTrainedModel from AUNet import AUNet from AUNetConfig import AUNetConfig import torch class s2l8hModel(PreTrainedModel): config_class=AUNetConfig def __init__(self, config): super().__init__(config) self.model = AUNet( in_channels = config.in_channels, out_channels = config.out_channels, depth = config.depth, spatial_attention = config.spatial_attention, growth_factor = config.growth_factor, interp_mode = config.interp_mode, up_mode = config.up_mode, ca_layer = config.ca_layer ) def forward(self, MS, PAN): return self.model.forward(MS, PAN)