Allow loading via AutoModel

#2
by tomaarsen HF staff - opened

Hello!

Pull Request overview

  • Allow loading via AutoModel

Details

I wanted to experiment with loading this model with just AutoModel:

from transformers import AutoTokenizer, AutoModel

max_seq_length = 8192
testing_string = "Every morning, I make a cup of coffee to start my day."
model = AutoModel.from_pretrained(
    "togethercomputer/m2-bert-80M-8k-retrieval",
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    "bert-base-uncased",
    model_max_length=max_seq_length
)
input_ids = tokenizer(
    [testing_string],
    return_tensors="pt",
    padding="max_length",
    return_token_type_ids=False,
    truncation=True,
    max_length=max_seq_length
)

encoder_outputs, pooled_output = model(**input_ids)
print(encoder_outputs.shape, pooled_output.shape)

But I ran into this issue:

You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):
  File "c:\code\m2-bert-80M-8k-retrieval\demo.py", line 5, in <module>
    model = AutoModel.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom\.conda\envs\sentence-transformers\Lib\site-packages\transformers\models\auto\auto_factory.py", line 519, in from_pretrained
    raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers_modules.togethercomputer.m2-bert-80M-8k-retrieval.66ea5d6b12ab7e3d332bba708d76f83ce2909b2e.configuration_bert.BertConfig'> for this kind of AutoModel: AutoModel.
Model type should be one of AlbertConfig, AlignConfig, AltCLIPConfig, ASTConfig, AutoformerConfig, BarkConfig, BartConfig, BeitConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitConfig, BlenderbotConfig, BlenderbotSmallConfig, BlipConfig, Blip2Config, BloomConfig, BridgeTowerConfig, CamembertConfig, CanineConfig, ChineseCLIPConfig, ClapConfig, CLIPConfig, CLIPSegConfig, CodeGenConfig, ConditionalDetrConfig, ConvBertConfig, ConvNextConfig, ConvNextV2Config, CpmAntConfig, CTRLConfig, CvtConfig, Data2VecAudioConfig, Data2VecTextConfig, Data2VecVisionConfig, DebertaConfig, DebertaV2Config, DecisionTransformerConfig, DeformableDetrConfig, DeiTConfig, DetaConfig, DetrConfig, DinatConfig, Dinov2Config, DistilBertConfig, DonutSwinConfig, DPRConfig, DPTConfig, EfficientFormerConfig, EfficientNetConfig, ElectraConfig, EncodecConfig, ErnieConfig, ErnieMConfig, EsmConfig, FalconConfig, FlaubertConfig, FlavaConfig, FNetConfig, FocalNetConfig, FSMTConfig, FunnelConfig, GitConfig, GLPNConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GPTSanJapaneseConfig, GraphormerConfig, GroupViTConfig, HubertConfig, IBertConfig, IdeficsConfig, ImageGPTConfig, InformerConfig, JukeboxConfig, LayoutLMConfig, LayoutLMv2Config, LayoutLMv3Config, LEDConfig, LevitConfig, LiltConfig, LlamaConfig, LongformerConfig, LongT5Config, LukeConfig, LxmertConfig, M2M100Config, MarianConfig, MarkupLMConfig, Mask2FormerConfig, MaskFormerConfig, MaskFormerSwinConfig, MBartConfig, MCTCTConfig, MegaConfig, MegatronBertConfig, MgpstrConfig, MobileBertConfig, MobileNetV1Config, MobileNetV2Config, MobileViTConfig, MobileViTV2Config, MPNetConfig, MptConfig, MraConfig, MT5Config, MvpConfig, NatConfig, NezhaConfig, NllbMoeConfig, NystromformerConfig, OneFormerConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, OwlViTConfig, PegasusConfig, PegasusXConfig, PerceiverConfig, PLBartConfig, PoolFormerConfig, ProphetNetConfig, PvtConfig, QDQBertConfig, ReformerConfig, RegNetConfig, RemBertConfig, ResNetConfig, RetriBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, SamConfig, SegformerConfig, SEWConfig, SEWDConfig, Speech2TextConfig, SpeechT5Config, SplinterConfig, SqueezeBertConfig, SwiftFormerConfig, SwinConfig, Swin2SRConfig, Swinv2Config, SwitchTransformersConfig, T5Config, TableTransformerConfig, TapasConfig, TimeSeriesTransformerConfig, TimesformerConfig, TimmBackboneConfig, TrajectoryTransformerConfig, TransfoXLConfig, TvltConfig, UMT5Config, UniSpeechConfig, UniSpeechSatConfig, VanConfig, VideoMAEConfig, ViltConfig, VisionTextDualEncoderConfig, VisualBertConfig, ViTConfig, ViTHybridConfig, ViTMAEConfig, ViTMSNConfig, VivitConfig, Wav2Vec2Config, Wav2Vec2ConformerConfig, WavLMConfig, WhisperConfig, XCLIPConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, YolosConfig, YosoConfig.

In short, the custom BertConfig is not defined in AutoModel, so it doesn't know what class to initialize. This PR fixes this by also setting a value for AutoModel in the config.json.

After this PR

The above script should now return:

You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05
torch.Size([1, 8192, 768]) torch.Size([1, 768])

πŸŽ‰

  • Tom Aarsen
tomaarsen changed pull request status to open
Together org

Thank you!!

danfu09 changed pull request status to merged

Sign up or log in to comment