Skip to content

LR Schedulers#

All schedulers in thunder are subclasses of torch.optim.lr_scheduler.LRScheduler. However, during initialization they do not require optimizer to be passed.

Usage#

We will use Switch as an example.

from thunder.policy import Switch

switch = Switch({10: 0.001, 20: 0.001 / 10})
We have just created a policy, but to make it work, it still needs an optimizers.
Let's see how it works after being assembled.
optimizer = Adam(...)
scheduler(optimizer) # binds optimizer to scheduler
# or 
# scheduler = scheduler(optimizer)
# You can also retrieve optimizer:
opt = scheduler.optimizer
After assigning optimizer to scheduler, policy instance will work just like usual torch scheduler.

Initial LR#

All schedulers have lr_init parameters, if specified, it will be used as lr value on 0th step.

Reference#

thunder.policy.Multiply #

Bases: MappingPolicy

Multiplies learning rate value on the specified factor in mapping. Example:

    sch = Multiply({1: 0.1, 4: 0.3})
if initial learning rate is 1e-3, learning rate will be: 1e-3, 1e-4, 1e-4, 1e-4, 3-e5, ...

Parameters#

mapping: Union[List[Dict[int, float]], Dict[int, float]] Maps epoch to factor, keeping the last value between the epochs. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

Source code in thunder/policy.py
class Multiply(MappingPolicy):
    """
    Multiplies learning rate value on the specified factor in `mapping`.
    Example:
        ```python
            sch = Multiply({1: 0.1, 4: 0.3})
        ```
        if initial learning rate is 1e-3, learning rate will be: 1e-3, 1e-4, 1e-4, 1e-4, 3-e5, ...

    Parameters
    ----------
    mapping: Union[List[Dict[int, float]], Dict[int, float]]
        Maps epoch to factor, keeping the last value between the epochs.
    lr_init: Union[List[float], float]]
        Initial learning rate for each group of parameters.
    """
    mapping: Union[List[Dict[int, float]], Dict[int, float]]

    def get_lr(self) -> List[float]:
        return [
            param_group["lr"] * mapping.get(self.last_epoch, 1)
            for param_group, mapping in zip_equal(self.optimizer.param_groups, self.current_mapping)
        ]

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)

thunder.policy.Schedule #

Bases: MappingPolicy

Assigns learning rate values received from callable mapping. Example:

sch = Schedule(np.cos)
lr will have values of np.cos(epoch_number)

Parameters#

mapping: Union[List[Callable], Callable]] Maps epoch to value. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

Source code in thunder/policy.py
class Schedule(MappingPolicy):
    """
    Assigns learning rate values received from callable mapping.
    Example:
        ```python
        sch = Schedule(np.cos)
        ```
        lr will have values of np.cos(epoch_number)

    Parameters
    ----------
    mapping: Union[List[Callable], Callable]]
        Maps epoch to value.
    lr_init: Union[List[float], float]]
        Initial learning rate for each group of parameters.
    """
    mapping: Union[List[Callable], Callable]

    def get_lr(self) -> List[float]:
        return juxt(self.current_mapping)(self.last_epoch)

    def state_dict(self) -> Dict[str, Any]:
        return self.prepare_state_dict("mapping", "current_mapping")

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        super().load_state_dict(state_dict)

thunder.policy.Switch #

Bases: MappingPolicy

Assigns learning rate values received from dict mapping. Example:

sch = Switch({0: 1e-4, 2: 1e-10)
lr: 1e-4, 1e-4, 1e-10, 1e-10, ...

Parameters#

mapping: Union[List[Dict[int, float]], Dict[int, float]] Maps specified epochs to specified values, preserving learning rate between epochs. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

Source code in thunder/policy.py
class Switch(MappingPolicy):
    """
    Assigns learning rate values received from dict mapping.
    Example:
        ```python
        sch = Switch({0: 1e-4, 2: 1e-10)
        ```
        lr: 1e-4, 1e-4, 1e-10, 1e-10, ...

    Parameters
    ----------
    mapping: Union[List[Dict[int, float]], Dict[int, float]]
        Maps specified epochs to specified values, preserving learning rate between epochs.
    lr_init: Union[List[float], float]]
        Initial learning rate for each group of parameters.
    """
    mapping: Union[List[Dict[int, float]], Dict[int, float]]

    def get_lr(self) -> List[float]:
        return [
            mapping.get(self.last_epoch, param_group["lr"])
            for param_group, mapping in zip_equal(self.optimizer.param_groups, self.current_mapping)
        ]

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        super().load_state_dict(state_dict)

Base classes#

thunder.policy #

Policy #

Bases: _LRScheduler

Policy base class.

Source code in thunder/policy.py
class Policy(LRScheduler, metaclass=ABCMeta):
    """
    Policy base class.
    """
    def __init__(self):
        pass

    def __call__(self, optimizer: Optimizer) -> Policy:
        self.set_optimizer(optimizer)
        return self

    def set_optimizer(self, optimizer: Optimizer) -> None:
        """Assigns optimizer to a scheduler"""
        super().__init__(optimizer)

    @abstractmethod
    def get_lr(self) -> List[float]:
        """
        Computes new value of learning rate.
        Returns
        -------
        List[float]
        """
        pass

    def prepare_state_dict(self, *keys: str) -> Dict[str, Any]:
        """
        Creates state dict of scheduler, excluding optimizer and specified keys.
        Be aware that this method does not save state_dict. And only useful for preparing it.
        Parameters
        ----------
        keys: str
            Names of attributes to be excluded from state_dict

        Returns
        -------
        Dict[str, Any]
        """
        keys = (*keys, "optimizer")
        return {key: value for key, value in self.__dict__.items() if key not in keys}

    @abstractmethod
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """
        Loads state dict of scheduler
        Parameters
        ----------
        state_dict: Dict[str, Any]
            State dict of scheduler.
        """
        self.__dict__.update(state_dict)

get_lr() abstractmethod #

Computes new value of learning rate. Returns


List[float]

Source code in thunder/policy.py
@abstractmethod
def get_lr(self) -> List[float]:
    """
    Computes new value of learning rate.
    Returns
    -------
    List[float]
    """
    pass

load_state_dict(state_dict) abstractmethod #

Loads state dict of scheduler Parameters


state_dict: Dict[str, Any] State dict of scheduler.

Source code in thunder/policy.py
@abstractmethod
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    """
    Loads state dict of scheduler
    Parameters
    ----------
    state_dict: Dict[str, Any]
        State dict of scheduler.
    """
    self.__dict__.update(state_dict)

prepare_state_dict(*keys) #

Creates state dict of scheduler, excluding optimizer and specified keys. Be aware that this method does not save state_dict. And only useful for preparing it. Parameters


keys: str Names of attributes to be excluded from state_dict

Returns#

Dict[str, Any]

Source code in thunder/policy.py
def prepare_state_dict(self, *keys: str) -> Dict[str, Any]:
    """
    Creates state dict of scheduler, excluding optimizer and specified keys.
    Be aware that this method does not save state_dict. And only useful for preparing it.
    Parameters
    ----------
    keys: str
        Names of attributes to be excluded from state_dict

    Returns
    -------
    Dict[str, Any]
    """
    keys = (*keys, "optimizer")
    return {key: value for key, value in self.__dict__.items() if key not in keys}

set_optimizer(optimizer) #

Assigns optimizer to a scheduler

Source code in thunder/policy.py
def set_optimizer(self, optimizer: Optimizer) -> None:
    """Assigns optimizer to a scheduler"""
    super().__init__(optimizer)

MappingPolicy #

Bases: Policy

Source code in thunder/policy.py
class MappingPolicy(Policy, metaclass=ABCMeta):
    def __init__(self, mapping, lr_init: Union[List[float], float] = 1e-3):
        """
        Base class for policy with mapping. Mapping can be a dict or a function
        (it should also be a list of latter types in case of multiple param groups).
        Mapping is the binding between epoch or step number and learning rate value.
        Parameters
        ----------
        mapping
            Binding of epoch or step number and learning rate.
        lr_init: Union[List[float], float]]
            Initial learning rate for each group of parameters.
        """
        self.current_mapping = None
        self.mapping = mapping

        self.current_lr_init = None
        self.lr_init = lr_init

        super().__init__()

    def set_optimizer(self, optimizer: Optimizer) -> None:
        self.current_mapping = self.mapping
        if isinstance(self.mapping, dict) or callable(self.mapping):
            self.current_mapping = [deepcopy(self.mapping) for _ in optimizer.param_groups]

        self.current_lr_init = self.lr_init
        if isinstance(self.lr_init, (float, int)):
            self.current_lr_init = [self.lr_init for _ in optimizer.param_groups]

        if len(self.current_mapping) != len(optimizer.param_groups):
            raise ValueError(f"Got {len(self.current_mapping)} mappings and {len(optimizer.param_groups)} param groups")

        if len(self.current_lr_init) != len(optimizer.param_groups):
            raise ValueError(f"Got {len(self.current_lr_init)} lr_init and {len(optimizer.param_groups)} param groups")

        for lr_init, param_group in zip_equal(self.current_lr_init, optimizer.param_groups):
            param_group["lr"] = lr_init

        super().set_optimizer(optimizer)

    def __repr__(self) -> str:
        mapping = self.current_mapping if self.current_mapping else self.mapping
        lr_init = self.current_lr_init if self.current_lr_init is not None else self.lr_init
        return f"{self.__class__.__name__}({mapping=}, {lr_init=})"

__init__(mapping, lr_init=0.001) #

Base class for policy with mapping. Mapping can be a dict or a function (it should also be a list of latter types in case of multiple param groups). Mapping is the binding between epoch or step number and learning rate value. Parameters


mapping Binding of epoch or step number and learning rate. lr_init: Union[List[float], float]] Initial learning rate for each group of parameters.

Source code in thunder/policy.py
def __init__(self, mapping, lr_init: Union[List[float], float] = 1e-3):
    """
    Base class for policy with mapping. Mapping can be a dict or a function
    (it should also be a list of latter types in case of multiple param groups).
    Mapping is the binding between epoch or step number and learning rate value.
    Parameters
    ----------
    mapping
        Binding of epoch or step number and learning rate.
    lr_init: Union[List[float], float]]
        Initial learning rate for each group of parameters.
    """
    self.current_mapping = None
    self.mapping = mapping

    self.current_lr_init = None
    self.lr_init = lr_init

    super().__init__()