Skip to content

Commit 75ade76

Browse files
committed
make compilation optional, return None for empty attention graph
1 parent a4365a6 commit 75ade76

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

chebifier/prediction_models/electra_predictor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import numpy as np
24

35
from .nn_predictor import NNPredictor
@@ -40,7 +42,7 @@ def __init__(self, model_name: str, ckpt_path: str, **kwargs):
4042
f"Initialised Electra model {self.model_name} (device: {self.predictor.device})"
4143
)
4244

43-
def explain_smiles(self, smiles) -> dict:
45+
def explain_smiles(self, smiles) -> Optional[dict]:
4446
from chebai.preprocessing.reader import EMBEDDING_OFFSET
4547

4648
# Add dummy labels because the collate function requires them.
@@ -69,4 +71,6 @@ def explain_smiles(self, smiles) -> dict:
6971
]
7072
for a in result["attentions"]
7173
]
74+
if len(graphs) == 0:
75+
return None
7276
return {"graphs": graphs}

chebifier/prediction_models/nn_predictor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@ def __init__(
2020
):
2121
super().__init__(model_name, **kwargs)
2222
self.batch_size = kwargs.get("batch_size", None)
23+
# compile_model will run the model in eager mode, which gives better performance, but does not return intermediate states
24+
# such as attention weights. Therfore, ELECTRA attention graphs will only work with compile_model=False.
25+
compile_model = kwargs.get("compile_model", True)
2326
# If batch_size is not provided, it will be set to default batch size used during training in Predictor
24-
self.predictor: Predictor = Predictor(ckpt_path, self.batch_size)
27+
self.predictor: Predictor = Predictor(
28+
ckpt_path, self.batch_size, compile_model=compile_model
29+
)
2530

2631
@modelwise_smiles_lru_cache.batch_decorator
2732
def predict_smiles_list(self, smiles_list: list[str]) -> list:
@@ -51,4 +56,5 @@ def calculate_results(self, batch):
5156
dat = self.predictor._model._process_batch(
5257
collator(batch).to(self.predictor.device), 0
5358
)
59+
5460
return self.predictor._model(dat, **dat["model_kwargs"])

0 commit comments

Comments
 (0)