Train a model using pytorch.
More...
|
| __init__ (self, torch.nn.Module model, torch.utils.data.DataLoader train_loader, torch.utils.data.DataLoader val_loader, torch.optim.Optimizer optimizer, torch.nn.Module criterion, int num_epochs, torch.device device, str model_dir) |
| Initialize the model trainer.
|
|
None | load_saved_model_dict (self, str model_file) |
| Load the saved model.
|
|
None | load_saved_model (self, str model_file) |
| Load the saved model.
|
|
None | load_saved_optimizer (self, str optimizer_path) |
| Load the saved optimizer.
|
|
None | train (self) |
| Train the model.
|
|
float | train_epoch (self) |
| Train the model in batches.
|
|
float | validate_epoch (self, torch.utils.data.DataLoader data_loader) |
| Validate the model in batches.
|
|
float | test (self, torch.utils.data.DataLoader test_loader) |
| Test the model in batches.
|
|
Train a model using pytorch.
Definition at line 39 of file trainer.py.
◆ __init__()
__init__ |
( |
| self, |
|
|
torch.nn.Module | model, |
|
|
torch.utils.data.DataLoader | train_loader, |
|
|
torch.utils.data.DataLoader | val_loader, |
|
|
torch.optim.Optimizer | optimizer, |
|
|
torch.nn.Module | criterion, |
|
|
int | num_epochs, |
|
|
torch.device | device, |
|
|
str | model_dir ) |
Initialize the model trainer.
- Parameters
-
model | torch.nn.Module |
train_loader | loader for the training data |
val_loader | loader for the validation data |
optimizer | optimizer for the model |
criterion | loss function |
num_epochs | number of epochs |
device | device to train the model |
model_dir | dir to save the model |
Definition at line 54 of file trainer.py.
◆ load_saved_model()
None load_saved_model |
( |
| self, |
|
|
str | model_file ) |
Load the saved model.
- Parameters
-
Definition at line 90 of file trainer.py.
◆ load_saved_model_dict()
None load_saved_model_dict |
( |
| self, |
|
|
str | model_file ) |
Load the saved model.
- Parameters
-
Definition at line 81 of file trainer.py.
◆ load_saved_optimizer()
None load_saved_optimizer |
( |
| self, |
|
|
str | optimizer_path ) |
Load the saved optimizer.
- Parameters
-
optimizer_path | optimizer path |
Definition at line 99 of file trainer.py.
◆ test()
float test |
( |
| self, |
|
|
torch.utils.data.DataLoader | test_loader ) |
Test the model in batches.
- Parameters
-
test_loader | data loader for the test data |
- Returns
- test loss
Definition at line 270 of file trainer.py.
◆ train()
◆ train_epoch()
float train_epoch |
( |
| self | ) |
|
Train the model in batches.
- Returns
- training loss
Definition at line 174 of file trainer.py.
◆ validate_epoch()
float validate_epoch |
( |
| self, |
|
|
torch.utils.data.DataLoader | data_loader ) |
Validate the model in batches.
- Parameters
-
data_loader | data loader for the validation data |
- Returns
- validation loss
Definition at line 229 of file trainer.py.
◆ criterion
◆ device
◆ model
◆ model_dir
◆ num_epochs
◆ optimizer
◆ start_time
◆ train_loader
train_loader = train_loader |
◆ val_loader