Spaces:
Runtime error
Runtime error
File size: 2,128 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from datamodules.utils import get_configs
from transformers import (
ConvNextConfig,
ConvNextForImageClassification,
PreTrainedModel,
ViTConfig,
ViTForImageClassification,
)
import argparse
import torch
def set_clf_head(base: PreTrainedModel, num_classes: int):
"""Set the classification head of the model in case of an output mismatch.
Args:
base (PreTrainedModel): the model to modify
num_classes (int): the number of classes to use for the output layer
"""
if base.classifier.out_features != num_classes:
in_features = base.classifier.in_features
base.classifier = torch.nn.Linear(in_features, num_classes)
def model_factory(
args: argparse.Namespace,
own_config: bool = False,
) -> PreTrainedModel:
"""A factory method for creating a HuggingFace model based on the command line args.
Args:
args (Namespace): the argparse Namespace object
own_config (bool): whether to create our own model config instead of a pretrained one;
this is recommended when the model was pre-trained on another task with a different
amount of classes for its classifier head
Returns:
a PreTrainedModel instance
"""
if args.base_model == "ViT":
# Create a new Vision Transformer
config_class = ViTConfig
base_class = ViTForImageClassification
elif args.base_model == "ConvNeXt":
# Create a new ConvNext model
config_class = ConvNextConfig
base_class = ConvNextForImageClassification
else:
raise Exception(f"Unknown base model: {args.base_model}")
# Get the model config
model_cfg_args, _ = get_configs(args)
if not own_config and args.from_pretrained:
# Create a model from a pretrained model
base = base_class.from_pretrained(args.from_pretrained)
# Set the classifier head if needed
set_clf_head(base, model_cfg_args["num_labels"])
else:
# Create a model based on the config
config = config_class(**model_cfg_args)
base = base_class(config)
return base
|