multi gpu eval #2938
Replies: 3 comments
-
I don't think that most of models support it, because SentenceTransformer don't support this UKPLab/sentence-transformers#2869 |
Beta Was this translation helpful? Give feedback.
-
Hi @riyajatar37003 (converted til to discussion instead of an issue) While the core logic in
|
Beta Was this translation helpful? Give feedback.
-
I tested Qwen3-embedding-0.6B with multi gpu using custom model class import mteb
from mteb.encoder_interface import PromptType
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
class CustomModel:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B', attn_implementation="flash_attention_2", torch_dtype=torch.float16, device_map="auto")
self.task = "Given a web search query, retrieve relevant passages that answer the query"
def encode(
self,
sentences: list[str],
task_name: str,
prompt_type: PromptType | None = None,
**kwargs,
) -> np.ndarray:
"""Encodes the given sentences using the encoder.
Args:
sentences: The sentences to encode.
task_name: The name of the task.
prompt_type: The prompt type to use.
**kwargs: Additional arguments to pass to the encoder.
Returns:
The encoded sentences.
"""
max_length = 8192
all_embeddings = []
batch_size = kwargs.pop('batch_size', 32)
if prompt_type == "query":
sentences = [get_detailed_instruct(self.task, sentence) for sentence in sentences]
batch_size = kwargs.pop('query_batch_size', 32)
else:
batch_size = kwargs.pop('corpus_batch_size', 1)
# Process sentences in batches
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i + batch_size]
# Tokenize the input texts
batch_dict = self.tokenizer(
batch_sentences,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
batch_dict.to(self.model.device)
outputs = self.model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# Move embeddings to CPU and convert to numpy, then append
all_embeddings.append(embeddings.cpu().detach().float().numpy())
torch.cuda.empty_cache() # Clear GPU memory after processing each batch
# Concatenate all batch embeddings
return np.concatenate(all_embeddings, axis=0)
model = CustomModel()
# benchmark = mteb.get_benchmark("MTEB(kor, v1)")
# benchmark = mteb.get_tasks(tasks=["KLUE-STS"])
# benchmark = mteb.get_tasks(tasks=["KLUE-TC"])
# benchmark = mteb.get_tasks(tasks=["KorSTS"])
# benchmark = mteb.get_tasks(tasks=["Ko-StrategyQA"])
# benchmark = mteb.get_tasks(languages=["kor"], tasks=["MIRACLReranking"])
benchmark = mteb.get_tasks(languages=["kor"], tasks=["MIRACLRetrieval"])
evaluation = mteb.MTEB(tasks=benchmark)
encode_kwargs ={
"query_batch_size": 64,
"corpus_batch_size": 2,
}
results = evaluation.run(model, output_folder=f"results_MTEB(kor,v1)/{model_name}", verbosity=3, encode_kwargs=encode_kwargs) if you have more gpus increase corpus_batch_size!
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
hi,
any example on how to do the eval on mulitiple gpus? i could not find anything
Beta Was this translation helpful? Give feedback.
All reactions