Skip to content

Commit a41fbe2

Browse files
theraysmithcopybara-github
authored andcommitted
Increased max_tbatch_size to kMaxBatchSize. Gives 1.5x speed-up for prefill on both intel and AMD machines
Shrank intermediate arrays used in matmul to reduce memory use. PiperOrigin-RevId: 899532036
1 parent a29e2fc commit a41fbe2

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

gemma/gemma_args.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ struct RuntimeConfig {
134134

135135
// These defaults are overridden by InferenceArgs::CopyTo(*this):
136136
// Max tokens per batch during prefill.
137-
size_t prefill_tbatch_size = 256;
137+
size_t prefill_tbatch_size = kMaxBatchSize;
138138
// Max queries per batch (one token from each) during decode.
139139
size_t decode_qbatch_size = 16;
140140

@@ -225,7 +225,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
225225
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
226226
"Maximum number of tokens to generate.");
227227

228-
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
228+
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{4096},
229229
"Prefill: max tokens per batch.");
230230
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
231231
"Decode: max queries per batch.");

ops/matmul.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ HWY_INLINE_VAR constexpr size_t kNR = 4;
5454
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
5555

5656
// For `MMTilesC`.
57-
HWY_INLINE_VAR constexpr size_t kMaxMC = 512;
58-
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
57+
HWY_INLINE_VAR constexpr size_t kMaxMC = 256;
58+
HWY_INLINE_VAR constexpr size_t kMaxNC = 6 * 1024;
5959

6060
// Upper bound for per-worker B storage on the stack. Chosen such that one row
6161
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
62-
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
62+
HWY_INLINE_VAR constexpr size_t kMaxKC = 6 * 1024;
6363

6464
// Policy classes for parallelism, implementing some of `Parallelism`.
6565

0 commit comments

Comments
 (0)