diff --git a/deeplay/applications/autoencoders/vae.py b/deeplay/applications/autoencoders/vae.py index 5334e9c4..deadb2e4 100644 --- a/deeplay/applications/autoencoders/vae.py +++ b/deeplay/applications/autoencoders/vae.py @@ -119,6 +119,40 @@ def training_step(self, batch, batch_idx): ) return tot_loss + def validation_step(self, batch, batch_idx): + x, y = self.val_preprocess(batch) + y_hat, mu, log_var = self(x) + rec_loss, KLD = self.compute_loss(y_hat, y, mu, log_var) + tot_loss = rec_loss + self.beta * KLD + loss = {"rec_loss": rec_loss, "KL": KLD, "total_loss": tot_loss} + for name, v in loss.items(): + self.log( + f"val_{name}", + v, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return tot_loss + + def test_step(self, batch, batch_idx): + x, y = self.test_preprocess(batch) + y_hat, mu, log_var = self(x) + rec_loss, KLD = self.compute_loss(y_hat, y, mu, log_var) + tot_loss = rec_loss + self.beta * KLD + loss = {"rec_loss": rec_loss, "KL": KLD, "total_loss": tot_loss} + for name, v in loss.items(): + self.log( + f"test_{name}", + v, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return tot_loss + def compute_loss(self, y_hat, y, mu, log_var): rec_loss = self.reconstruction_loss(y_hat, y) KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())