File size: 670 Bytes
40ed350
b4f3d8a
 
40ed350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import PreTrainedModel
from AUNet import AUNet
from AUNetConfig import AUNetConfig
import torch

class s2l8hModel(PreTrainedModel):
    config_class=AUNetConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = AUNet(
            in_channels = config.in_channels, out_channels = config.out_channels,
            depth = config.depth, spatial_attention = config.spatial_attention,
            growth_factor = config.growth_factor, interp_mode = config.interp_mode,
            up_mode = config.up_mode, ca_layer = config.ca_layer
        )


    def forward(self, MS, PAN):
        return self.model.forward(MS, PAN)