How to finetune the model using multiple GPUs ?

#45
by Space192 - opened

I'm trying to finetune the model with SentenceTransformer using FSDP but I get errors on inconsistent device id.

Do you guys have a template python script to fine tune with FSDP ?

Hello!

Although FSDP is possible (with some limits and requirements), FSDP trains a bit slower than the simpler Distributed Data Parallel (DDP) approach. FSDP notably divides up the model itself between all GPUs, but when the model itself is very small, then this does not save meaningful amounts of memory & only introduces additional overhead due to the extra communication between devices. In short, I would recommend using DDP instead of FSDP for smaller models like this one.

The SentenceTransformerTrainer supports DDP. You will have to do 2 things:

  • Wrap the bulk of your code in if __name__ == "__main__":, e.g.:
    from sentence_transformers import SentenceTransformer, SentenceTransformerTrainingArguments, SentenceTransformerTrainer
    # Other imports here
    
    def main():
        # Your training code here
    
    if __name__ == "__main__":
        main()
    
    This is fairly common practice when doing multi-GPU training.
  • Call your training script with e.g. torchrun --nproc_per_node=4 train_script.py or accelerate launch --num_processes 4 train_script.py (this assumes 4 GPUs, modify it to your number) instead of python train_script.py. The former two commands will run the training script with 4 processes, instead of just one.

See https://sbert.net/docs/sentence_transformer/training/distributed.html for more documentation.

  • Tom Aarsen

Well the thing is that with the 8192 context length (I'm using all the tokens) it doesn't fit on a single H100 100GB with a batch size of 2 and bf16 enabled. (I'm using the provided script in this discussion) https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/24#667ee03d53387df99a6c8eaf

so that's why I'm trying to use FSDP

@tomaarsen
here is the script I'm using if you wanna try and reproduce the bug (I know on this dataset I don't need to use FSDP but on the (non-public) dataset I'm using I need it):

import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator


def main():



    # 1. Load a model to finetune with 2. (Optional) model card data
    model = SentenceTransformer(
        "jinaai/jina-embeddings-v2-base-en",
        trust_remote_code=True,
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="apache-2.0",
            model_name="jina-embeddings-v2-base-en trained on Natural Questions pairs",
        ),
    )
    model_name = "jina-v2-base-natural-questions"

    


    # 3. Load a dataset to finetune on
    dataset = load_dataset("sentence-transformers/natural-questions", split="train")
    dataset = dataset.add_column("id", range(len(dataset)))
    train_dataset: Dataset = dataset.select(range(90_000))
    eval_dataset: Dataset = dataset.select(range(90_000, len(dataset)))

    # 4. Define a loss function
    loss = MultipleNegativesRankingLoss(model)


    # 5. (Optional) Specify training arguments
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=f"models/{model_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=False,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        # Optional tracking/debugging parameters:
        evaluation_strategy="steps",
        eval_steps=200,
        save_strategy="steps",
        save_steps=200,
        save_total_limit=2,
        logging_steps=200,
        report_to=None,
        logging_first_step=True,
        dataloader_drop_last=True,
        fsdp=["full_shard", "auto_wrap"],
        fsdp_config={"transformer_layer_cls_to_wrap": "JinaBertLayer"}
    )

    # 7. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset.remove_columns("id"),
        eval_dataset=eval_dataset.remove_columns("id"),
        loss=loss,
    )
    trainer.train()

    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

    trainer.save_model("output")

if __name__ == "__main__":
    main()

Okay so I did find find a solution in the end !
You simply need to add the following:

model = SentenceTransformer(
        "jinaai/jina-embeddings-v2-base-en",
        trust_remote_code=True,
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="apache-2.0",
            model_name="jina-embeddings-v2-base-en trained on Natural Questions pairs",
            device="cuda:" + os.environ['LOCAL_RANK']
        ),
    )
bwang0911 changed discussion status to closed

Sign up or log in to comment