Skip to content

ThunderModule#

ThunderModule inherits everything from LightningModule and implements essential methods for most common training pipelines.

From Lightning to Thunder#

Most common pipelines are implemented in lightning in the following way:

from lightning import LightningModule

class Model(LightningModule):
    def __init__(self):
        self.architecture: nn.Module = ...
        self.metrics = ... # smth like Dict[str, Callable]

    def forward(self, *args, **kwargs):
        return self.architecture(*args, **kwargs)

    def criterion(self, x, y):
        ...

    def training_step(self, batch, batch_idx):
        x, y = batch
        return self.criterion(self(x), y)

    def validation_step(self, batch, batch_idx, dataloader_idx):
        # forward and metrics computation or output preservation
        ...

    def test_step(self, batch, batch_idx, dataloader_idx):
        # forward and metrics computation or output preservation
        ...

    def configure_optimizers(self):
        return Adam(...), StepLR(...)

ThunderModule offers an implementation of necessary steps shown above.

from thunder import ThunderModule

architecture: nn.Module = ...
criterion = CrossEntropy()
optimizer = Adam(architecture.parameters())
scheduler = StepLR(optimizer)

model = ThunderModule(architecture, criterion,
                      optimizer=optimizer, lr_scheduler=scheduler)

Configuring Optimizers#

For extra information see this.
Lightning requires optimizers and learning rate policies to be defined inside configure_optimizers method.
Using ThunderModule allows you to pass the following configurations of optimizers and learning rate schedulers:

from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Adam

architecture = nn.Linear(2, 2)

No scheduling#

optimizer = Adam(architecture.parameters())
model = ThunderModule(..., optimizer=optimizer)

Defining optimizer and scheduler#

optimizer = Adam(architecture.parameters())
lr_scheduler = LRScheduler(optimizer)
model = ThunderModule(..., optimizer=optimizer, lr_scheduler=lr_scheduler)

Defining no optimizer#

lr_scheduler = LRScheduler(optimizer)
model = ThunderModule(..., lr_scheduler=lr_scheduler)

Multiple Optimizers#

Thunder just as lightning supports configuration with more than 1 optimizer. If such configuration is to be used, manual optimization is required.
Guide on manual optimization

In thunder you can pass lists of optimizers and schedulers to ThunderModule.

class ThunderModuleManual(ThunderModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.automatic_optimization = False


optimizers = [Adam(module1.parameters()), Adam(module2.parameters())]
lr_schedulers = [Scheduler(opt) for opt in optimizers]

model = ThunderModuleManual(..., optimizer=optimizers, lr_scheduler=lr_schedulers)

Thunder Policies#

As shown above, torch schedulers require optimizer(s) to be passed to them before they are given to ThunderModule. It is not very convenient, and also they lack some basic functionality.
You can use thunder policies just like torch schedulers:

from thunder.policy import Switch

optimizers = [Adam(module1.parameters()), Adam(module2.parameters())]
lr_schedulers = [Switch({1: 0.001}), Switch({2: 0.001})]

model = ThunderModuleManual(..., optimizer=optimizers, lr_scheduler=lr_schedulers)

For extra information see Thunder Policies Docs.

Inference#

During inference step, ThunderModule uses Predictors in order to preprocess data and make inverse transforms after passing data through the model. Default predictor is just an identity function.

For more on predictors see Thunder Predictors Docs.

Batch Transfer#

ThunderModule transfers training batches to device by default. However, during inference batch remains on the device, on which it was received from data loader. Transferring happens later in the inference_step, which is invoked in validation_step, test_step and predict_step.

Reference#

thunder.torch.core.ThunderModule #

Bases: LightningModule

Source code in thunder/torch/core.py
class ThunderModule(LightningModule):
    def __init__(
            self,
            architecture: nn.Module,
            criterion: Callable,
            n_targets: int = 1,
            activation: Callable = identity,
            optimizer: Union[List[Optimizer], Optimizer] = None,
            lr_scheduler: Union[List[LRScheduler], LRScheduler] = None,
            predictor: BasePredictor = None,
            n_val_targets: int = None
    ):
        """
        Parameters
        ----------
        architecture: nn.Module
            Model architecture used to conduct forward pass.
        criterion: Callable
            Criterion to optimize.
        n_targets: int
            Number of target values in train and inference batches, if negative, then ...
        activation: Callable
            Final activation function for inference, identity by default.
        optimizer: Union[List[Optimizer], Optimizer]
            Optimizers.
        lr_scheduler: Union[List[LRScheduler], LRScheduler]
            Learning Rate policies.
        predictor: BasePredictor.
            Predictor for inference.
        n_val_targets: int
            Number of target values for inference, if set to None assumes value of `n_targets`.
        """
        super().__init__()
        self.architecture = architecture
        self.criterion = criterion
        self.n_targets = n_targets
        self.n_val_targets = n_targets if n_val_targets is None else n_val_targets
        self.activation = activation
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.predictor = predictor if predictor else Predictor()

    def transfer_batch_to_device(self, batch: Tuple, device: torch.device, dataloader_idx: int) -> Any:
        if self.trainer.state.stage != "train":
            return batch
        return super().transfer_batch_to_device(maybe_from_np(batch, device=device), device, dataloader_idx)

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        return self.architecture(*args, **kwargs)

    def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT:
        x, y = batch[: -self.n_targets], batch[-self.n_targets:]
        return self.criterion(self(*x), *y)

    def validation_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
        return self.inference_step(batch, batch_idx, dataloader_idx)

    def test_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
        return self.inference_step(batch, batch_idx, dataloader_idx)

    def predict_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
        return self.inference_step(batch, batch_idx, dataloader_idx)

    def predict(self, x) -> STEP_OUTPUT:
        # TODO: do we need super(). ...?, also consider changing maybe_to_np to smth stricter
        x = maybe_from_np(x, device=self.device)
        if not isinstance(x, (list, tuple)):
            x = (x,)
        return to_np(self.activation(self(*x)))

    def inference_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
        x, y = map(squeeze_first, (batch[:-self.n_val_targets], batch[-self.n_val_targets:]))
        return self.predictor([x], self.predict)[0], y

    def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]:
        if not self.optimizer and not self.lr_scheduler:
            raise NotImplementedError(
                "You must specify optimizer or lr_scheduler, "
                "or implement configure_optimizers method"
            )

        _optimizers = list(collapse([self.optimizer]))
        _lr_schedulers = list(collapse([self.lr_scheduler]))
        max_len = max(len(_optimizers), len(_lr_schedulers))
        _optimizers = list(padded(_optimizers, None, max_len))
        _lr_schedulers = list(padded(_lr_schedulers, None, max_len))

        optimizers = []
        lr_schedulers = []

        for optimizer, lr_scheduler in zip_equal(_optimizers, _lr_schedulers):
            if callable(lr_scheduler):
                if optimizer is None:
                    raise ValueError("The scheduler demands an Optimizer, but received None")
                lr_scheduler = lr_scheduler(optimizer)

            optimizers.append(optimizer if lr_scheduler is None else lr_scheduler.optimizer)
            if lr_scheduler is not None:
                lr_schedulers.append(lr_scheduler)

        if len(optimizers) < len(lr_schedulers):
            raise ValueError(
                "The number of optimizers must be greater or equal to the number of "
                f"lr_schedulers, got {len(optimizers)} and {len(lr_schedulers)}\n"
                f"Optimizers: f{optimizers}\n"
                f"Schedulers: f{lr_schedulers}\n"
            )

        return optimizers, lr_schedulers

training_step(batch, batch_idx) #

Source code in thunder/torch/core.py
def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT:
    x, y = batch[: -self.n_targets], batch[-self.n_targets:]
    return self.criterion(self(*x), *y)

validation_step(batch, batch_idx, dataloader_idx=0) #

Source code in thunder/torch/core.py
def validation_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
    return self.inference_step(batch, batch_idx, dataloader_idx)

test_step(batch, batch_idx, dataloader_idx=0) #

Source code in thunder/torch/core.py
def test_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
    return self.inference_step(batch, batch_idx, dataloader_idx)

predict_step(batch, batch_idx, dataloader_idx=0) #

Source code in thunder/torch/core.py
def predict_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
    return self.inference_step(batch, batch_idx, dataloader_idx)

inference_step(batch, batch_idx, dataloader_idx=0) #

Source code in thunder/torch/core.py
def inference_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
    x, y = map(squeeze_first, (batch[:-self.n_val_targets], batch[-self.n_val_targets:]))
    return self.predictor([x], self.predict)[0], y