[NPUW] Add block-based KV cache support for HFA and Pyramid attention#35014
Draft
intelgaoxiong wants to merge 1 commit intoopenvinotoolkit:masterfrom
Draft
[NPUW] Add block-based KV cache support for HFA and Pyramid attention#35014intelgaoxiong wants to merge 1 commit intoopenvinotoolkit:masterfrom
intelgaoxiong wants to merge 1 commit intoopenvinotoolkit:masterfrom
Conversation
Extend Host Flash Attention (HFA) and Pyramid attention to operate with block-split KV cache produced by SplitKVCacheIntoBlocks. Section 1 - Shared infrastructure: - util.hpp/cpp: rename isPastKeyValuesKey/Value to isPastKeyParam/isPastValueParam; add isPastKeyParamContiguous / isPastValueParamContiguous for non-block contexts - sdpa_utils.hpp/cpp: new file, extract shared SDPA parameter utilities (previously duplicated between pyramid_attention and host_flash_attention) - attention.hpp: extend SDPAIndices with past_key_blocks/past_value_blocks vectors; extend Attention struct with per-variant block indices for Pyramid Section 2 - Host Flash Attention: - host_flash_attention.cpp/hpp: loop over all Concat inputs in build_sdpa_param_mapping() to collect _past_key_block_indices / _past_value_block_indices; switch #include from pyramid_attention to sdpa_utils - base_sync_infer_request.cpp: replace scalar past_key/past_value checks with is_past_kv() lambda that searches block-index vectors Section 3 - Pyramid Attention: - pyramid_attention.cpp/hpp: add is_block_split path in process_pyramid_model() that shrinks each pyramid-variant Concat to keep only idx past blocks; collect_concat_block_indices() helper; populate past_key/value_block_*_indices - base_sync_infer_request.cpp: add block_mode + bind_block_ports() lambda in bind_pyramid_attention_inputs() - just_sync_infer_request.cpp: share_kv_block_buffers() for pyramid variants - partitioning/patterns/sdpa.cpp: relax Concat input-count check to support multi-block inputs Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Details:
Extends Host Flash Attention (HFA) and Pyramid attention to operate with
the block-split KV cache produced by
SplitKVCacheIntoBlocks(Part 1/4).Section 1 — Shared infrastructure:
util.hpp/cpp: renameisPastKeyValuesKey/Value→isPastKeyParam/isPastValueParam; addisPastKeyParamContiguous/isPastValueParamContiguoussdpa_utils.hpp/cpp: new file extracting shared SDPA parameter utilities (previously duplicated between HFA and Pyramid)attention.hpp: extendSDPAIndiceswithpast_key_blocks/past_value_blocksvectors; extendAttentionstruct with per-variant block indicesSection 2 — Host Flash Attention:
host_flash_attention.cpp: loop over all Concat inputs inbuild_sdpa_param_mapping()to collect_past_key_block_indices/_past_value_block_indicesbase_sync_infer_request.cpp: replace scalar past_key/past_value checks withis_past_kv()lambdaSection 3 — Pyramid Attention:
pyramid_attention.cpp: addis_block_splitpath inprocess_pyramid_model()that shrinks each pyramid-variant Concat toidxpast blocks;collect_concat_block_indices()helperbase_sync_infer_request.cpp: addblock_mode+bind_block_ports()lambda inbind_pyramid_attention_inputs()just_sync_infer_request.cpp:share_kv_block_buffers()for pyramid variantspartitioning/patterns/sdpa.cpp: relax Concat input-count check for multi-block inputsThis is part 3/4 of the block-based KV cache feature split.
Tickets:
AI Assistance: