1 Star 0 Fork 0

Hugging Face 模型镜像/MiniCPM4-MCP

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
compressed_attention.py 42.97 KB
一键复制 编辑 原始数据 按行查看 历史
BIGWangYuDong 提交于 2025-06-06 15:28 +08:00 . update README.md
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402
# coding=utf-8
# Copyright 2025 The OpenBMB Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Tuple, Union
from collections import Counter
import torch
import triton
import triton.language as tl
import warnings
from torch import nn
def is_hopper_gpu():
if torch.cuda.is_available():
device_capability = torch.cuda.get_device_capability()
major, minor = device_capability
return major == 9
return False
def get_compressed_seqlens(
cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int
):
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.zeros(
y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device
)
y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0)
return y_seqlens, y_cu_seqlens
def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
"""
Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
Args:
head_dim (int): Size of the head dimension.
block_size (int): Size of the block in the attention matrix.
is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
Returns:
tuple: (num_warps, num_stages) recommended values.
"""
# Determine if head_dim and block_size exceed 64
head_large = head_dim > 64
block_large = block_size > 64
if is_hopper_gpu:
# Hopper GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 4
num_stages = 3
else:
num_warps = 2
num_stages = 2
else:
# Ampere GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 8
num_stages = 3
else:
num_warps = 2
num_stages = 2
return num_warps, num_stages
IS_HOPPER_GPU = is_hopper_gpu()
@triton.jit
def forward_kernel(
q_ptr, # Q: n x h x d
k_ptr, # K: n x h x d
v_ptr, # V: n x h x d
o_ptr, # O: n x h x d
lse_ptr, # LSE: h x n
# size and stride at compresstion
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_on,
stride_oh,
stride_od,
stride_lh,
stride_ln,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
# attention
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
i = tl.multiple_of(i, BLOCK_SIZE_K)
# load k
k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
)
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
# update ptrs
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
# final scale
acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + q_start * stride_on + pid_h * stride_oh,
shape=(q_len, HEAD_DIM),
strides=(stride_on, stride_od),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
# save lse
l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
@triton.jit
def backward_sum_o_do(
o_ptr, # O: n x h x d
do_ptr, # dO: n x h x d
delta_ptr, # D: h x n
o_len,
HEAD_DIM,
stride_on,
stride_oh,
stride_od,
stride_don,
stride_doh,
stride_dod,
stride_dh,
stride_dn,
BLOCK_SIZE_O: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_h = tl.program_id(1)
off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
off_d = tl.arange(0, BLOCK_SIZE_D)
o = tl.load(
o_ptr
+ off_n[:, None] * stride_on
+ pid_h * stride_oh
+ off_d[None, :] * stride_od,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
do = tl.load(
do_ptr
+ off_n[:, None] * stride_don
+ pid_h * stride_doh
+ off_d[None, :] * stride_dod,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(
delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
)
@triton.jit
def backward_dkdv(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dk_ptr, # DK: sh x n x kh x d
dv_ptr, # DV: sh x n x kh x d
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dks,
stride_dkn,
stride_dkh,
stride_dkd,
stride_dvs,
stride_dvn,
stride_dvh,
stride_dvd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
pid_sh = pid_h % NUM_SHARE_Q_HEADS
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# init pointers
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dk_ptrs = tl.make_block_ptr(
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
shape=(k_len, HEAD_DIM),
strides=(stride_dkn, stride_dkd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dv_ptrs = tl.make_block_ptr(
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
shape=(k_len, HEAD_DIM),
strides=(stride_dvn, stride_dvd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q)
off_k = (
pid_k * BLOCK_SIZE_K * kernel_stride
+ tl.arange(0, BLOCK_SIZE_K) * kernel_stride
+ kernel_size
- 1
)
# load k v and keep in SRAM
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dk dv
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(HEAD_DIM, q_len),
strides=(stride_qd, stride_qn),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(HEAD_DIM, q_len),
strides=(stride_dod, stride_don),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(1, q_len),
strides=(0, stride_dn),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(1, q_len),
strides=(0, stride_ln),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(0, 1),
)
# loop for q blocks
for i in range(q_lo, q_len, BLOCK_SIZE_Q):
# load
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
qk += tl.dot(k, q) * qk_scale
# compute p, ds
# [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
p = tl.exp2(qk - lse)
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
dp = tl.dot(v, do)
ds = sm_scale * p * (dp - d)
# cast dtype
p = p.to(do.dtype)
ds = ds.to(q.dtype)
# update dk and dv
# [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
dk += tl.dot(ds, tl.trans(q))
dv += tl.dot(p, tl.trans(do))
# increment pointers
q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
# save dk dv
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def backward_dq(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dq_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dqn,
stride_dqh,
stride_dqd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init pointers
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
dq_ptrs = tl.make_block_ptr(
base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
shape=(q_len, HEAD_DIM),
strides=(stride_dqn, stride_dqd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(HEAD_DIM, k_len),
strides=(stride_vd, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(q_len, HEAD_DIM),
strides=(stride_don, stride_dod),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(q_len, 1),
strides=(stride_dn, stride_dh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
# load q, do, lse, delta, and keep in SRAM
q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dq
dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
# load
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
)
qk += tl.dot(q, tl.trans(k)) * qk_scale
# compute p, ds
p = tl.exp2(qk - lse)
dp = tl.dot(do, v)
ds = sm_scale * p * (dp - d)
# cast dtype
ds = ds.to(q.dtype)
# update dq
dq += tl.dot(ds, k)
# increment pointers
k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
# save dq
tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
def _compressed_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: torch.Tensor,
max_seqlen_k: torch.Tensor,
sm_scale: float,
):
# dtype check
assert k.dtype == q.dtype and v.dtype == q.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert k_len == v_len and q_len > k_len
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# output tensor
o = torch.zeros_like(q)
lse = torch.full(
(num_q_heads, q_len),
fill_value=-torch.inf,
dtype=torch.float32,
device=q.device,
)
# launch kernel
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
forward_kernel[grid](
q,
k,
v,
o,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
lse.stride(0),
lse.stride(1),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
return o, lse
def _compressed_attention_bwd(
o: torch.Tensor,
do: torch.Tensor,
lse: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: torch.Tensor,
max_seqlen_k: torch.Tensor,
sm_scale: float,
):
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
o_len, num_o_heads, head_dim = o.shape
num_share_q_heads = num_q_heads // num_k_heads
# compute D
delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
BLOCK_SIZE_O = 256
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
backward_sum_o_do[grid](
o,
do,
delta,
o_len,
head_dim,
o.stride(0),
o.stride(1),
o.stride(2),
do.stride(0),
do.stride(1),
do.stride(2),
delta.stride(0),
delta.stride(1),
BLOCK_SIZE_O=BLOCK_SIZE_O,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
# compute dk dv
dk = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
dv = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
batch_size = cu_seqlens_q.shape[0] - 1
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 64
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
backward_dkdv[grid](
q,
k,
v,
lse,
delta,
do,
dk,
dv,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dk.stride(0),
dk.stride(1),
dk.stride(2),
dk.stride(3),
dv.stride(0),
dv.stride(1),
dv.stride(2),
dv.stride(3),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
dk = dk.sum(0)
dv = dv.sum(0)
# compute dq
dq = torch.zeros_like(q)
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 64
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
backward_dq[grid](
q,
k,
v,
lse,
delta,
do,
dq,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dq.stride(0),
dq.stride(1),
dq.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
return dq, dk, dv
class CompressedAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: torch.Tensor,
max_seqlen_k: torch.Tensor,
sm_scale=None,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype and k.dtype == v.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# softmax scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(q.shape[-1])
o, lse = _compressed_attention_fwd(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
ctx.sm_scale = sm_scale
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.kernel_size = kernel_size
ctx.kernel_stride = kernel_stride
return o, lse
@staticmethod
def backward(ctx, do: torch.Tensor, *args) -> Any:
q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
max_seqlen_q = ctx.max_seqlen_q
max_seqlen_k = ctx.max_seqlen_k
sm_scale = ctx.sm_scale
kernel_size = ctx.kernel_size
kernel_stride = ctx.kernel_stride
dq, dk, dv = _compressed_attention_bwd(
o,
do,
lse,
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
return dq, dk, dv, None, None, None, None, None, None, None
@triton.jit
def score_kernel(
q_ptr,
k_ptr,
lse_ptr,
s_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_lh,
stride_ln,
stride_sh,
stride_sq,
stride_sk,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_bkh = tl.program_id(0)
pid_b = pid_bkh // NUM_KV_HEADS
pid_kh = pid_bkh % NUM_KV_HEADS
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
return
# init k pointer and load k
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
# init score
s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
# loop over gqa heads
for h in range(NUM_SHARE_Q_HEADS):
pid_h = pid_kh * NUM_SHARE_Q_HEADS + h
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# load q and lse
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.dot(q, k) * qk_scale
# compute score
s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
# save output
s_ptrs = tl.make_block_ptr(
base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
shape=(q_len, k_len),
strides=(stride_sq, stride_sk),
offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
order=(1, 0),
)
tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
def _get_attention_score(
q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
lse: torch.Tensor, # [num_q_heads, total_query_len]
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
) -> torch.Tensor:
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
assert (
lse.dtype == torch.float32
) # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert q_len > k_len
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# gqa
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# init score
score = torch.zeros(
num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device
)
# launch kernel
grid = lambda META: (
batch_size * num_k_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
score_kernel[grid](
q,
k,
lse,
score,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
lse.stride(0),
lse.stride(1),
score.stride(0),
score.stride(1),
score.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=8,
num_stages=3,
)
return score
@triton.jit
def _transform_score_kernel(
s_ptr, # score, shape: [num_heads, q_len, k_len]
bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
offs,
cu_seqlens_q,
# shape
num_heads,
num_offs,
max_k_len,
max_blocks,
pad_len,
# kernel & block size
block_size,
block_stride, # block_size // kernel_stride
init_blocks,
local_blocks,
# stride
stride_sh,
stride_sq,
stride_sk,
stride_bsh,
stride_bsq,
stride_bsk,
BLOCK_SIZE_Q: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_O: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_b = pid_bh // num_heads
pid_h = pid_bh % num_heads
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = pid_k * BLOCK_SIZE_K
if pid_q * BLOCK_SIZE_Q >= q_len:
return
# load weight
off_o = tl.arange(0, BLOCK_SIZE_O)
w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
# load score
off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
off_k = off_k[None, :] + off_o[:, None]
s_ptrs = (
s_ptr
+ q_start * stride_sq
+ pid_h * stride_sh
+ off_q[:, None, None] * stride_sq
+ off_k[None, :, :] * stride_sk
)
# weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
s = tl.load(
s_ptrs,
mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
other=0,
)
s = s * w[None, :, None]
s = tl.max(s, axis=1)
# init mask and local mask
off_bq = off_q // block_size
off_bk = tl.arange(0, BLOCK_SIZE_K)
s = tl.where(
# For local blocks: set to negative infinity (exclude from topk)
(off_bq[:, None] >= (off_bk + k_start)[None, :]) & (off_bq[:, None] < (off_bk + k_start)[None, :] + local_blocks),
float("-inf"),
s,
)
# Keep the original conditions for init_blocks and query location as infinity
s = tl.where(
(off_bk[None, :] < init_blocks - k_start)
# Force blocks where the query is located to have infinite score (always include in topk)
| (off_bq[:, None] == (off_bk + k_start)[None, :]),
float("inf"),
s,
)
# store block wise score
bs_ptrs = (
bs_ptr
+ q_start * stride_bsq
+ k_start * stride_bsk
+ pid_h * stride_bsh
+ off_q[:, None] * stride_bsq
+ off_bk[None, :] * stride_bsk
)
tl.store(
bs_ptrs,
s,
mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start)[None, :],
)
def transform_score(
score: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, total_query_len, max_key_len = score.shape
batch_size = cu_seqlens_q.shape[0] - 1
pad_len = kernel_size // kernel_stride - 1
max_blocks = math.ceil(max_seqlen_q / block_size)
block_score = torch.zeros(
num_k_heads,
total_query_len,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
+ torch.arange(block_size // kernel_stride, device=score.device)[None, :]
).view(-1)
offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
num_offs = int(offs.shape[0])
BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
BLOCK_SIZE_Q = 8
grid = (
num_k_heads * batch_size,
triton.cdiv(total_query_len, BLOCK_SIZE_Q),
triton.cdiv(max_blocks, BLOCK_SIZE_K),
)
_transform_score_kernel[grid](
score,
block_score,
torch.ones_like(offs, dtype=offs.dtype,device=offs.device), #! 为了max 就不用wieght了
cu_seqlens_q,
num_k_heads,
offs.shape[0],
max_key_len,
max_blocks,
pad_len,
block_size,
block_size // kernel_stride,
init_blocks,
local_blocks,
score.stride(0),
score.stride(1),
score.stride(2),
block_score.stride(0),
block_score.stride(1),
block_score.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_O=BLOCK_SIZE_O,
num_warps=8,
num_stages=3,
)
return block_score
def compressed_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float = None,
init_blocks: int = 1,
local_blocks: int = 2,
parallel_topk_compute: Union[str, bool] = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
Args:
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
max_seqlen_q (int): max q len of the batch.
max_seqlen_k (int): max k len of the batch.
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
"""
if max_seqlen_q is None:
max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
if max_seqlen_k is None:
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
attn_output, lse = CompressedAttention.apply(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# do not select topk index
if topk <= 0:
warnings.warn("topk <= 0, returned topk_idx will be None")
return attn_output, None
assert topk >= init_blocks #+ local_blocks
with torch.no_grad():
num_k_heads, num_q_heads = k.shape[1], q.shape[1]
num_shared_q_heads = num_q_heads // num_k_heads
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
for i in range(batch_size)
],
dim=0,
)
q_idx = q_idx // block_size
# whether to use parallel version
if parallel_topk_compute == "auto":
parallel_topk_compute = cu_seqlens_q[-1] <= 32768
# parallel version
if parallel_topk_compute:
# recompute score
score = _get_attention_score(
q,
k,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
# print(cu_seqlens_q)
# breakpoint()
topk_idx[topk_idx >= q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
# non parallel version, avoid some current bugs when sequence length is too long
# FIXME: need to fix later
else:
topk_idx_list = []
for h in range(num_k_heads):
# recompute score
score = _get_attention_score(
q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
k[:, h : h + 1],
lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx >= q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
topk_idx_list.append(topk_idx)
topk_idx = torch.cat(topk_idx_list, dim=0)
return attn_output, topk_idx
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/hf-models/MiniCPM4-MCP.git
git@gitee.com:hf-models/MiniCPM4-MCP.git
hf-models
MiniCPM4-MCP
MiniCPM4-MCP
main

搜索帮助