LR Schedulers#

All schedulers in thunder are subclasses of torch.optim.lr_scheduler.LRScheduler. However, during initialization they do not require optimizer to be passed.


We will use Switch as an example.

from thunder.policy import Switch

switch = Switch({10: 0.001, 20: 0.001 / 10})
We have just created a policy, but to make it work, it still needs an optimizers.
Let's see how it works after being assembled.
optimizer = Adam(...)
scheduler(optimizer) # binds optimizer to scheduler
# or 
# scheduler = scheduler(optimizer)
# You can also retrieve optimizer:
opt = scheduler.optimizer
After assigning optimizer to scheduler, policy instance will work just like usual torch scheduler.

Initial LR#

All schedulers have lr_init parameters, if specified, it will be used as lr value on 0th step.


thunder.policy.Multiply #

Bases: MappingPolicy

Multiplies learning rate value on the specified factor in mapping. Example:

    sch = Multiply({1: 0.1, 4: 0.3})
if initial learning rate is 1e-3, learning rate will be: 1e-3, 1e-4, 1e-4, 1e-4, 3-e5, ...


mapping: Union[List[Dict[int, float]], Dict[int, float]] Maps epoch to factor, keeping the last value between the epochs. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

    mapping: Union[List[Dict[int, float]], Dict[int, float]]

    def get_lr(self) -> List[float]:
        return [
            param_group["lr"] * mapping.get(self.last_epoch, 1)
            for param_group, mapping in zip_equal(self.optimizer.param_groups, self.current_mapping)

    def load_state_dict(self, state_dict):

thunder.policy.Schedule #

Bases: MappingPolicy

Assigns learning rate values received from callable mapping. Example:

sch = Schedule(np.cos)
lr will have values of np.cos(epoch_number)


mapping: Union[List[Callable], Callable]] Maps epoch to value. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

    mapping: Union[List[Callable], Callable]

    def get_lr(self) -> List[float]:
        return juxt(self.current_mapping)(self.last_epoch)

    def state_dict(self) -> Dict[str, Any]:
        return self.prepare_state_dict("mapping", "current_mapping")

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

thunder.policy.Switch #

Bases: MappingPolicy

Assigns learning rate values received from dict mapping. Example:

sch = Switch({0: 1e-4, 2: 1e-10)
lr: 1e-4, 1e-4, 1e-10, 1e-10, ...


mapping: Union[List[Dict[int, float]], Dict[int, float]] Maps specified epochs to specified values, preserving learning rate between epochs. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

    mapping: Union[List[Dict[int, float]], Dict[int, float]]

    def get_lr(self) -> List[float]:
        return [
            mapping.get(self.last_epoch, param_group["lr"])
            for param_group, mapping in zip_equal(self.optimizer.param_groups, self.current_mapping)

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

Base classes#

thunder.policy #

Policy #

Bases: _LRScheduler

Policy base class.

get_lr() abstractmethod #

Computes new value of learning rate. Returns


load_state_dict(state_dict) abstractmethod #

Loads state dict of scheduler Parameters

state_dict: Dict[str, Any] State dict of scheduler.

prepare_state_dict(*keys) #

Creates state dict of scheduler, excluding optimizer and specified keys. Be aware that this method does not save state_dict. And only useful for preparing it. Parameters

keys: str Names of attributes to be excluded from state_dict


Dict[str, Any]

set_optimizer(optimizer) #

Assigns optimizer to a scheduler

MappingPolicy #

Bases: Policy

class MappingPolicy(Policy, metaclass=ABCMeta):
    def __init__(self, mapping, lr_init: Union[List[float], float] = 1e-3):
        Base class for policy with mapping. Mapping can be a dict or a function
        (it should also be a list of latter types in case of multiple param groups).
        Mapping is the binding between epoch or step number and learning rate value.
            Binding of epoch or step number and learning rate.
        lr_init: Union[List[float], float]]
            Initial learning rate for each group of parameters.
        self.current_mapping = None
        self.mapping = mapping

        self.current_lr_init = None
        self.lr_init = lr_init


    def set_optimizer(self, optimizer: Optimizer) -> None:
        self.current_mapping = self.mapping
        if isinstance(self.mapping, dict) or callable(self.mapping):
            self.current_mapping = [deepcopy(self.mapping) for _ in optimizer.param_groups]

        self.current_lr_init = self.lr_init
        if isinstance(self.lr_init, (float, int)):
            self.current_lr_init = [self.lr_init for _ in optimizer.param_groups]

        if len(self.current_mapping) != len(optimizer.param_groups):
            raise ValueError(f"Got {len(self.current_mapping)} mappings and {len(optimizer.param_groups)} param groups")

        if len(self.current_lr_init) != len(optimizer.param_groups):
            raise ValueError(f"Got {len(self.current_lr_init)} lr_init and {len(optimizer.param_groups)} param groups")

        for lr_init, param_group in zip_equal(self.current_lr_init, optimizer.param_groups):
            param_group["lr"] = lr_init


    def __repr__(self) -> str:
        mapping = self.current_mapping if self.current_mapping else self.mapping
        lr_init = self.current_lr_init if self.current_lr_init is not None else self.lr_init
        return f"{self.__class__.__name__}({mapping=}, {lr_init=})"

__init__(mapping, lr_init=0.001) #

Base class for policy with mapping. Mapping can be a dict or a function (it should also be a list of latter types in case of multiple param groups). Mapping is the binding between epoch or step number and learning rate value. Parameters

mapping Binding of epoch or step number and learning rate. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

