forked from jasperzuallaert/BasicPLMUsage
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.py
More file actions
39 lines (30 loc) · 1.32 KB
/
example.py
File metadata and controls
39 lines (30 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import language_models as lm
example_sequences = [
'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
'KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE'
]
device = 'cuda'
print(f'Running ESM-1b...')
embedding, mask = lm.run_esm(example_sequences, 'ESM-1b', device=device)
print(f'Embedding shape: {embedding.shape}')
print(f'Mask shape: {mask.shape}')
print(f'Running ESM-2-3B...')
embedding, mask = lm.run_esm(example_sequences, 'ESM-2-3B', device=device)
print(f'Embedding shape: {embedding.shape}')
print(f'Mask shape: {mask.shape}')
print(f'Running ESM-2-650M...')
embedding, mask = lm.run_esm(example_sequences, 'ESM-2-650M', device=device)
print(f'Embedding shape: {embedding.shape}')
print(f'Mask shape: {mask.shape}')
print(f'Running ProtTransT5_XL_UniRef50...')
embedding, mask = lm.run_prottransxl(example_sequences, device=device)
print(f'Embedding shape: {embedding.shape}')
print(f'Mask shape: {mask.shape}')
print(f'Running Ankh_base...')
embedding, mask = lm.run_ankh(example_sequences, 'base', device=device)
print(f'Embedding shape: {embedding.shape}')
print(f'Mask shape: {mask.shape}')
print(f'Running Ankh_large...')
embedding, mask = lm.run_ankh(example_sequences, 'large', device=device)
print(f'Embedding shape: {embedding.shape}')
print(f'Mask shape: {mask.shape}')