ThunderModule
ThunderModule inherits everything from LightningModule and implements
essential methods for most common training pipelines.
From Lightning to Thunder
Most common pipelines are implemented in lightning in the following way:
from lightning import LightningModule
class Model(LightningModule):
def __init__(self):
self.architecture: nn.Module = ...
self.metrics = ... # smth like Dict[str, Callable]
def forward(self, *args, **kwargs):
return self.architecture(*args, **kwargs)
def criterion(self, x, y):
...
def training_step(self, batch, batch_idx):
x, y = batch
return self.criterion(self(x), y)
def validation_step(self, batch, batch_idx, dataloader_idx):
# forward and metrics computation or output preservation
...
def test_step(self, batch, batch_idx, dataloader_idx):
# forward and metrics computation or output preservation
...
def configure_optimizers(self):
return Adam(...), StepLR(...)
ThunderModule offers an implementation of necessary steps shown above.
from thunder import ThunderModule
architecture: nn.Module = ...
criterion = CrossEntropy()
optimizer = Adam(architecture.parameters())
scheduler = StepLR(optimizer)
model = ThunderModule(architecture, criterion,
optimizer=optimizer, lr_scheduler=scheduler)
Configuring Optimizers
For extra information see
this.
Lightning requires optimizers and learning rate policies
to be defined inside configure_optimizers
method.
Using ThunderModule allows you to pass the following configurations of
optimizers and learning rate schedulers:
from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Adam
architecture = nn.Linear(2, 2)
No scheduling
optimizer = Adam(architecture.parameters())
model = ThunderModule(..., optimizer=optimizer)
Defining optimizer and scheduler
optimizer = Adam(architecture.parameters())
lr_scheduler = LRScheduler(optimizer)
model = ThunderModule(..., optimizer=optimizer, lr_scheduler=lr_scheduler)
Defining no optimizer
lr_scheduler = LRScheduler(optimizer)
model = ThunderModule(..., lr_scheduler=lr_scheduler)
Multiple Optimizers
Thunder just as lightning supports configuration with more than 1 optimizer. If such configuration is to be used, manual optimization is required.
Guide on manual optimization
In thunder you can pass lists of optimizers and schedulers to ThunderModule.
class ThunderModuleManual(ThunderModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
optimizers = [Adam(module1.parameters()), Adam(module2.parameters())]
lr_schedulers = [Scheduler(opt) for opt in optimizers]
model = ThunderModuleManual(..., optimizer=optimizers, lr_scheduler=lr_schedulers)
Thunder Policies
As shown above, torch schedulers require optimizer(s) to be passed to them before
they are given to ThunderModule. It is not very convenient, and also they lack some basic
functionality.
You can use thunder policies just like torch schedulers:
from thunder.policy import Switch
optimizers = [Adam(module1.parameters()), Adam(module2.parameters())]
lr_schedulers = [Switch({1: 0.001}), Switch({2: 0.001})]
model = ThunderModuleManual(..., optimizer=optimizers, lr_scheduler=lr_schedulers)
For extra information see Thunder Policies Docs.
Inference
During inference step, ThunderModule uses Predictors in order to preprocess data and
make inverse transforms after passing data through the model. Default predictor
is just an identity function.
For more on predictors see Thunder Predictors Docs.
Batch Transfer
ThunderModule transfers training batches to device by default. However, during
inference batch remains on the device, on which it was received from data loader.
Transferring happens later in the inference_step
, which is invoked in
validation_step
, test_step
and predict_step
.
Reference
thunder.torch.core.ThunderModule
Bases: LightningModule
Source code in thunder/torch/core.py
| class ThunderModule(LightningModule):
def __init__(
self,
architecture: nn.Module,
criterion: Callable,
n_targets: int = 1,
activation: Callable = identity,
optimizer: Union[List[Optimizer], Optimizer] = None,
lr_scheduler: Union[List[LRScheduler], LRScheduler] = None,
predictor: BasePredictor = None,
n_val_targets: int = None
):
"""
Parameters
----------
architecture: nn.Module
Model architecture used to conduct forward pass.
criterion: Callable
Criterion to optimize.
n_targets: int
Number of target values in train and inference batches, if negative, then ...
activation: Callable
Final activation function for inference, identity by default.
optimizer: Union[List[Optimizer], Optimizer]
Optimizers.
lr_scheduler: Union[List[LRScheduler], LRScheduler]
Learning Rate policies.
predictor: BasePredictor.
Predictor for inference.
n_val_targets: int
Number of target values for inference, if set to None assumes value of `n_targets`.
"""
super().__init__()
self.architecture = architecture
self.criterion = criterion
self.n_targets = n_targets
self.n_val_targets = n_targets if n_val_targets is None else n_val_targets
self.activation = activation
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.predictor = predictor if predictor else Predictor()
def transfer_batch_to_device(self, batch: Tuple, device: torch.device, dataloader_idx: int) -> Any:
if self.trainer.state.stage != "train":
return batch
return super().transfer_batch_to_device(maybe_from_np(batch, device=device), device, dataloader_idx)
def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.architecture(*args, **kwargs)
def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT:
x, y = batch[: -self.n_targets], batch[-self.n_targets:]
return self.criterion(self(*x), *y)
def validation_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
return self.inference_step(batch, batch_idx, dataloader_idx)
def test_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
return self.inference_step(batch, batch_idx, dataloader_idx)
def predict_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self.inference_step(batch, batch_idx, dataloader_idx)
def predict(self, x) -> STEP_OUTPUT:
# TODO: do we need super(). ...?, also consider changing maybe_to_np to smth stricter
x = maybe_from_np(x, device=self.device)
if not isinstance(x, (list, tuple)):
x = (x,)
return to_np(self.activation(self(*x)))
def inference_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
x, y = map(squeeze_first, (batch[:-self.n_val_targets], batch[-self.n_val_targets:]))
return self.predictor([x], self.predict)[0], y
def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]:
if not self.optimizer and not self.lr_scheduler:
raise NotImplementedError(
"You must specify optimizer or lr_scheduler, "
"or implement configure_optimizers method"
)
_optimizers = list(collapse([self.optimizer]))
_lr_schedulers = list(collapse([self.lr_scheduler]))
max_len = max(len(_optimizers), len(_lr_schedulers))
_optimizers = list(padded(_optimizers, None, max_len))
_lr_schedulers = list(padded(_lr_schedulers, None, max_len))
optimizers = []
lr_schedulers = []
for optimizer, lr_scheduler in zip_equal(_optimizers, _lr_schedulers):
if callable(lr_scheduler):
if optimizer is None:
raise ValueError("The scheduler demands an Optimizer, but received None")
lr_scheduler = lr_scheduler(optimizer)
optimizers.append(optimizer if lr_scheduler is None else lr_scheduler.optimizer)
if lr_scheduler is not None:
lr_schedulers.append(lr_scheduler)
if len(optimizers) < len(lr_schedulers):
raise ValueError(
"The number of optimizers must be greater or equal to the number of "
f"lr_schedulers, got {len(optimizers)} and {len(lr_schedulers)}\n"
f"Optimizers: f{optimizers}\n"
f"Schedulers: f{lr_schedulers}\n"
)
return optimizers, lr_schedulers
|
training_step(batch, batch_idx)
Source code in thunder/torch/core.py
| def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> STEP_OUTPUT:
x, y = batch[: -self.n_targets], batch[-self.n_targets:]
return self.criterion(self(*x), *y)
|
validation_step(batch, batch_idx, dataloader_idx=0)
Source code in thunder/torch/core.py
| def validation_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
return self.inference_step(batch, batch_idx, dataloader_idx)
|
test_step(batch, batch_idx, dataloader_idx=0)
Source code in thunder/torch/core.py
| def test_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
return self.inference_step(batch, batch_idx, dataloader_idx)
|
predict_step(batch, batch_idx, dataloader_idx=0)
Source code in thunder/torch/core.py
| def predict_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self.inference_step(batch, batch_idx, dataloader_idx)
|
inference_step(batch, batch_idx, dataloader_idx=0)
Source code in thunder/torch/core.py
| def inference_step(self, batch: Tuple, batch_idx: int, dataloader_idx: int = 0) -> Any:
x, y = map(squeeze_first, (batch[:-self.n_val_targets], batch[-self.n_val_targets:]))
return self.predictor([x], self.predict)[0], y
|