Skip to content

API Reference

This page is generated directly from Python docstrings.

Triton Interface

flash_sparse_attn.ops.triton.interface

flash_dense_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash dense attention function that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.

False
is_split_kv bool

Whether to enable split-KV for occupancy.

False
pack_gqa bool

Whether to pack grouped-query attention.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].

flash_dense_attn_with_kvcache_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash dense attention function for decoding with KV cache that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads].

flash_dense_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, is_split_kv: bool = False, pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash dense attention function for variable-length sequences that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [total_seqlen_q, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
cu_seqlens_q Tensor

Cumulative sequence lengths for queries, shape [batch_size + 1].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_q int

Maximum sequence length for queries.

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.

False
seqused_q Optional[Tensor]

Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking.

None
seqused_k Optional[Tensor]

Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.

None
is_split_kv bool

Whether to enable split-KV for occupancy.

False
pack_gqa bool

Whether to pack grouped-query attention.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].

flash_dense_attn_varlen_with_kvcache_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_k: int, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash dense attention function for variable-length decoding with KV cache that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization.

False
seqused_k Optional[Tensor]

Optional tensor indicating the actual sequence lengths for keys/values.

None
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, num_heads_q, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads_q].

flash_sparse_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, softmax_threshold: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash sparse attention function that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax. If None, defaults to head_dim / seqlen_k.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.

False
is_split_kv bool

Whether to enable split-KV for occupancy.

False
pack_gqa bool

Whether to pack grouped-query attention.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].

flash_sparse_attn_with_kvcache_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash sparse attention function for decoding with KV cache that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax. If None, defaults to head_dim / seqlen_k.

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads].

flash_sparse_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, softmax_threshold: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, is_split_kv: bool = False, pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash sparse attention function for variable-length sequences that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [total_seqlen_q, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
cu_seqlens_q Tensor

Cumulative sequence lengths for queries, shape [batch_size + 1].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_q int

Maximum sequence length for queries.

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax. If None, defaults to head_dim / max_seqlen_k.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.

False
seqused_q Optional[Tensor]

Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking.

None
seqused_k Optional[Tensor]

Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.

None
is_split_kv bool

Whether to enable split-KV for occupancy.

False
pack_gqa bool

Whether to pack grouped-query attention.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].

flash_sparse_attn_varlen_with_kvcache_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_k: int, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash sparse attention function for variable-length decoding with KV cache that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax. If None, defaults to head_dim / max_seqlen_k.

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization.

False
seqused_k Optional[Tensor]

Optional tensor indicating the actual sequence lengths for keys/values.

None
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, num_heads_q, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads_q].

flash_gated_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, is_adapt_gate: bool = True, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash gated attention function that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
alpha Tensor

Tensor of shape [batch_size, seqlen_q, num_heads] representing the sparsity pattern for queries.

required
delta Tensor

Tensor of shape [batch_size, seqlen_k, num_kv_heads] representing the sparsity pattern for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax.

None
gate_threshold Optional[float]

Optional threshold for the sparsity gate.

None
is_logsigmoid_gate bool

Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function.

True
is_adapt_gate bool

Whether to adapt the gate threshold based on sequence length.

True
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.

False
is_split_kv bool

Whether to enable split-KV for occupancy.

False
pack_gqa bool

Whether to pack grouped-query attention.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].

flash_gated_attn_with_kvcache_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash gated attention function for decoding with KV cache that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
alpha Tensor

Tensor of shape [batch_size, num_heads] representing the sparsity pattern for queries.

required
delta Tensor

Tensor of shape [batch_size, seqlen_k, num_kv_heads] representing the sparsity pattern for keys/values.

required
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax.

None
gate_threshold Optional[float]

Optional threshold for the sparsity gate.

None
is_logsigmoid_gate bool

Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function.

True
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads].

flash_gated_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, is_adapt_gate: bool = True, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, is_split_kv: bool = False, pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash gated attention function for variable-length sequences that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [total_seqlen_q, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
alpha Tensor

Tensor of shape [total_seqlen_q, num_heads_q] representing the sparsity pattern for queries.

required
delta Tensor

Tensor of shape [total_seqlen_k, num_heads_kv] representing the sparsity pattern for keys/values.

required
cu_seqlens_q Tensor

Cumulative sequence lengths for queries, shape [batch_size + 1].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_q int

Maximum sequence length for queries.

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax.

None
gate_threshold Optional[float]

Optional threshold for the sparsity gate.

None
is_logsigmoid_gate bool

Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function.

True
is_adapt_gate bool

Whether to adapt the gate threshold based on sequence length.

True
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.

False
seqused_q Optional[Tensor]

Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking.

None
seqused_k Optional[Tensor]

Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.

None
is_split_kv bool

Whether to enable split-KV for occupancy.

False
pack_gqa bool

Whether to pack grouped-query attention.

False
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].

flash_gated_attn_varlen_with_kvcache_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_k: int, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, query_scale: Optional[torch.Tensor] = None, key_scale: Optional[torch.Tensor] = None, value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_autotune: bool = False, skip_checks: bool = False, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash gated attention function for variable-length decoding with KV cache that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
alpha Tensor

Tensor of shape [batch_size, num_heads_q] representing the sparsity pattern for queries.

required
delta Tensor

Tensor of shape [total_seqlen_k, num_heads_kv] representing the sparsity pattern for keys/values.

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax.

None
gate_threshold Optional[float]

Optional threshold for the sparsity gate.

None
is_logsigmoid_gate bool

Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function.

True
query_scale Optional[Tensor]

Optional per-tensor scale for FP8 query dequantization.

None
key_scale Optional[Tensor]

Optional per-tensor scale for FP8 key dequantization.

None
value_scale Optional[Tensor]

Optional per-tensor scale for FP8 value dequantization.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
is_quant bool

Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization.

False
seqused_k Optional[Tensor]

Optional tensor indicating the actual sequence lengths for keys/values.

None
out Optional[Tensor]

Optional preallocated output tensor with shape [batch_size, num_heads_q, head_dim].

None
lse Optional[Tensor]

Optional preallocated logsumexp tensor with shape [batch_size, num_heads_q].

None
is_autotune bool

Whether to use Triton autotuner for kernel launch configuration.

False
skip_checks bool

Whether to skip input validation checks for faster performance.

False
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads_q].