File size: 826 Bytes
40ed350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import PretrainedConfig
from typing import List


class AUNetConfig(PretrainedConfig):
    model_type = "s2l8hModel"
    def __init__(
            self,
            in_channels:int = 7,
            out_channels:int = 6,
            depth:int = 5,
            spatial_attention:str = 'None',
            growth_factor:int = 6,
            interp_mode:str = 'bicubic',
            up_mode:str = 'upsample',
            ca_layer:bool = False,
            **kwargs,
    ):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.depth = depth
        self.spatial_attention = spatial_attention
        self.growth_factor = growth_factor
        self.interp_mode = interp_mode
        self.up_mode = up_mode
        self.ca_layer = ca_layer

        super().__init__(**kwargs)