Add Sound Encoder to Cosmos3#13911
Conversation
Signed-off-by: Maciej Bala <mbala@nvidia.com>
| def _disable_encoder(self): | ||
| self.encoder = None | ||
| self._encoder_available = False | ||
| self.register_to_config(encoder_enabled=False) | ||
|
|
||
| def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: | ||
| super()._fix_state_dict_keys_on_load(state_dict) | ||
| if self.encoder is not None and not any(key.startswith("encoder.") for key in state_dict): | ||
| self._disable_encoder() | ||
|
|
There was a problem hiding this comment.
why do we need these two methods?
There was a problem hiding this comment.
It's an extra safety net for checkpoints that do not have the encoder weights. We will update the main checkpoint to have encoder weights, but I think it's still fine to keep this method in case of e.g. cached local checkpoints. We don't want them to break if people don't need the encoder weights.
| return hidden_states | ||
|
|
||
|
|
||
| class Cosmos3AudioSnakeBeta(nn.Module): |
There was a problem hiding this comment.
It looks like the existing Snake1d module implements essentially the same logic as Cosmos3AudioSnakeBeta, could we use it as well for the encoder?
There was a problem hiding this comment.
The math should be the same, but we'd need a reshape on load, since Cosmos3AudioSnakeBeta has 1D parameters instead of 3D. Let me think about it for a bit.
There was a problem hiding this comment.
I kept the separate classes for native checkpoint loading, but shared a forward implementation
There was a problem hiding this comment.
Could we potentially reshape the encoder Snake alpha/beta weights to 3D in the scripts/convert_cosmos3_to_diffusers.py conversion script? I think this would allow us to reuse Snake1d for the encoder.
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left an initial design review :).
Signed-off-by: Maciej Bala <mbala@nvidia.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| alpha = self.alpha if not self.logscale else torch.exp(self.alpha) | ||
| beta = self.beta if not self.logscale else torch.exp(self.beta) | ||
| @staticmethod | ||
| def _forward(hidden_states, alpha, beta, logscale): |
There was a problem hiding this comment.
Since Snake1d is # Copied from the Oobleck audio VAE for the Cosmos 3 decoder, I don't think we should modify it here (and this breaks the CI check that the implementations are synced). I think it would be better (for example) to convert the encoder Snake weights from 1D to 3D when converting the checkpoint, as suggested in #13911 (comment).
| from diffusers.models.autoencoders.autoencoder_oobleck import OobleckDiagonalGaussianDistribution | ||
|
|
||
|
|
||
| def _get_tiny_cosmos3_audio_tokenizer() -> Cosmos3AVAEAudioTokenizer: |
There was a problem hiding this comment.
Would it be possible to refactor the tests to use the standard diffusers model test config + test mixins? For reference, here is what the LTX-2 audio VAE tests do:
So for example we would move the _get_tiny_cosmos3_audio_tokenizer logic into the get_init_dict method of a new Cosmos3AVAEAudioTokenizerTesterConfig class. I think we would still keep the Cosmos 3-specific tests below.
dg845
left a comment
There was a problem hiding this comment.
Thanks for the changes! Left some follow up comments.
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.