Skip to content

API Reference

This page is generated directly from Python docstrings.

Triton Interface

flash_sparse_attn.ops.triton.interface

flash_dense_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash dense attention function that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].

flash_dense_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash dense attention function for variable-length sequences that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [total_seqlen_q, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
cu_seqlens_q Tensor

Cumulative sequence lengths for queries, shape [batch_size + 1].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_q int

Maximum sequence length for queries.

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
seqused_q Optional[Tensor]

Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking.

None
seqused_k Optional[Tensor]

Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.

None
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].

flash_sparse_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash sparse attention function that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax. If None, defaults to head_dim / seqlen_k.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].

flash_sparse_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash sparse attention function for variable-length sequences that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [total_seqlen_q, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
cu_seqlens_q Tensor

Cumulative sequence lengths for queries, shape [batch_size + 1].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_q int

Maximum sequence length for queries.

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax. If None, defaults to head_dim / max_seqlen_k.

None
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
seqused_q Optional[Tensor]

Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking.

None
seqused_k Optional[Tensor]

Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.

None
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].

flash_gated_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, is_adapt_gate: bool = True, window_size: Tuple[Optional[int], Optional[int]] = (None, None), return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash gated attention function that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim].

required
key Tensor

Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
value Tensor

Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim].

required
alpha Tensor

Tensor of shape [batch_size, num_heads, seqlen_q] representing the sparsity pattern for queries.

required
delta Tensor

Tensor of shape [batch_size, num_kv_heads, seqlen_k] representing the sparsity pattern for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax.

None
gate_threshold Optional[float]

Optional threshold for the sparsity gate.

None
is_logsigmoid_gate bool

Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function.

True
is_adapt_gate bool

Whether to adapt the gate threshold based on sequence length.

True
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].

flash_gated_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, is_adapt_gate: bool = True, window_size: Tuple[Optional[int], Optional[int]] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Flash gated attention function for variable-length sequences that computes the attention output and optionally the logsumexp.

Parameters:

Name Type Description Default
query Tensor

Query tensor of shape [total_seqlen_q, num_heads_q, head_dim].

required
key Tensor

Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
value Tensor

Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim].

required
alpha Tensor

Tensor of shape [num_heads_q, total_seqlen_q] representing the sparsity pattern for queries.

required
delta Tensor

Tensor of shape [num_heads_kv, total_seqlen_k] representing the sparsity pattern for keys/values.

required
cu_seqlens_q Tensor

Cumulative sequence lengths for queries, shape [batch_size + 1].

required
cu_seqlens_k Tensor

Cumulative sequence lengths for keys/values, shape [batch_size + 1].

required
max_seqlen_q int

Maximum sequence length for queries.

required
max_seqlen_k int

Maximum sequence length for keys/values.

required
is_causal bool

Whether to apply a causal mask.

False
softmax_scale Optional[float]

Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim).

None
softmax_threshold Optional[float]

Optional threshold for the sparse softmax.

None
gate_threshold Optional[float]

Optional threshold for the sparsity gate.

None
is_logsigmoid_gate bool

Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function.

True
is_adapt_gate bool

Whether to adapt the gate threshold based on sequence length.

True
window_size Tuple[Optional[int], Optional[int]]

Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.

(None, None)
seqused_q Optional[Tensor]

Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking.

None
seqused_k Optional[Tensor]

Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.

None
return_lse bool

Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].