Fix VarAutoEncoder reparameterize returning mu+std at inference (#8413)#8933
Fix VarAutoEncoder reparameterize returning mu+std at inference (#8413)#8933Shizoqua wants to merge 1 commit into
Conversation
VarAutoEncoder.reparameterize added the standard deviation to mu at eval time (`std.add_(mu)` with no noise term), so inference returned mu + std instead of the posterior mean. At inference the latent code should be mu; the random term belongs only to training (the reparameterization trick). Return mu directly when not training, and compute mu + eps * std out-of-place otherwise. VarFullyConnectedNet.reparameterize had the identical bug and is fixed the same way. Adds regression tests asserting eval is deterministic and equals mu while training stays stochastic. Fixes Project-MONAI#8413. Signed-off-by: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com>
📝 WalkthroughWalkthrough
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (4)
monai/networks/nets/fullyconnectednet.py (1)
174-181: ⚡ Quick winMissing docstring for
reparameterizemethod.Per coding guidelines, all definitions require Google-style docstrings describing parameters, return values, and exceptions.
📝 Suggested docstring
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """Sample latent code using the reparameterization trick. + + At inference (eval mode), returns the posterior mean directly. During + training, samples z = mu + eps * std where eps ~ N(0, I). + + Args: + mu: Mean of the approximate posterior, shape (batch, latent_size). + logvar: Log-variance of the approximate posterior, same shape as mu. + + Returns: + Sampled latent code, same shape as mu. + """ if not self.training:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/fullyconnectednet.py` around lines 174 - 181, Add a Google-style docstring to the reparameterize method that documents the mu and logvar parameters as torch.Tensor inputs, describes the return value as a torch.Tensor, and explains the method's behavior including the reparameterization trick applied during training and the inference-time behavior when self.training is False. Place the docstring immediately after the method signature and before the method body.Source: Coding guidelines
tests/networks/nets/test_varautoencoder.py (1)
125-145: 💤 Low valueMissing docstring for test method.
Same as
test_fullyconnectednet.py. The inline comment documents intent well but a docstring is preferred per guidelines.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/nets/test_varautoencoder.py` around lines 125 - 145, The test method test_reparameterize_eval_returns_mu currently has an inline comment describing its purpose instead of a proper docstring. Convert the inline comment at the beginning of the method (the lines explaining the test intent regarding eval mode returning mu and the regression test for `#8413`) into a docstring by placing it immediately after the method definition using triple quotes, following the same pattern used in test_fullyconnectednet.py.Source: Coding guidelines
monai/networks/nets/varautoencoder.py (1)
144-151: ⚡ Quick winMissing docstring for
reparameterizemethod.Same issue as
VarFullyConnectedNet. Docstring required per coding guidelines.📝 Suggested docstring
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """Sample latent code using the reparameterization trick. + + At inference (eval mode), returns the posterior mean directly. During + training, samples z = mu + eps * std where eps ~ N(0, I). + + Args: + mu: Mean of the approximate posterior, shape (batch, latent_size). + logvar: Log-variance of the approximate posterior, same shape as mu. + + Returns: + Sampled latent code, same shape as mu. + """ if not self.training:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/varautoencoder.py` around lines 144 - 151, The reparameterize method is missing a docstring as required by coding guidelines. Add a docstring to the reparameterize method that describes its purpose (implementing the reparameterization trick for VAE training), documents the input parameters (mu and logvar as torch.Tensor objects), explains the behavior difference between training and inference modes, and clearly specifies the return type and what the returned tensor represents.Source: Coding guidelines
tests/networks/nets/test_fullyconnectednet.py (1)
67-86: 💤 Low valueMissing docstring for test method.
Per coding guidelines, definitions require docstrings. The inline comment is helpful but a proper docstring is expected.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/nets/test_fullyconnectednet.py` around lines 67 - 86, The test method test_vfc_reparameterize_eval_returns_mu is missing a proper docstring and instead relies on an inline comment. Convert the existing inline comment explaining the test's purpose into a formal docstring by placing it immediately after the method definition line using triple quotes, removing the inline comment after the conversion.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@monai/networks/nets/fullyconnectednet.py`:
- Around line 174-181: Add a Google-style docstring to the reparameterize method
that documents the mu and logvar parameters as torch.Tensor inputs, describes
the return value as a torch.Tensor, and explains the method's behavior including
the reparameterization trick applied during training and the inference-time
behavior when self.training is False. Place the docstring immediately after the
method signature and before the method body.
In `@monai/networks/nets/varautoencoder.py`:
- Around line 144-151: The reparameterize method is missing a docstring as
required by coding guidelines. Add a docstring to the reparameterize method that
describes its purpose (implementing the reparameterization trick for VAE
training), documents the input parameters (mu and logvar as torch.Tensor
objects), explains the behavior difference between training and inference modes,
and clearly specifies the return type and what the returned tensor represents.
In `@tests/networks/nets/test_fullyconnectednet.py`:
- Around line 67-86: The test method test_vfc_reparameterize_eval_returns_mu is
missing a proper docstring and instead relies on an inline comment. Convert the
existing inline comment explaining the test's purpose into a formal docstring by
placing it immediately after the method definition line using triple quotes,
removing the inline comment after the conversion.
In `@tests/networks/nets/test_varautoencoder.py`:
- Around line 125-145: The test method test_reparameterize_eval_returns_mu
currently has an inline comment describing its purpose instead of a proper
docstring. Convert the inline comment at the beginning of the method (the lines
explaining the test intent regarding eval mode returning mu and the regression
test for `#8413`) into a docstring by placing it immediately after the method
definition using triple quotes, following the same pattern used in
test_fullyconnectednet.py.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 9e91163b-f57c-4989-b7db-d6d818f586fa
📒 Files selected for processing (4)
monai/networks/nets/fullyconnectednet.pymonai/networks/nets/varautoencoder.pytests/networks/nets/test_fullyconnectednet.pytests/networks/nets/test_varautoencoder.py
Description
Fixes #8413.
VarAutoEncoder.reparameterizecomputedstd = exp(0.5*logvar), added the random noise term only when training, and then returnedstd.add_(mu). At inference (eval()),stdis still the standard deviation, so the method returnedmu + stdinstead of the posterior meanmu. The reparameterization trick's stochastic term should apply during training only; at inference the latent code should bemu.The fix returns
mudirectly when not training, and computesmu + eps * stdout-of-place otherwise (also avoids the in-placeadd_).VarFullyConnectedNet.reparameterize(monai/networks/nets/fullyconnectednet.py) contained the identical bug, so it is fixed the same way in this PR.Types of changes
Testing
Added regression tests asserting that in eval mode the returned latent equals
muand is deterministic, while training stays stochastic, for bothVarAutoEncoderandVarFullyConnectedNet.