@@ -5,16 +5,16 @@ Subject: [PATCH] override SWALR.state_dict and load_state_dict (#163122)
55
66Fixes #163105
77
8- - Add typing_extensions.override
98- Use _set_anneal_func to set anneal function
109- Implement state_dict and load_state_dict for SWALR excluding optimizer and anneal_func
1110
1211Signed-off-by: Azure Linux Security Servicing Account <azurelinux-security@microsoft.com>
1312Upstream-reference: AI Backport of https://github.com/pytorch/pytorch/commit/167ad09be5af5c52666759412a3804068c6955d1.patch
13+
1414---
1515 test/test_optim.py | 16 ++++++++++++++++
16- torch/optim/swa_utils.py | 37 ++ +++++++++++++++++++++++++++++++----
17- 2 files changed, 49 insertions(+), 4 deletions(-)
16+ torch/optim/swa_utils.py | 35 +++++++++++++++++++++++++++++++----
17+ 2 files changed, 47 insertions(+), 4 deletions(-)
1818
1919diff --git a/test/test_optim.py b/test/test_optim.py
2020index 1608478b..d3dd4567 100644
@@ -44,17 +44,17 @@ index 1608478b..d3dd4567 100644
4444 class SWATestDNN(torch.nn.Module):
4545 def __init__(self, input_features):
4646diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py
47- index dda4b8ad..d18084e2 100644
47+ index dda4b8ad..abd8128f 100644
4848--- a/torch/optim/swa_utils.py
4949+++ b/torch/optim/swa_utils.py
50- @@ -2,6 +2,7 @@ import itertools
51- import math
52- from copy import deepcopy
53- import warnings
54- + from typing_extensions import override
55-
50+ @@ -6,6 +6,7 @@ import warnings
5651 import torch
5752 from torch.nn import Module
53+ from torch.optim.lr_scheduler import LRScheduler
54+ + from typing import Any, Literal
55+
56+ __all__ = ['AveragedModel', 'update_bn', 'SWALR']
57+
5858@@ -247,10 +248,7 @@ class SWALR(LRScheduler):
5959 if anneal_strategy not in ['cos', 'linear']:
6060 raise ValueError("anneal_strategy must by one of 'cos' or 'linear', "
@@ -67,7 +67,7 @@ index dda4b8ad..d18084e2 100644
6767 if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
6868 raise ValueError(f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}")
6969 self.anneal_epochs = anneal_epochs
70- @@ -296,3 +294,34 @@ class SWALR(LRScheduler):
70+ @@ -296,3 +294,32 @@ class SWALR(LRScheduler):
7171 alpha = self.anneal_func(t)
7272 return [group['swa_lr'] * alpha + lr * (1 - alpha)
7373 for group, lr in zip(self.optimizer.param_groups, prev_lrs)]
@@ -79,7 +79,6 @@ index dda4b8ad..d18084e2 100644
7979+ else:
8080+ self.anneal_func = self._linear_anneal
8181+
82- + @override
8382+ def state_dict(self) -> dict[str, Any]:
8483+ """Return the state of the scheduler as a :class:`dict`.
8584+
@@ -92,7 +91,6 @@ index dda4b8ad..d18084e2 100644
9291+ if key not in ("optimizer", "anneal_func")
9392+ }
9493+
95- + @override
9694+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
9795+ """Load the scheduler's state.
9896+
@@ -102,6 +100,7 @@ index dda4b8ad..d18084e2 100644
102100+ """
103101+ self.__dict__.update(state_dict)
104102+ self._set_anneal_func(self._anneal_strategy)
103+ \ No newline at end of file
105104- -
1061052.45.4
107106
0 commit comments