Skip to content

vLLM Wrapper

vLLM

Note

vLLM currently supports only a limited number of models, and many implementations have subtle differences compared to the default implementations in mteb. For the full list of supported models, refer to the vllm documentation.

Installation

If you're using cuda you can run

pip install "mteb[vllm]"
uv pip install "mteb[vllm]"

For other architectures, please refer to the vllm installation guide.

Usage

To use vLLM with MTEB you have to wrap the model with its respective wrapper.

Note

you must update your Python code to guard usage of vllm behind a if name == 'main': block. For example, instead of this:

import vllm

llm = vllm.LLM(...)
try this instead:
if __name__ == '__main__':
    import vllm

    llm = vllm.LLM(...)

See more troubleshooting

import mteb
from mteb.models.vllm_wrapper import VllmEncoderWrapper

def run_vllm_encoder():
    """Evaluate a model on specified MTEB tasks using vLLM for inference."""
    encoder = VllmEncoderWrapper(model="intfloat/e5-small")
    return mteb.evaluate(
        encoder,
        mteb.get_task("STS12"),
    )

if __name__ == "__main__":
    results = run_vllm_encoder()
    print(results)
import mteb
from mteb.models.vllm_wrapper import VllmCrossEncoderWrapper

def run_vllm_crossencoder():
    """Evaluate a model on specified MTEB tasks using vLLM for inference."""
    cross_encoder = VllmCrossEncoderWrapper(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
    return mteb.evaluate(
        cross_encoder,
        mteb.get_task("AskUbuntuDupQuestions"),
    )


if __name__ == "__main__":
    results = run_vllm_crossencoder()
    print(results)

Why is vLLM fast?

Half-Precision Inference

By default, vLLM uses Flash Attention, which only supports float16 and bfloat16 but not float32. vLLM does not optimize inference performance for float32.

The throughput using float16 is approximately four times that of float32. ST: using sentence transformers backend vLLM: using vLLM backend X-axis: Throughput (request/s) Y-axis: Latency, Time needed for one step (ms) <- logarithmic scale The curve lower right is better ↘

Note

Format Bits Exponent Fraction
float32 32 8 23
float16 16 5 10
bfloat16 16 8 7

If the model weights are stored in float32:

  • VLLM uses float16 for inference by default to inference a float32 model, it will keep numerical precision in most cases, for it have retains relatively more Fraction bits. However, due to the smaller Exponent part (only 5 bits), some models (e.g., the Gemma family) may risk producing NaN. VLLM maintains a list models that may cause NaN values and uses bfloat16 for inference by default.
  • Using bfloat16 for inference avoids NaN risks because its Exponent part matches float32 with 8 bits. However, with only 7 Fraction bits, numerical precision decreases noticeably.
  • Using float32 for inference incurs no precision loss but is about four times slower than float16/bfloat16.

If model weights are stored in float16 or bfloat16, vLLM defaults to using the original dtype for inference.

Quantization: With the advancement of open-source large models, fine-tuning of larger models for tasks like embedding and reranking is increasing. Exploring quantization methods to accelerate inference and reduce GPU memory usage may become necessary.

Unpadding

By default, Sentence Transformers (st) pads all inputs in a batch to the length of the longest one, which is undoubtedly very inefficient. VLLM avoids padding entirely during inference.

X-axis: Throughput (request/s) ST: using sentence transformers vLLM: using vLLM Y-axis: Latency, Time needed for one step (ms) <- logarithmic scale The curve lower right is better ↘

Sentence Transformers (st) suffers a noticeable drop in speed when handling requests with varied input lengths, whereas vLLM does not.

Others

For models using bidirectional attention, such as BERT, VLLM offers a range of performance optimizations:

  • Optimized CUDA kernels, including FlashAttention and FlashInfer integration
  • CUDA Graphs and torch.compile support to reduce overhead and accelerate execution
  • Support for tensor, pipeline, data, and expert parallelism for distributed inference
  • Multiple quantization schemes—GPTQ, AWQ, AutoRound, INT4, INT8, and FP8—for efficient deployment
  • Continuous batching of incoming requests to maximize throughput

For causal attention models, such as the Qwen3 reranker, the following optimizations are also applicable:

  • Efficient KV cache memory management via PagedAttention
  • Chunked prefill for improved memory handling during long-context processing
  • Prefix caching to accelerate repeated prompt processing

vLLM’s optimizations are primarily designed for and most effective with causal language models (generative models). For the full list of features, refer to the vllm documentation.

API Reference

mteb.models.vllm_wrapper.VllmWrapperBase

Wrapper for vllm serving engine.

Source code in mteb/models/vllm_wrapper.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class VllmWrapperBase:
    """Wrapper for vllm serving engine."""

    convert = "auto"
    mteb_model_meta: ModelMeta | None = None

    def __init__(
        self,
        model: str | ModelMeta,
        revision: str | None = None,
        *,
        trust_remote_code: bool = True,
        dtype: Dtype = "auto",
        head_dtype: Literal["model"] | Dtype | None = None,
        max_model_len: int | None = None,
        max_num_batched_tokens: int | None = None,
        max_num_seqs: int = 128,
        tensor_parallel_size: int = 1,
        enable_prefix_caching: bool | None = None,
        gpu_memory_utilization: float = 0.9,
        hf_overrides: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        enforce_eager: bool = False,
        **kwargs: Any,
    ):
        """Wrapper for vllm serving engine.

        Args:
            model: model name string.
            revision: The revision of the model to use.
            trust_remote_code: Whether to trust remote code execution when loading the model.
                Should be True for models with custom code.
            dtype: Data type for model weights. "auto" will automatically select appropriate
                dtype based on hardware and model capabilities. vllm uses flash attention by
                default, which does not support fp32. Therefore, it defaults to using fp16 for
                inference on fp32 models. Testing has shown a relatively small drop in accuracy.
                You can manually opt for fp32, but inference speed will be very slow.
            head_dtype: "head" refers to the last Linear layer(s) of an LLMs, such as the score
                or classifier in a classification model. Uses fp32 for the head by default to
                gain extra precision.
            max_model_len: Maximum sequence length (context window) supported by the model.
                If None, uses the model's default maximum length.
            max_num_batched_tokens: Maximum number of tokens to process in a single batch.
                If None, automatically determined.
            max_num_seqs: Maximum number of sequences to process concurrently.
            tensor_parallel_size: Number of GPUs for tensor parallelism.
            enable_prefix_caching: Whether to enable KV cache sharing for common prompt prefixes.
                If None, uses the model's default setting.
            gpu_memory_utilization: Target GPU memory utilization ratio (0.0 to 1.0).
            hf_overrides: Dictionary mapping Hugging Face configuration keys to override values.
            pooler_config: Controls the behavior of output pooling in pooling models.
            enforce_eager: Whether to disable CUDA graph optimization and use eager execution.
            **kwargs: Additional arguments to pass to the vllm serving engine model.
        """
        requires_package(
            self,
            "vllm",
            "Wrapper for vllm serving engine",
            install_instruction="pip install mteb[vllm]",
        )

        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        from vllm import LLM, EngineArgs

        hf_overrides = {} if hf_overrides is None else hf_overrides

        if head_dtype is not None:
            hf_overrides["head_dtype"] = head_dtype

        model_name = model if isinstance(model, str) else model.name

        if isinstance(model, ModelMeta):
            logger.info(
                "Using revision from model meta. Passed revision will be ignored"
            )
            revision = model.revision

        args = EngineArgs(
            model=model_name,
            revision=revision,
            runner="pooling",
            convert=self.convert,  # type: ignore[arg-type]
            max_model_len=max_model_len,
            max_num_batched_tokens=max_num_batched_tokens,
            max_num_seqs=max_num_seqs,
            tensor_parallel_size=tensor_parallel_size,
            enable_prefix_caching=enable_prefix_caching,
            gpu_memory_utilization=gpu_memory_utilization,
            hf_overrides=hf_overrides,
            pooler_config=pooler_config,
            enforce_eager=enforce_eager,
            trust_remote_code=trust_remote_code,
            dtype=dtype,
            **kwargs,
        )
        self.llm = LLM(**vars(args))

        if isinstance(model, str):
            self.mteb_model_meta = ModelMeta.from_hub(model=model, revision=revision)
        else:
            self.mteb_model_meta = model

        atexit.register(self.cleanup)

    def cleanup(self):
        """Clean up the VLLM distributed runtime environment and release GPU resources."""
        if self.llm is None:
            return

        from vllm.distributed import (  # type: ignore[import-not-found]
            cleanup_dist_env_and_memory,
        )

        self.llm = None
        gc.collect()
        cleanup_dist_env_and_memory()

    def __del__(self):
        try:
            self.cleanup()
        except Exception:
            pass

__init__(model, revision=None, *, trust_remote_code=True, dtype='auto', head_dtype=None, max_model_len=None, max_num_batched_tokens=None, max_num_seqs=128, tensor_parallel_size=1, enable_prefix_caching=None, gpu_memory_utilization=0.9, hf_overrides=None, pooler_config=None, enforce_eager=False, **kwargs)

Wrapper for vllm serving engine.

Parameters:

Name Type Description Default
model str | ModelMeta

model name string.

required
revision str | None

The revision of the model to use.

None
trust_remote_code bool

Whether to trust remote code execution when loading the model. Should be True for models with custom code.

True
dtype Dtype

Data type for model weights. "auto" will automatically select appropriate dtype based on hardware and model capabilities. vllm uses flash attention by default, which does not support fp32. Therefore, it defaults to using fp16 for inference on fp32 models. Testing has shown a relatively small drop in accuracy. You can manually opt for fp32, but inference speed will be very slow.

'auto'
head_dtype Literal['model'] | Dtype | None

"head" refers to the last Linear layer(s) of an LLMs, such as the score or classifier in a classification model. Uses fp32 for the head by default to gain extra precision.

None
max_model_len int | None

Maximum sequence length (context window) supported by the model. If None, uses the model's default maximum length.

None
max_num_batched_tokens int | None

Maximum number of tokens to process in a single batch. If None, automatically determined.

None
max_num_seqs int

Maximum number of sequences to process concurrently.

128
tensor_parallel_size int

Number of GPUs for tensor parallelism.

1
enable_prefix_caching bool | None

Whether to enable KV cache sharing for common prompt prefixes. If None, uses the model's default setting.

None
gpu_memory_utilization float

Target GPU memory utilization ratio (0.0 to 1.0).

0.9
hf_overrides dict[str, Any] | None

Dictionary mapping Hugging Face configuration keys to override values.

None
pooler_config PoolerConfig | None

Controls the behavior of output pooling in pooling models.

None
enforce_eager bool

Whether to disable CUDA graph optimization and use eager execution.

False
**kwargs Any

Additional arguments to pass to the vllm serving engine model.

{}
Source code in mteb/models/vllm_wrapper.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self,
    model: str | ModelMeta,
    revision: str | None = None,
    *,
    trust_remote_code: bool = True,
    dtype: Dtype = "auto",
    head_dtype: Literal["model"] | Dtype | None = None,
    max_model_len: int | None = None,
    max_num_batched_tokens: int | None = None,
    max_num_seqs: int = 128,
    tensor_parallel_size: int = 1,
    enable_prefix_caching: bool | None = None,
    gpu_memory_utilization: float = 0.9,
    hf_overrides: dict[str, Any] | None = None,
    pooler_config: PoolerConfig | None = None,
    enforce_eager: bool = False,
    **kwargs: Any,
):
    """Wrapper for vllm serving engine.

    Args:
        model: model name string.
        revision: The revision of the model to use.
        trust_remote_code: Whether to trust remote code execution when loading the model.
            Should be True for models with custom code.
        dtype: Data type for model weights. "auto" will automatically select appropriate
            dtype based on hardware and model capabilities. vllm uses flash attention by
            default, which does not support fp32. Therefore, it defaults to using fp16 for
            inference on fp32 models. Testing has shown a relatively small drop in accuracy.
            You can manually opt for fp32, but inference speed will be very slow.
        head_dtype: "head" refers to the last Linear layer(s) of an LLMs, such as the score
            or classifier in a classification model. Uses fp32 for the head by default to
            gain extra precision.
        max_model_len: Maximum sequence length (context window) supported by the model.
            If None, uses the model's default maximum length.
        max_num_batched_tokens: Maximum number of tokens to process in a single batch.
            If None, automatically determined.
        max_num_seqs: Maximum number of sequences to process concurrently.
        tensor_parallel_size: Number of GPUs for tensor parallelism.
        enable_prefix_caching: Whether to enable KV cache sharing for common prompt prefixes.
            If None, uses the model's default setting.
        gpu_memory_utilization: Target GPU memory utilization ratio (0.0 to 1.0).
        hf_overrides: Dictionary mapping Hugging Face configuration keys to override values.
        pooler_config: Controls the behavior of output pooling in pooling models.
        enforce_eager: Whether to disable CUDA graph optimization and use eager execution.
        **kwargs: Additional arguments to pass to the vllm serving engine model.
    """
    requires_package(
        self,
        "vllm",
        "Wrapper for vllm serving engine",
        install_instruction="pip install mteb[vllm]",
    )

    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    from vllm import LLM, EngineArgs

    hf_overrides = {} if hf_overrides is None else hf_overrides

    if head_dtype is not None:
        hf_overrides["head_dtype"] = head_dtype

    model_name = model if isinstance(model, str) else model.name

    if isinstance(model, ModelMeta):
        logger.info(
            "Using revision from model meta. Passed revision will be ignored"
        )
        revision = model.revision

    args = EngineArgs(
        model=model_name,
        revision=revision,
        runner="pooling",
        convert=self.convert,  # type: ignore[arg-type]
        max_model_len=max_model_len,
        max_num_batched_tokens=max_num_batched_tokens,
        max_num_seqs=max_num_seqs,
        tensor_parallel_size=tensor_parallel_size,
        enable_prefix_caching=enable_prefix_caching,
        gpu_memory_utilization=gpu_memory_utilization,
        hf_overrides=hf_overrides,
        pooler_config=pooler_config,
        enforce_eager=enforce_eager,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        **kwargs,
    )
    self.llm = LLM(**vars(args))

    if isinstance(model, str):
        self.mteb_model_meta = ModelMeta.from_hub(model=model, revision=revision)
    else:
        self.mteb_model_meta = model

    atexit.register(self.cleanup)

cleanup()

Clean up the VLLM distributed runtime environment and release GPU resources.

Source code in mteb/models/vllm_wrapper.py
137
138
139
140
141
142
143
144
145
146
147
148
def cleanup(self):
    """Clean up the VLLM distributed runtime environment and release GPU resources."""
    if self.llm is None:
        return

    from vllm.distributed import (  # type: ignore[import-not-found]
        cleanup_dist_env_and_memory,
    )

    self.llm = None
    gc.collect()
    cleanup_dist_env_and_memory()

Info

For all vLLM parameters, please refer to https://docs.vllm.ai/en/latest/configuration/engine_args/.

mteb.models.vllm_wrapper.VllmEncoderWrapper

Bases: AbsEncoder, VllmWrapperBase

vLLM wrapper for Encoder models.

Parameters:

Name Type Description Default
model str | ModelMeta

model name string or ModelMeta.

required
revision str | None

The revision of the model to use.

None
prompt_dict dict[str, str] | None

A dictionary mapping task names to prompt strings.

None
use_instructions bool

Whether to use instructions from the prompt_dict. When False, values from prompt_dict are used as static prompts (prefixes). When True, values from prompt_dict are used as instructions to be formatted using the instruction_template.

False
instruction_template str | Callable[[str, PromptType | None], str] | None

A template or callable to format instructions. Can be a string with '{instruction}' placeholder or a callable that takes the instruction and prompt type and returns a formatted string.

None
apply_instruction_to_documents bool

Whether to apply instructions to documents prompts.

True
**kwargs Any

Additional arguments to pass to the vllm serving engine model.

{}
Source code in mteb/models/vllm_wrapper.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class VllmEncoderWrapper(AbsEncoder, VllmWrapperBase):
    """vLLM wrapper for Encoder models.

    Args:
        model: model name string or ModelMeta.
        revision: The revision of the model to use.
        prompt_dict: A dictionary mapping task names to prompt strings.
        use_instructions: Whether to use instructions from the prompt_dict.
            When False, values from prompt_dict are used as static prompts (prefixes).
            When True, values from prompt_dict are used as instructions to be formatted
            using the instruction_template.
        instruction_template: A template or callable to format instructions.
            Can be a string with '{instruction}' placeholder or a callable that takes
            the instruction and prompt type and returns a formatted string.
        apply_instruction_to_documents: Whether to apply instructions to documents prompts.
        **kwargs: Additional arguments to pass to the vllm serving engine model.
    """

    convert = "embed"

    def __init__(
        self,
        model: str | ModelMeta,
        revision: str | None = None,
        prompt_dict: dict[str, str] | None = None,
        use_instructions: bool = False,
        instruction_template: (
            str | Callable[[str, PromptType | None], str] | None
        ) = None,
        apply_instruction_to_documents: bool = True,
        **kwargs: Any,
    ):
        if use_instructions and instruction_template is None:
            raise ValueError(
                "To use instructions, an instruction_template must be provided. "
                "For example, `Instruction: {instruction}`"
            )

        if (
            isinstance(instruction_template, str)
            and "{instruction}" not in instruction_template
        ):
            raise ValueError(
                "Instruction template must contain the string '{instruction}'."
            )

        self.prompts_dict = prompt_dict
        self.use_instructions = use_instructions
        self.instruction_template = instruction_template
        self.apply_instruction_to_passages = apply_instruction_to_documents
        super().__init__(
            model,
            revision,
            **kwargs,
        )

    def encode(
        self,
        inputs: DataLoader[BatchedInput],
        *,
        task_metadata: TaskMetadata,
        hf_split: str,
        hf_subset: str,
        prompt_type: PromptType | None = None,
        **kwargs: Any,
    ) -> Array:
        """Encodes the given sentences using the encoder.

        Args:
            inputs: The sentences to encode.
            task_metadata: The metadata of the task. Sentence-transformers uses this to
                determine which prompt to use from a specified dictionary.
            prompt_type: The name type of prompt. (query or passage)
            hf_split: Split of current task
            hf_subset: Subset of current task
            **kwargs: Additional arguments to pass to the encoder.

        Returns:
            The encoded sentences.
        """
        prompt = ""
        if self.use_instructions and self.prompts_dict is not None:
            prompt = self.get_task_instruction(task_metadata, prompt_type)
        elif self.prompts_dict is not None:
            prompt_name = self.get_prompt_name(task_metadata, prompt_type)
            if prompt_name is not None:
                prompt = self.prompts_dict.get(prompt_name, "")

        if (
            self.use_instructions
            and self.apply_instruction_to_passages is False
            and prompt_type == PromptType.document
        ):
            logger.info(
                f"No instruction used, because prompt type = {prompt_type.document}"
            )
            prompt = ""
        else:
            logger.info(
                f"Using instruction: '{prompt}' for task: '{task_metadata.name}' prompt type: '{prompt_type}'"
            )

        prompts = [prompt + text for batch in inputs for text in batch["text"]]
        outputs = self.llm.encode(
            prompts, pooling_task="embed", truncate_prompt_tokens=-1
        )
        embeddings = torch.stack([output.outputs.data for output in outputs])
        return embeddings

encode(inputs, *, task_metadata, hf_split, hf_subset, prompt_type=None, **kwargs)

Encodes the given sentences using the encoder.

Parameters:

Name Type Description Default
inputs DataLoader[BatchedInput]

The sentences to encode.

required
task_metadata TaskMetadata

The metadata of the task. Sentence-transformers uses this to determine which prompt to use from a specified dictionary.

required
prompt_type PromptType | None

The name type of prompt. (query or passage)

None
hf_split str

Split of current task

required
hf_subset str

Subset of current task

required
**kwargs Any

Additional arguments to pass to the encoder.

{}

Returns:

Type Description
Array

The encoded sentences.

Source code in mteb/models/vllm_wrapper.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def encode(
    self,
    inputs: DataLoader[BatchedInput],
    *,
    task_metadata: TaskMetadata,
    hf_split: str,
    hf_subset: str,
    prompt_type: PromptType | None = None,
    **kwargs: Any,
) -> Array:
    """Encodes the given sentences using the encoder.

    Args:
        inputs: The sentences to encode.
        task_metadata: The metadata of the task. Sentence-transformers uses this to
            determine which prompt to use from a specified dictionary.
        prompt_type: The name type of prompt. (query or passage)
        hf_split: Split of current task
        hf_subset: Subset of current task
        **kwargs: Additional arguments to pass to the encoder.

    Returns:
        The encoded sentences.
    """
    prompt = ""
    if self.use_instructions and self.prompts_dict is not None:
        prompt = self.get_task_instruction(task_metadata, prompt_type)
    elif self.prompts_dict is not None:
        prompt_name = self.get_prompt_name(task_metadata, prompt_type)
        if prompt_name is not None:
            prompt = self.prompts_dict.get(prompt_name, "")

    if (
        self.use_instructions
        and self.apply_instruction_to_passages is False
        and prompt_type == PromptType.document
    ):
        logger.info(
            f"No instruction used, because prompt type = {prompt_type.document}"
        )
        prompt = ""
    else:
        logger.info(
            f"Using instruction: '{prompt}' for task: '{task_metadata.name}' prompt type: '{prompt_type}'"
        )

    prompts = [prompt + text for batch in inputs for text in batch["text"]]
    outputs = self.llm.encode(
        prompts, pooling_task="embed", truncate_prompt_tokens=-1
    )
    embeddings = torch.stack([output.outputs.data for output in outputs])
    return embeddings

mteb.models.vllm_wrapper.VllmCrossEncoderWrapper

Bases: VllmWrapperBase

vLLM wrapper for CrossEncoder models.

Source code in mteb/models/vllm_wrapper.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
class VllmCrossEncoderWrapper(VllmWrapperBase):
    """vLLM wrapper for CrossEncoder models."""

    convert = "classify"

    def __init__(
        self,
        model: str | ModelMeta,
        revision: str | None = None,
        query_prefix: str = "",
        document_prefix: str = "",
        **kwargs: Any,
    ):
        super().__init__(
            model,
            revision,
            **kwargs,
        )
        self.query_prefix = query_prefix
        self.document_prefix = document_prefix

    def predict(
        self,
        inputs1: DataLoader[BatchedInput],
        inputs2: DataLoader[BatchedInput],
        *,
        task_metadata: TaskMetadata,
        hf_split: str,
        hf_subset: str,
        prompt_type: PromptType | None = None,
        **kwargs: Any,
    ) -> Array:
        """Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.

        Args:
            inputs1: First Dataloader of inputs to encode. For reranking tasks, these are queries (for text only tasks `QueryDatasetType`).
            inputs2: Second Dataloader of inputs to encode. For reranking, these are documents (for text only tasks `RetrievalOutputType`).
            task_metadata: Metadata of the current task.
            hf_split: Split of current task, allows to know some additional information about current split.
                E.g. Current language
            hf_subset: Subset of current task. Similar to `hf_split` to get more information
            prompt_type: The name type of prompt. (query or passage)
            **kwargs: Additional arguments to pass to the cross-encoder.

        Returns:
            The predicted relevance scores for each inputs pair.
        """
        queries = [
            self.query_prefix + text for batch in inputs1 for text in batch["text"]
        ]
        corpus = [
            self.document_prefix + text for batch in inputs2 for text in batch["text"]
        ]
        # TODO: support score prompt

        outputs = self.llm.score(
            queries,
            corpus,
            truncate_prompt_tokens=-1,
            use_tqdm=False,
        )
        scores = np.array([output.outputs.score for output in outputs])
        return scores

predict(inputs1, inputs2, *, task_metadata, hf_split, hf_subset, prompt_type=None, **kwargs)

Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.

Parameters:

Name Type Description Default
inputs1 DataLoader[BatchedInput]

First Dataloader of inputs to encode. For reranking tasks, these are queries (for text only tasks QueryDatasetType).

required
inputs2 DataLoader[BatchedInput]

Second Dataloader of inputs to encode. For reranking, these are documents (for text only tasks RetrievalOutputType).

required
task_metadata TaskMetadata

Metadata of the current task.

required
hf_split str

Split of current task, allows to know some additional information about current split. E.g. Current language

required
hf_subset str

Subset of current task. Similar to hf_split to get more information

required
prompt_type PromptType | None

The name type of prompt. (query or passage)

None
**kwargs Any

Additional arguments to pass to the cross-encoder.

{}

Returns:

Type Description
Array

The predicted relevance scores for each inputs pair.

Source code in mteb/models/vllm_wrapper.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def predict(
    self,
    inputs1: DataLoader[BatchedInput],
    inputs2: DataLoader[BatchedInput],
    *,
    task_metadata: TaskMetadata,
    hf_split: str,
    hf_subset: str,
    prompt_type: PromptType | None = None,
    **kwargs: Any,
) -> Array:
    """Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.

    Args:
        inputs1: First Dataloader of inputs to encode. For reranking tasks, these are queries (for text only tasks `QueryDatasetType`).
        inputs2: Second Dataloader of inputs to encode. For reranking, these are documents (for text only tasks `RetrievalOutputType`).
        task_metadata: Metadata of the current task.
        hf_split: Split of current task, allows to know some additional information about current split.
            E.g. Current language
        hf_subset: Subset of current task. Similar to `hf_split` to get more information
        prompt_type: The name type of prompt. (query or passage)
        **kwargs: Additional arguments to pass to the cross-encoder.

    Returns:
        The predicted relevance scores for each inputs pair.
    """
    queries = [
        self.query_prefix + text for batch in inputs1 for text in batch["text"]
    ]
    corpus = [
        self.document_prefix + text for batch in inputs2 for text in batch["text"]
    ]
    # TODO: support score prompt

    outputs = self.llm.score(
        queries,
        corpus,
        truncate_prompt_tokens=-1,
        use_tqdm=False,
    )
    scores = np.array([output.outputs.score for output in outputs])
    return scores