API Reference¶
This page is generated directly from Python docstrings.
Triton Interface¶
flash_sparse_attn.ops.triton.interface
¶
flash_dense_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
¶
Flash dense attention function that computes the attention output and optionally the logsumexp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Tensor
|
Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim]. |
required |
key
|
Tensor
|
Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim]. |
required |
value
|
Tensor
|
Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim]. |
required |
is_causal
|
bool
|
Whether to apply a causal mask. |
False
|
softmax_scale
|
Optional[float]
|
Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim). |
None
|
window_size
|
Tuple[Optional[int], Optional[int]]
|
Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. |
(None, None)
|
return_lse
|
bool
|
Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. |
False
|
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. |
flash_dense_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
¶
Flash dense attention function for variable-length sequences that computes the attention output and optionally the logsumexp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Tensor
|
Query tensor of shape [total_seqlen_q, num_heads_q, head_dim]. |
required |
key
|
Tensor
|
Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim]. |
required |
value
|
Tensor
|
Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim]. |
required |
cu_seqlens_q
|
Tensor
|
Cumulative sequence lengths for queries, shape [batch_size + 1]. |
required |
cu_seqlens_k
|
Tensor
|
Cumulative sequence lengths for keys/values, shape [batch_size + 1]. |
required |
max_seqlen_q
|
int
|
Maximum sequence length for queries. |
required |
max_seqlen_k
|
int
|
Maximum sequence length for keys/values. |
required |
is_causal
|
bool
|
Whether to apply a causal mask. |
False
|
softmax_scale
|
Optional[float]
|
Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim). |
None
|
window_size
|
Tuple[Optional[int], Optional[int]]
|
Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. |
(None, None)
|
seqused_q
|
Optional[Tensor]
|
Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking. |
None
|
seqused_k
|
Optional[Tensor]
|
Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking. |
None
|
return_lse
|
bool
|
Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. |
False
|
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. |
flash_sparse_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
¶
Flash sparse attention function that computes the attention output and optionally the logsumexp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Tensor
|
Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim]. |
required |
key
|
Tensor
|
Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim]. |
required |
value
|
Tensor
|
Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim]. |
required |
is_causal
|
bool
|
Whether to apply a causal mask. |
False
|
softmax_scale
|
Optional[float]
|
Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim). |
None
|
softmax_threshold
|
Optional[float]
|
Optional threshold for the sparse softmax. If None, defaults to head_dim / seqlen_k. |
None
|
window_size
|
Tuple[Optional[int], Optional[int]]
|
Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. |
(None, None)
|
return_lse
|
bool
|
Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. |
False
|
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. |
flash_sparse_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
¶
Flash sparse attention function for variable-length sequences that computes the attention output and optionally the logsumexp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Tensor
|
Query tensor of shape [total_seqlen_q, num_heads_q, head_dim]. |
required |
key
|
Tensor
|
Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim]. |
required |
value
|
Tensor
|
Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim]. |
required |
cu_seqlens_q
|
Tensor
|
Cumulative sequence lengths for queries, shape [batch_size + 1]. |
required |
cu_seqlens_k
|
Tensor
|
Cumulative sequence lengths for keys/values, shape [batch_size + 1]. |
required |
max_seqlen_q
|
int
|
Maximum sequence length for queries. |
required |
max_seqlen_k
|
int
|
Maximum sequence length for keys/values. |
required |
is_causal
|
bool
|
Whether to apply a causal mask. |
False
|
softmax_scale
|
Optional[float]
|
Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim). |
None
|
softmax_threshold
|
Optional[float]
|
Optional threshold for the sparse softmax. If None, defaults to head_dim / max_seqlen_k. |
None
|
window_size
|
Tuple[Optional[int], Optional[int]]
|
Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. |
(None, None)
|
seqused_q
|
Optional[Tensor]
|
Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking. |
None
|
seqused_k
|
Optional[Tensor]
|
Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking. |
None
|
return_lse
|
bool
|
Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. |
False
|
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. |
flash_gated_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, is_adapt_gate: bool = True, window_size: Tuple[Optional[int], Optional[int]] = (None, None), return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
¶
Flash gated attention function that computes the attention output and optionally the logsumexp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Tensor
|
Query tensor of shape [batch_size, seqlen_q, num_heads, head_dim]. |
required |
key
|
Tensor
|
Key tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim]. |
required |
value
|
Tensor
|
Value tensor of shape [batch_size, seqlen_k, num_kv_heads, head_dim]. |
required |
alpha
|
Tensor
|
Tensor of shape [batch_size, num_heads, seqlen_q] representing the sparsity pattern for queries. |
required |
delta
|
Tensor
|
Tensor of shape [batch_size, num_kv_heads, seqlen_k] representing the sparsity pattern for keys/values. |
required |
is_causal
|
bool
|
Whether to apply a causal mask. |
False
|
softmax_scale
|
Optional[float]
|
Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim). |
None
|
softmax_threshold
|
Optional[float]
|
Optional threshold for the sparse softmax. |
None
|
gate_threshold
|
Optional[float]
|
Optional threshold for the sparsity gate. |
None
|
is_logsigmoid_gate
|
bool
|
Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function. |
True
|
is_adapt_gate
|
bool
|
Whether to adapt the gate threshold based on sequence length. |
True
|
window_size
|
Tuple[Optional[int], Optional[int]]
|
Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. |
(None, None)
|
return_lse
|
bool
|
Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. |
False
|
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. |
flash_gated_attn_varlen_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, alpha: torch.Tensor, delta: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, is_causal: bool = False, softmax_scale: Optional[float] = None, softmax_threshold: Optional[float] = None, gate_threshold: Optional[float] = None, is_logsigmoid_gate: bool = True, is_adapt_gate: bool = True, window_size: Tuple[Optional[int], Optional[int]] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, return_lse: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
¶
Flash gated attention function for variable-length sequences that computes the attention output and optionally the logsumexp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Tensor
|
Query tensor of shape [total_seqlen_q, num_heads_q, head_dim]. |
required |
key
|
Tensor
|
Key tensor of shape [total_seqlen_k, num_heads_kv, head_dim]. |
required |
value
|
Tensor
|
Value tensor of shape [total_seqlen_k, num_heads_kv, head_dim]. |
required |
alpha
|
Tensor
|
Tensor of shape [num_heads_q, total_seqlen_q] representing the sparsity pattern for queries. |
required |
delta
|
Tensor
|
Tensor of shape [num_heads_kv, total_seqlen_k] representing the sparsity pattern for keys/values. |
required |
cu_seqlens_q
|
Tensor
|
Cumulative sequence lengths for queries, shape [batch_size + 1]. |
required |
cu_seqlens_k
|
Tensor
|
Cumulative sequence lengths for keys/values, shape [batch_size + 1]. |
required |
max_seqlen_q
|
int
|
Maximum sequence length for queries. |
required |
max_seqlen_k
|
int
|
Maximum sequence length for keys/values. |
required |
is_causal
|
bool
|
Whether to apply a causal mask. |
False
|
softmax_scale
|
Optional[float]
|
Optional scaling factor for the softmax. If None, defaults to 1/sqrt(head_dim). |
None
|
softmax_threshold
|
Optional[float]
|
Optional threshold for the sparse softmax. |
None
|
gate_threshold
|
Optional[float]
|
Optional threshold for the sparsity gate. |
None
|
is_logsigmoid_gate
|
bool
|
Whether to use a log-sigmoid function for the sparsity gate. If False, uses a linear function. |
True
|
is_adapt_gate
|
bool
|
Whether to adapt the gate threshold based on sequence length. |
True
|
window_size
|
Tuple[Optional[int], Optional[int]]
|
Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. |
(None, None)
|
seqused_q
|
Optional[Tensor]
|
Optional tensor of shape [total_seqlen_q] indicating the actual sequence lengths for queries. If provided, overrides cu_seqlens_q for masking. |
None
|
seqused_k
|
Optional[Tensor]
|
Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking. |
None
|
return_lse
|
bool
|
Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. |
False
|
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. |