Coverage Control Library
Loading...
Searching...
No Matches
TrainModel Class Reference

Train a model using pytorch. More...

Public Member Functions

 __init__ (self, torch.nn.Module model, torch.utils.data.DataLoader train_loader, torch.utils.data.DataLoader val_loader, torch.optim.Optimizer optimizer, torch.nn.Module criterion, int epochs, torch.device device, str model_file, str optimizer_file)
 Initialize the model trainer.
 
None load_saved_model_dict (self, str model_path)
 Load the saved model.
 
None load_saved_model (self, str model_path)
 Load the saved model.
 
None load_saved_optimizer (self, str optimizer_path)
 Load the saved optimizer.
 
None train (self)
 Train the model.
 
float TrainEpoch (self)
 Train the model in batches.
 
float validate_epoch (self, torch.utils.data.DataLoader data_loader)
 Validate the model in batches.
 
float Test (self, torch.utils.data.DataLoader test_loader)
 Test the model in batches.
 

Public Attributes

 model
 
 train_loader
 
 val_loader
 
 optimizer
 
 criterion
 
 epochs
 
 device
 
 model_file
 
 optimizer_file
 

Detailed Description

Train a model using pytorch.

Definition at line 38 of file trainer.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
torch.nn.Module model,
torch.utils.data.DataLoader train_loader,
torch.utils.data.DataLoader val_loader,
torch.optim.Optimizer optimizer,
torch.nn.Module criterion,
int epochs,
torch.device device,
str model_file,
str optimizer_file )

Initialize the model trainer.

Parameters
modeltorch.nn.Module
train_loaderloader for the training data
val_loaderloader for the validation data
optimizeroptimizer for the model
criterionloss function
epochsnumber of epochs
devicedevice to train the model
model_filefile to save the model
optimizer_filefile to save the optimizer

Definition at line 54 of file trainer.py.

Member Function Documentation

◆ load_saved_model()

None load_saved_model ( self,
str model_path )

Load the saved model.

Parameters
model_pathmodel path

Definition at line 91 of file trainer.py.

◆ load_saved_model_dict()

None load_saved_model_dict ( self,
str model_path )

Load the saved model.

Parameters
model_pathmodel path

Definition at line 82 of file trainer.py.

◆ load_saved_optimizer()

None load_saved_optimizer ( self,
str optimizer_path )

Load the saved optimizer.

Parameters
optimizer_pathoptimizer path

Definition at line 100 of file trainer.py.

◆ Test()

float Test ( self,
torch.utils.data.DataLoader test_loader )

Test the model in batches.

Parameters
test_loaderdata loader for the test data
Returns
test loss

Definition at line 269 of file trainer.py.

◆ train()

None train ( self)

Train the model.

Definition at line 107 of file trainer.py.

◆ TrainEpoch()

float TrainEpoch ( self)

Train the model in batches.

Returns
training loss

Definition at line 169 of file trainer.py.

◆ validate_epoch()

float validate_epoch ( self,
torch.utils.data.DataLoader data_loader )

Validate the model in batches.

Parameters
data_loaderdata loader for the validation data
Returns
validation loss

Definition at line 226 of file trainer.py.

Member Data Documentation

◆ criterion

criterion

Definition at line 70 of file trainer.py.

◆ device

device

Definition at line 72 of file trainer.py.

◆ epochs

epochs

Definition at line 71 of file trainer.py.

◆ model

model

Definition at line 66 of file trainer.py.

◆ model_file

model_file

Definition at line 73 of file trainer.py.

◆ optimizer

optimizer

Definition at line 69 of file trainer.py.

◆ optimizer_file

optimizer_file

Definition at line 74 of file trainer.py.

◆ train_loader

train_loader

Definition at line 67 of file trainer.py.

◆ val_loader

val_loader

Definition at line 68 of file trainer.py.