-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrosamodels.py
More file actions
482 lines (386 loc) · 21.9 KB
/
rosamodels.py
File metadata and controls
482 lines (386 loc) · 21.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
import torch
import math
import copy
from torch import nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv
from transformers.models.qwen2.modeling_qwen2 import Qwen2SdpaAttention, Qwen2Attention
from transformers.models.gemma2.modeling_gemma2 import Gemma2Attention, sdpa_attention_forward, eager_attention_forward
from typing import List, Optional, Tuple, Union
from transformers import LlamaForCausalLM
from transformers.utils import logging
from transformers.cache_utils import Cache
from tqdm import tqdm
logger = logging.get_logger()
class LoraLinear(nn.Module):
"""
Replaces a linear layer with a LoRA-adapted version.
"""
def __init__(self, in_features, out_features, rank, lora_alpha=1):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_A = nn.Linear(self.in_features, rank, bias=False)
self.lora_B = nn.Linear(rank, self.out_features, bias=False)
# Scaling factor
self.scaling = self.lora_alpha / self.rank
# Initialize LoRA weights
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
lora_output = self.lora_B(self.lora_A(x)) * self.scaling
return lora_output
class LlamaFreqAttnwithEager(LlamaAttention):
def __init__(self, config, layer_idx,
low_ratio: float,
alpha: float, outargs):
super().__init__(config, layer_idx)
self.outargs = outargs
self.low_dim = int(self.head_dim * low_ratio)
self.alpha = alpha
proj_out = self.num_heads * self.low_dim
if self.outargs.use_lora_gate:
self.shared_gate_proj = LoraLinear(config.hidden_size, proj_out, rank=self.outargs.lora_dim)
else:
self.shared_gate_proj = nn.Linear(config.hidden_size, proj_out, bias=False)
if self.num_key_value_groups != 1:
print("!!!!!#WARNING: GQA Using")
self.gqa_proj = LoraLinear(proj_out, proj_out // self.num_key_value_groups, rank=self.outargs.lora_dim)
def _apply_gate(self, states, gate):
#query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
#q_gate: torch.Size([1, 21, 32, low]) [bsz, num_heads, q_len, low_dim]
head_dim = states.shape[-1]
half_dim = head_dim // 2
low_half_dim = self.low_dim // 2
part1 = states[..., :half_dim]
part2 = states[..., half_dim:]
low_freq_part1 = part1[..., -low_half_dim:]
low_freq_part2 = part2[..., -low_half_dim:]
low_freq_combined = torch.cat([low_freq_part1, low_freq_part2], dim=-1)
scale = 1. + self.alpha * gate
modified_low_freq = low_freq_combined * scale
modified_part1 = modified_low_freq[..., :low_half_dim]
modified_part2 = modified_low_freq[..., low_half_dim:]
states[..., half_dim - low_half_dim : half_dim] = modified_part1
states[..., -low_half_dim:] = modified_part2
return states
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() # bsz, q_len, _ 1 21 4096
query_states = self.q_proj(hidden_states) # query_states: torch.Size([1, 21, 4096])
key_states = self.k_proj(hidden_states) # key_states: torch.Size([1, 21, 4096/1024(group)])
value_states = self.v_proj(hidden_states)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
shared_gate = self.shared_gate_proj(hidden_states) #Hidden: # bsz, q_len, _ 1 21 4096
q_gate = shared_gate
k_gate = shared_gate
if self.num_key_value_groups != 1:
k_gate = self.gqa_proj(k_gate)
q_gate = F.silu(q_gate).view(bsz, -1, q_len, self.low_dim)
k_gate = F.silu(k_gate).view(bsz, -1, q_len, self.low_dim)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
# outcos: torch.Size([1, 21, 128])
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
# key_states: torch.Size([1, 32/8, 21, 128]) [bsz, num_heads or num_heads/group, q_len, head_dim]
query_states = self._apply_gate(query_states, q_gate)
key_states = self._apply_gate(key_states, k_gate)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# raise NotImplementedError()
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
# key_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class QwenFreqAttnwithEager(Qwen2Attention):
def __init__(self, config, layer_idx,
low_ratio: float,
alpha: float, outargs):
super().__init__(config, layer_idx)
self.outargs = outargs
self.low_dim = int(self.head_dim * low_ratio)
self.alpha = alpha
proj_out = self.num_heads * self.low_dim
if self.outargs.use_lora_gate:
self.shared_gate_proj = LoraLinear(config.hidden_size, proj_out, rank=self.outargs.lora_dim)
else:
self.shared_gate_proj = nn.Linear(config.hidden_size, proj_out, bias=False)
if self.num_key_value_groups != 1:
print("!!!!!#WARNING: GQA Using")
self.gqa_proj = LoraLinear(proj_out, proj_out // self.num_key_value_groups, rank=self.outargs.lora_dim)
def _apply_gate(self, states, gate):
#query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
#q_gate: torch.Size([1, 21, 32, low]) [bsz, num_heads, q_len, low_dim]
head_dim = states.shape[-1]
half_dim = head_dim // 2
low_half_dim = self.low_dim // 2
part1 = states[..., :half_dim]
part2 = states[..., half_dim:]
low_freq_part1 = part1[..., -low_half_dim:]
low_freq_part2 = part2[..., -low_half_dim:]
low_freq_combined = torch.cat([low_freq_part1, low_freq_part2], dim=-1)
scale = 1. + self.alpha * gate
modified_low_freq = low_freq_combined * scale
modified_part1 = modified_low_freq[..., :low_half_dim]
modified_part2 = modified_low_freq[..., low_half_dim:]
states[..., half_dim - low_half_dim : half_dim] = modified_part1
states[..., -low_half_dim:] = modified_part2
return states
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() # bsz, q_len, _ 1 21 4096
query_states = self.q_proj(hidden_states) # query_states: torch.Size([1, 21, 4096])
key_states = self.k_proj(hidden_states) # key_states: torch.Size([1, 21, 4096/1024(group)])
value_states = self.v_proj(hidden_states)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
shared_gate = self.shared_gate_proj(hidden_states) #Hidden: # bsz, q_len, _ 1 21 4096
q_gate = shared_gate
k_gate = shared_gate
if self.num_key_value_groups != 1:
k_gate = self.gqa_proj(k_gate)
q_gate = F.silu(q_gate).view(bsz, -1, q_len, self.low_dim)
k_gate = F.silu(k_gate).view(bsz, -1, q_len, self.low_dim)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
# outcos: torch.Size([1, 21, 128])
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
# key_states: torch.Size([1, 32/8, 21, 128]) [bsz, num_heads or num_heads/group, q_len, head_dim]
query_states = self._apply_gate(query_states, q_gate)
key_states = self._apply_gate(key_states, k_gate)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# raise NotImplementedError()
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
# key_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class GemmaFreqAttnwithEager(Gemma2Attention):
def __init__(self, config, layer_idx,
low_ratio: float,
alpha: float, outargs):
super().__init__(config, layer_idx)
self.outargs = outargs
self.low_dim = int(self.head_dim * low_ratio)
self.alpha = alpha
proj_out = self.num_heads * self.low_dim
if self.outargs.use_lora_gate:
self.shared_gate_proj = LoraLinear(config.hidden_size, proj_out, rank=self.outargs.lora_dim)
else:
self.shared_gate_proj = nn.Linear(config.hidden_size, proj_out, bias=False)
if self.num_key_value_groups != 1:
print("!!!!!#WARNING: GQA Using")
self.gqa_proj = LoraLinear(proj_out, proj_out // self.num_key_value_groups, rank=self.outargs.lora_dim)
def _apply_gate(self, states, gate):
#query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
#q_gate: torch.Size([1, 21, 32, low]) [bsz, num_heads, q_len, low_dim]
head_dim = states.shape[-1]
half_dim = head_dim // 2
low_half_dim = self.low_dim // 2
part1 = states[..., :half_dim]
part2 = states[..., half_dim:]
low_freq_part1 = part1[..., -low_half_dim:]
low_freq_part2 = part2[..., -low_half_dim:]
low_freq_combined = torch.cat([low_freq_part1, low_freq_part2], dim=-1)
scale = 1. + self.alpha * gate
modified_low_freq = low_freq_combined * scale
modified_part1 = modified_low_freq[..., :low_half_dim]
modified_part2 = modified_low_freq[..., low_half_dim:]
states[..., half_dim - low_half_dim : half_dim] = modified_part1
states[..., -low_half_dim:] = modified_part2
return states
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() # bsz, q_len, _ 1 21 4096
query_states = self.q_proj(hidden_states) # query_states: torch.Size([1, 21, 4096])
key_states = self.k_proj(hidden_states) # key_states: torch.Size([1, 21, 4096/1024(group)])
value_states = self.v_proj(hidden_states)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
shared_gate = self.shared_gate_proj(hidden_states) #Hidden: # bsz, q_len, _ 1 21 4096
q_gate = shared_gate
k_gate = shared_gate
if self.num_key_value_groups != 1:
k_gate = self.gqa_proj(k_gate)
q_gate = F.silu(q_gate).view(bsz, -1, q_len, self.low_dim)
k_gate = F.silu(k_gate).view(bsz, -1, q_len, self.low_dim)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# query_states: torch.Size([1, 32, 21, 128]) [bsz, num_heads, q_len, head_dim]
# key_states: torch.Size([1, 32/8, 21, 128]) [bsz, num_heads or num_heads/group, q_len, head_dim]
query_states = self._apply_gate(query_states, q_gate)
key_states = self._apply_gate(key_states, k_gate)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
raise NotImplementedError()
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "flex_attention"
else:
attention_type = self.config._attn_implementation
attn_output, attn_weights = eager_attention_forward(
self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def collect_layer_scores(model):
scores = {}
for idx, layer in enumerate(model.model.layers):
total = 0.0
for module in [layer.input_layernorm, layer.post_attention_layernorm]:
if hasattr(module.weight, 'grad') and module.weight.grad is not None:
grad = module.weight.grad
total += torch.sum(grad**2).item()
scores[idx] = math.sqrt(total) # L2 范数
return scores
def mask_layer_grads(model, topk_layers: set[int]):
for idx, layer in enumerate(model.model.layers):
keep = idx in topk_layers
for n, p in layer.self_attn.named_parameters():
if any(t in n for t in ['shared_gate_proj', 'q_gate_proj', 'k_gate_proj', 'a_gate_proj']):
if not keep and p.grad is not None:
p.grad = None # 或 p.grad.mul_(0.)
def convert_to_rosa_model(model, args=None):
for idx, layer in tqdm(enumerate(model.model.layers), total=len(model.model.layers)):
if 'llama' in args.model_name_or_path.lower():
new_attn = LlamaFreqAttnwithEager(model.config, idx,
low_ratio=args.low_ratio, alpha=0.1, outargs=args)
elif 'qwen' in args.model_name_or_path.lower():
new_attn = QwenFreqAttnwithEager(model.config, idx,
low_ratio=args.low_ratio, alpha=0.1, outargs=args)
elif 'gemma' in args.model_name_or_path.lower():
new_attn = GemmaFreqAttnwithEager(model.config, idx,
low_ratio=args.low_ratio, alpha=0.1, outargs=args)
else:
NotImplementedError()
new_attn.load_state_dict(layer.self_attn.state_dict(), strict=False)
layer.self_attn = new_attn
target_params = ['weight', 'bias']
for layer in model.model.layers:
for module_name, module in layer.self_attn.named_modules():
if module_name in ['shared_gate_proj', 'q_gate_proj', 'k_gate_proj', 'a_gate_proj']:
for param_name, param in module.named_parameters():
for target_param in target_params:
if target_param in param_name and target_param == 'weight':
torch.nn.init.zeros_(param)
if target_param in param_name and target_param == 'bias':
torch.nn.init.zeros_(param)
print('bias')
return model
def only_optimize_rosa_parameters(model):
force_optimize_params=['shared_gate_proj', 'q_gate_proj', 'k_gate_proj', 'a_gate_proj', 'lora_A', 'lora_B', 'layernorm']
print('Only_Optim_rosa')
for name, param in model.named_parameters():
param.requires_grad = False
for target_name in force_optimize_params:
if target_name in name:
param.requires_grad = True
print(name)
break
return model