Skip to content

Commit 1b1b63d

Browse files
pcullitoncopybara-github
authored andcommitted
Fix PaliGemma models.
PiperOrigin-RevId: 736483021
1 parent 0ff6b31 commit 1b1b63d

4 files changed

Lines changed: 59 additions & 8 deletions

File tree

gemma/configs.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ ModelConfig GetVitConfig(const ModelConfig& config) {
291291
vit_config.seq_len = config.vit_config.seq_len;
292292
vit_config.layer_configs = config.vit_config.layer_configs;
293293
vit_config.pool_dim = config.vit_config.pool_dim;
294+
vit_config.wrapping = config.wrapping;
294295
// The Vit part does not have a vocabulary, the image patches are embedded.
295296
vit_config.vocab_size = 0;
296297
return vit_config;

gemma/gemma-inl.h

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <stdio.h>
2121

2222
#include <algorithm> // std::min
23+
#include <cstdio>
2324
#include <vector>
2425

2526
#include "compression/compress.h"
@@ -610,7 +611,7 @@ class VitAttention {
610611
}
611612

612613
// TODO(philculliton): transition fully to MatMul.
613-
HWY_NOINLINE void DotSoftmaxWeightedSum() {
614+
HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() {
614615
const size_t qkv_dim = layer_config_.qkv_dim;
615616
const size_t heads = layer_config_.heads;
616617
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
@@ -669,6 +670,44 @@ class VitAttention {
669670
}
670671
}
671672

673+
HWY_NOINLINE void DotSoftmaxWeightedSum() {
674+
const size_t qkv_dim = layer_config_.qkv_dim;
675+
const size_t heads = layer_config_.heads;
676+
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
677+
const size_t seq_len = activations_.seq_len;
678+
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
679+
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
680+
681+
// Compute Q.K, softmax, and weighted V.
682+
pool_.Run(0, layer_config_.heads * num_tokens_,
683+
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
684+
const size_t head = task % layer_config_.heads;
685+
const size_t token = task / layer_config_.heads;
686+
// Compute Q.K scores, which are "logits" stored in head_att.
687+
float* HWY_RESTRICT q =
688+
activations_.q.Batch(token) + head * 3 * qkv_dim;
689+
MulByConst(query_scale, q, qkv_dim);
690+
float* HWY_RESTRICT head_att =
691+
activations_.att.Batch(token) + head * activations_.seq_len;
692+
for (size_t i = 0; i < seq_len; ++i) {
693+
float* HWY_RESTRICT k =
694+
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
695+
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
696+
}
697+
// SoftMax yields "probabilities" in head_att.
698+
Softmax(head_att, seq_len);
699+
// Compute weighted sum of v into att_out.
700+
float* HWY_RESTRICT att_out =
701+
activations_.att_out.Batch(token) + head * qkv_dim;
702+
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
703+
for (size_t i = 0; i < seq_len; ++i) {
704+
float* HWY_RESTRICT v = activations_.q.Batch(i) +
705+
head * 3 * qkv_dim + 2 * qkv_dim;
706+
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
707+
}
708+
});
709+
}
710+
672711
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
673712
// head_dim (`qkv_dim`) into output (`att_sums`).
674713
HWY_NOINLINE void SumHeads() {
@@ -695,7 +734,11 @@ class VitAttention {
695734

696735
HWY_INLINE void operator()() {
697736
ComputeQKV();
698-
DotSoftmaxWeightedSum();
737+
if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
738+
DotSoftmaxWeightedSumMatrix();
739+
} else {
740+
DotSoftmaxWeightedSum();
741+
}
699742
SumHeads();
700743
}
701744

@@ -1177,11 +1220,13 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
11771220
weights.vit_encoder_norm_bias.data_scale1(),
11781221
activations.x.All(), vit_model_dim);
11791222

1180-
activations.x = AvgPool4x4(activations.x);
1223+
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
1224+
activations.x = AvgPool4x4(activations.x);
11811225

1182-
// Apply soft embedding norm before input projection.
1183-
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
1184-
vit_model_dim);
1226+
// Apply soft embedding norm before input projection.
1227+
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
1228+
vit_model_dim);
1229+
}
11851230

11861231
// Apply head embedding into image_tokens of size of the LLM kModelDim.
11871232
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),

gemma/run.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
217217
timing_info);
218218
std::cout << "\n\n";
219219

220-
// Prepare for the next turn.
220+
// Prepare for the next turn. Works only for PaliGemma.
221221
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
222+
abs_pos = 0; // Start a new turn at position 0.
223+
InitGenerator(args, gen);
224+
} else {
222225
// The last token was either EOS, then it should be ignored because it is
223226
// never part of the dialog, see Table 5 in the Gemma-2 paper:
224227
// https://arxiv.org/pdf/2408.00118

gemma/weights.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,9 @@ struct ModelWeightsPtrs {
500500
GEMMA_CALL_FUNC(vit_img_pos_embedding);
501501
GEMMA_CALL_FUNC(vit_img_head_bias);
502502
GEMMA_CALL_FUNC(vit_img_head_kernel);
503-
GEMMA_CALL_FUNC(mm_embed_norm);
503+
504+
if (ptrs[0]->weights_config.wrapping == PromptWrapping::GEMMA_VLM)
505+
GEMMA_CALL_FUNC(mm_embed_norm);
504506
}
505507

506508
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {

0 commit comments

Comments
 (0)