-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
47 lines (41 loc) · 1.52 KB
/
model.py
File metadata and controls
47 lines (41 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# model
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from sklearn.metrics import f1_score
from config import model_name, max_length, batch_size, num_epochs, random_seed
def load_tokenizer():
return AutoTokenizer.from_pretrained(model_name)
def load_model(num_labels):
return AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
def compute_metrics(pred):
labels = pred.label_ids
preds = np.argmax(pred.predictions, axis=1)
return {
"accuracy": (preds == labels).mean(),
"f1_macro": f1_score(labels, preds, average="macro"),
"f1_micro": f1_score(labels, preds, average="micro")
}
def create_trainer(model, tokenizer, train_dataset, val_dataset):
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_epochs,
eval_strategy="epoch",
save_strategy="epoch",
logging_dir="./logs",
load_best_model_at_end=True,
metric_for_best_model="eval_f1_macro",
greater_is_better=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_metrics
)
return trainer