Skip to content

Commit f909c4a

Browse files
committed
add <=1 requirement to bce weighted
1 parent 765299e commit f909c4a

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

chebai/loss/bce_weighted.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(
3333
self.data_extractor = data_extractor
3434

3535
assert (
36-
isinstance(beta, float) and beta > 0.0
37-
), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."
36+
isinstance(beta, float) and beta >= 0.0 and beta <= 1.0
37+
), f"Beta parameter must be a float with value between 0 and 1, for loss class {self.__class__.__name__}."
3838

3939
assert (
4040
self.data_extractor is not None
@@ -63,13 +63,16 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
6363
print(
6464
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
6565
)
66+
print(f"loading: {self.data_extractor.processed_file_names[0]}")
6667
complete_labels = torch.concat(
6768
[
6869
torch.stack(
6970
[
7071
torch.Tensor(row["labels"])
7172
for row in self.data_extractor.load_processed_data(
72-
filename=file_name
73+
filename=os.path.join(
74+
self.data_extractor.processed_dir, file_name
75+
)
7376
)
7477
]
7578
)

configs/loss/bce_weighted.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
class_path: chebai.loss.bce_weighted.BCEWeighted
22
init_args:
3-
beta: 0.99
3+
beta: 0.99 # this is the default weight, change this factor to increase/decrease the weighting effect

0 commit comments

Comments
 (0)