2020#include < stdio.h>
2121
2222#include < algorithm> // std::min
23+ #include < cstdio>
2324#include < vector>
2425
2526#include " compression/compress.h"
@@ -610,7 +611,7 @@ class VitAttention {
610611 }
611612
612613 // TODO(philculliton): transition fully to MatMul.
613- HWY_NOINLINE void DotSoftmaxWeightedSum () {
614+ HWY_NOINLINE void DotSoftmaxWeightedSumMatrix () {
614615 const size_t qkv_dim = layer_config_.qkv_dim ;
615616 const size_t heads = layer_config_.heads ;
616617 HWY_ASSERT_M (heads == layer_config_.kv_heads , " Vit expects MHA" );
@@ -669,6 +670,44 @@ class VitAttention {
669670 }
670671 }
671672
673+ HWY_NOINLINE void DotSoftmaxWeightedSum () {
674+ const size_t qkv_dim = layer_config_.qkv_dim ;
675+ const size_t heads = layer_config_.heads ;
676+ HWY_ASSERT_M (heads == layer_config_.kv_heads , " Vit expects MHA" );
677+ const size_t seq_len = activations_.seq_len ;
678+ const float query_scale = 1 .0f / sqrtf (static_cast <float >(qkv_dim));
679+ PROFILER_ZONE (" Gen.VitAttention.DotSoftmax" );
680+
681+ // Compute Q.K, softmax, and weighted V.
682+ pool_.Run (0 , layer_config_.heads * num_tokens_,
683+ [&](uint64_t task, size_t /* thread*/ ) HWY_ATTR {
684+ const size_t head = task % layer_config_.heads ;
685+ const size_t token = task / layer_config_.heads ;
686+ // Compute Q.K scores, which are "logits" stored in head_att.
687+ float * HWY_RESTRICT q =
688+ activations_.q .Batch (token) + head * 3 * qkv_dim;
689+ MulByConst (query_scale, q, qkv_dim);
690+ float * HWY_RESTRICT head_att =
691+ activations_.att .Batch (token) + head * activations_.seq_len ;
692+ for (size_t i = 0 ; i < seq_len; ++i) {
693+ float * HWY_RESTRICT k =
694+ activations_.q .Batch (i) + head * 3 * qkv_dim + qkv_dim;
695+ head_att[i] = Dot (q, k, qkv_dim); // score = q.k
696+ }
697+ // SoftMax yields "probabilities" in head_att.
698+ Softmax (head_att, seq_len);
699+ // Compute weighted sum of v into att_out.
700+ float * HWY_RESTRICT att_out =
701+ activations_.att_out .Batch (token) + head * qkv_dim;
702+ hwy::ZeroBytes (att_out, qkv_dim * sizeof (*att_out));
703+ for (size_t i = 0 ; i < seq_len; ++i) {
704+ float * HWY_RESTRICT v = activations_.q .Batch (i) +
705+ head * 3 * qkv_dim + 2 * qkv_dim;
706+ MulByConstAndAdd (head_att[i], v, att_out, qkv_dim);
707+ }
708+ });
709+ }
710+
672711 // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
673712 // head_dim (`qkv_dim`) into output (`att_sums`).
674713 HWY_NOINLINE void SumHeads () {
@@ -695,7 +734,11 @@ class VitAttention {
695734
696735 HWY_INLINE void operator ()() {
697736 ComputeQKV ();
698- DotSoftmaxWeightedSum ();
737+ if (activations_.weights_config .wrapping == PromptWrapping::GEMMA_VLM) {
738+ DotSoftmaxWeightedSumMatrix ();
739+ } else {
740+ DotSoftmaxWeightedSum ();
741+ }
699742 SumHeads ();
700743 }
701744
@@ -1177,11 +1220,13 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
11771220 weights.vit_encoder_norm_bias .data_scale1 (),
11781221 activations.x .All (), vit_model_dim);
11791222
1180- activations.x = AvgPool4x4 (activations.x );
1223+ if (weights.weights_config .wrapping == PromptWrapping::GEMMA_VLM) {
1224+ activations.x = AvgPool4x4 (activations.x );
11811225
1182- // Apply soft embedding norm before input projection.
1183- RMSNormInplace (weights.mm_embed_norm .data_scale1 (), activations.x .All (),
1184- vit_model_dim);
1226+ // Apply soft embedding norm before input projection.
1227+ RMSNormInplace (weights.mm_embed_norm .data_scale1 (), activations.x .All (),
1228+ vit_model_dim);
1229+ }
11851230
11861231 // Apply head embedding into image_tokens of size of the LLM kModelDim.
11871232 MatMul (ConstMatFromBatch (activations.x .BatchSize (), activations.x ),
0 commit comments