Skip to content

vllm.models.deepseek_v4.attention

DeepseekV4 MLA Attention Layer

_resolve_dsv4_backend

_resolve_dsv4_backend(vllm_config: VllmConfig | None)

Return the explicitly-requested DSv4 sparse backend enum, or None.

Source code in vllm/models/deepseek_v4/attention.py
def _resolve_dsv4_backend(vllm_config: VllmConfig | None):
    """Return the explicitly-requested DSv4 sparse backend enum, or None."""
    if vllm_config is None:
        return None
    attn_config = getattr(vllm_config, "attention_config", None)
    return getattr(attn_config, "backend", None) if attn_config is not None else None

_resolve_dsv4_kv_cache_dtype

_resolve_dsv4_kv_cache_dtype(
    backend,
    kv_cache_dtype: str,
    cache_config: CacheConfig | None,
) -> tuple[str, dtype]

Map (backend, --kv-cache-dtype) to (cache_dtype_str, torch_dtype).

FlashInfer V4 reads a contiguous 512-wide KV row (bf16 or per-tensor FP8 E4M3); FlashMLA V4 reads the legacy UE8M0 paged layout (uint8 / fp8_ds_mla). For FlashMLA the canonical fp8_ds_mla string is written back onto cache_config so the page-size specs pick the 576B layout.

Source code in vllm/models/deepseek_v4/attention.py
def _resolve_dsv4_kv_cache_dtype(
    backend,
    kv_cache_dtype: str,
    cache_config: CacheConfig | None,
) -> tuple[str, torch.dtype]:
    """Map ``(backend, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``.

    FlashInfer V4 reads a contiguous 512-wide KV row (bf16 or per-tensor FP8
    E4M3); FlashMLA V4 reads the legacy UE8M0 paged layout (uint8 /
    ``fp8_ds_mla``).  For FlashMLA the canonical ``fp8_ds_mla`` string is
    written back onto ``cache_config`` so the page-size specs pick the 576B
    layout.
    """
    from vllm.v1.attention.backends.registry import AttentionBackendEnum

    if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
        if kv_cache_dtype.startswith("fp8"):
            return kv_cache_dtype, torch.float8_e4m3fn
        # auto / bfloat16 -> contiguous BF16 cache.
        return kv_cache_dtype, torch.bfloat16

    # FlashMLA (and ROCm Aiter): legacy UE8M0 paged uint8 cache.
    assert kv_cache_dtype.startswith("fp8"), (
        f"DeepseekV4 FlashMLA sparse backend only supports fp8 kv-cache, "
        f"got {kv_cache_dtype}"
    )
    if kv_cache_dtype != "fp8_ds_mla":
        if cache_config is not None:
            cache_config.cache_dtype = "fp8_ds_mla"
        kv_cache_dtype = "fp8_ds_mla"
        logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
    return kv_cache_dtype, torch.uint8

_select_v4_sparse_impl

_select_v4_sparse_impl(
    vllm_config: VllmConfig | None = None,
) -> type[DeepseekV4SparseMLAAttentionImpl]

Pick the V4 sparse MLA impl class.

An explicit --attention-backend FLASHINFER_MLA_SPARSE_DSV4 selects the FlashInfer TRTLLM-gen path; otherwise the platform default (FlashMLA on NVIDIA, ROCm Aiter on AMD) is used.

Source code in vllm/models/deepseek_v4/attention.py
def _select_v4_sparse_impl(
    vllm_config: VllmConfig | None = None,
) -> "type[DeepseekV4SparseMLAAttentionImpl]":
    """Pick the V4 sparse MLA impl class.

    An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the
    FlashInfer TRTLLM-gen path; otherwise the platform default (FlashMLA on
    NVIDIA, ROCm Aiter on AMD) is used.
    """
    from vllm.v1.attention.backends.registry import AttentionBackendEnum

    backend = _resolve_dsv4_backend(vllm_config)
    if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
        from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import (
            DeepseekV4FlashInferMLASparseImpl,
        )

        logger.info_once("Using FLASHINFER_MLA_SPARSE_DSV4 backend.")
        return DeepseekV4FlashInferMLASparseImpl
    if current_platform.is_rocm():
        from vllm.models.deepseek_v4.amd.rocm import (
            DeepseekV4ROCMAiterMLASparseImpl,
        )

        logger.info_once("Using ROCM_FLASHMLA_SPARSE_DSV4 backend.")
        return DeepseekV4ROCMAiterMLASparseImpl
    from vllm.models.deepseek_v4.nvidia.flashmla import (
        DeepseekV4FlashMLASparseImpl,
    )

    logger.info_once("Using FLASHMLA_SPARSE_DSV4 backend.")
    return DeepseekV4FlashMLASparseImpl