diff --git a/deeplay/applications/autoencoders/wae.py b/deeplay/applications/autoencoders/wae.py index c0d3ff2d..65042a11 100644 --- a/deeplay/applications/autoencoders/wae.py +++ b/deeplay/applications/autoencoders/wae.py @@ -126,6 +126,46 @@ def training_step(self, batch, batch_idx): return sum(loss.values()) + def validation_step(self, batch, batch_idx): + x, y = self.val_preprocess(batch) + y_hat, z = self(x) + rec_loss, mmd_loss = self.compute_loss(y_hat, y, z) + loss = { + "rec_loss": rec_loss, + "mmd_loss": mmd_loss, + } + for name, v in loss.items(): + self.log( + f"val_{name}", + v, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return sum(loss.values()) + + def test_step(self, batch, batch_idx): + x, y = self.test_preprocess(batch) + y_hat, z = self(x) + rec_loss, mmd_loss = self.compute_loss(y_hat, y, z) + loss = { + "rec_loss": rec_loss, + "mmd_loss": mmd_loss, + } + for name, v in loss.items(): + self.log( + f"test_{name}", + v, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return sum(loss.values()) + def compute_IMQ(self, x1, x2): # Inverse MultiQuadratic kernel C = 2 * self.latent_dim * self.z_var