--- license: mit language: - en tags: - sentence-embedding - sentence-similarity - transformers - feature-extraction pipeline_tag: sentence-similarity --- # Gemma-2B-Text-Embedding-cft ## Description This is a fine-tuned version of [Gemma-2b-it](https://huggingface.co/google/gemma-2b-it) to perform Text Embedding tasks. The model is fine-tuned using the Contrastive Fine-tuning and LoRA technique on NLI datasets. The paper can be found [here](https://arxiv.org/abs/2408.00690). ## Base Model [Gemma-2b-it](https://huggingface.co/google/gemma-2b-it) ## Usage 1. Clone Gemma-2b-it repository ```bash git clone https://huggingface.co/google/gemma-2b-it ``` 2. Change a tokenizer setting in `tokenizer_config.json` ```json "add_eos_token": true ``` 3. Use the model ```python from transformers import AutoModelForCausalLM, AutoTokenizer import torch import numpy as np class GemmaSentenceEmbedding: def __init__(self, model_path='google/gemma-2b-it', adapter_path=None): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map='cuda', trust_remote_code=True) if adapter_path != None: # Load fine-tuned LoRA self.model.load_adapter(adapter_path) def get_last_hidden_state(self, text): inputs = self.tokenizer(text, return_tensors="pt").to('cuda') with torch.no_grad(): out = self.model(**inputs, output_hidden_states=True).hidden_states[-1][0, -1, :] return out.squeeze().float().cpu().numpy() def encode(self, sentences: list[str], **kwargs) -> list[np.ndarray]: """ Returns a list of embeddings for the given sentences. Args: sentences: List of sentences to encode Returns: List of embeddings for the given sentences """ out = [] for s in sentences: out.append(self.get_last_hidden_state(s)) return out gemma_sentence_embedding = GemmaSentenceEmbedding(, 'trapoom555/Gemma-2B-Text-Embedding-cft') example_sentences = ["I don't like apples", "I like apples"] encoded_sentences = gemma_sentence_embedding.encode(example_sentences) print(encoded_sentences) ``` ## Training Details | **Training Details** | **Value** | |-------------------------|-------------------| | Loss | InfoNCE | | Batch Size | 60 | | InfoNCE Temperature | 0.05 | | Learning Rate | 5e-05 | | Warmup Steps | 100 | | Learning Rate Scheduler | CosineAnnealingLR | | LoRA Rank | 8 | | LoRA Alpha | 32 | | LoRA Dropout | 0.1 | | Training Precision | bf16 | | Max Epoch | 1 | | GPU | RTX3090 | | Num GPUs | 4 | ## Training Scripts The training script for this model is written in this [Github repository](https://github.com/trapoom555/Language-Model-STS-CFT/tree/main). ## Checkpoints We provide checkpoints every 500 training steps which can be found [here](https://huggingface.co/trapoom555/Gemma-2B-Text-Embedding-cft-checkpoints). ## Evaluation Results | **Benchmarks** | **Before cft** | **After cft** | |----------------|----------------|---------------| | STS12 | 43.83 | 75.80 | | STS13 | 66.36 | 85.45 | | STS14 | 49.57 | 80.08 | | STS15 | 57.40 | 85.02 | | STS16 | 70.13 | 83.33 | | STS17 | 58.34 | 88.22 | | STSBenchmark | 57.36 | 85.61 | | BOISSES | 48.67 | 73.83 | | SICK-R | 58.02 | 76.69 | | **Overall** | **56.63** | **81.56** | ## Contributors Trapoom Ukarapol, Zhicheng Lee, Amy Xin ## Foot Notes This work is the final project of the Natural Language Processing Spring 2024 course at Tsinghua University 🟣. We would like to express our sincere gratitude to this course !