Train a model using pytorch.
More...
|
| | __init__ (self, torch.nn.Module model, torch.utils.data.DataLoader train_loader, torch.utils.data.DataLoader val_loader, torch.optim.Optimizer optimizer, torch.nn.Module criterion, int num_epochs, torch.device device, str model_dir) |
| | Initialize the model trainer.
|
| |
| None | load_saved_model_dict (self, str model_file) |
| | Load the saved model.
|
| |
| None | load_saved_model (self, str model_file) |
| | Load the saved model.
|
| |
| None | load_saved_optimizer (self, str optimizer_path) |
| | Load the saved optimizer.
|
| |
| None | train (self) |
| | Train the model.
|
| |
| float | train_epoch (self) |
| | Train the model in batches.
|
| |
| float | validate_epoch (self, torch.utils.data.DataLoader data_loader) |
| | Validate the model in batches.
|
| |
| float | test (self, torch.utils.data.DataLoader test_loader) |
| | Test the model in batches.
|
| |
Train a model using pytorch.
Definition at line 39 of file trainer.py.
◆ __init__()
| __init__ |
( |
| self, |
|
|
torch.nn.Module | model, |
|
|
torch.utils.data.DataLoader | train_loader, |
|
|
torch.utils.data.DataLoader | val_loader, |
|
|
torch.optim.Optimizer | optimizer, |
|
|
torch.nn.Module | criterion, |
|
|
int | num_epochs, |
|
|
torch.device | device, |
|
|
str | model_dir ) |
Initialize the model trainer.
- Parameters
-
| model | torch.nn.Module |
| train_loader | loader for the training data |
| val_loader | loader for the validation data |
| optimizer | optimizer for the model |
| criterion | loss function |
| num_epochs | number of epochs |
| device | device to train the model |
| model_dir | dir to save the model |
Definition at line 54 of file trainer.py.
◆ load_saved_model()
| None load_saved_model |
( |
| self, |
|
|
str | model_file ) |
Load the saved model.
- Parameters
-
Definition at line 90 of file trainer.py.
◆ load_saved_model_dict()
| None load_saved_model_dict |
( |
| self, |
|
|
str | model_file ) |
Load the saved model.
- Parameters
-
Definition at line 81 of file trainer.py.
◆ load_saved_optimizer()
| None load_saved_optimizer |
( |
| self, |
|
|
str | optimizer_path ) |
Load the saved optimizer.
- Parameters
-
| optimizer_path | optimizer path |
Definition at line 99 of file trainer.py.
◆ test()
| float test |
( |
| self, |
|
|
torch.utils.data.DataLoader | test_loader ) |
Test the model in batches.
- Parameters
-
| test_loader | data loader for the test data |
- Returns
- test loss
Definition at line 270 of file trainer.py.
◆ train()
◆ train_epoch()
| float train_epoch |
( |
| self | ) |
|
Train the model in batches.
- Returns
- training loss
Definition at line 174 of file trainer.py.
◆ validate_epoch()
| float validate_epoch |
( |
| self, |
|
|
torch.utils.data.DataLoader | data_loader ) |
Validate the model in batches.
- Parameters
-
| data_loader | data loader for the validation data |
- Returns
- validation loss
Definition at line 229 of file trainer.py.
◆ criterion
◆ device
◆ model
◆ model_dir
◆ num_epochs
◆ optimizer
◆ start_time
◆ train_loader
| train_loader = train_loader |
◆ val_loader