Coverage Control Library
Loading...
Searching...
No Matches
TrainModel Class Reference

Train a model using pytorch. More...

Public Member Functions

 __init__ (self, torch.nn.Module model, torch.utils.data.DataLoader train_loader, torch.utils.data.DataLoader val_loader, torch.optim.Optimizer optimizer, torch.nn.Module criterion, int num_epochs, torch.device device, str model_dir)
 Initialize the model trainer.
 
None load_saved_model_dict (self, str model_file)
 Load the saved model.
 
None load_saved_model (self, str model_file)
 Load the saved model.
 
None load_saved_optimizer (self, str optimizer_path)
 Load the saved optimizer.
 
None train (self)
 Train the model.
 
float train_epoch (self)
 Train the model in batches.
 
float validate_epoch (self, torch.utils.data.DataLoader data_loader)
 Validate the model in batches.
 
float test (self, torch.utils.data.DataLoader test_loader)
 Test the model in batches.
 

Public Attributes

 model = model
 
 train_loader = train_loader
 
 val_loader = val_loader
 
 optimizer = optimizer
 
 criterion = criterion
 
 num_epochs = num_epochs
 
 device = device
 
 model_dir = model_dir
 
 start_time = time.time()
 

Detailed Description

Train a model using pytorch.

Definition at line 39 of file trainer.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
torch.nn.Module model,
torch.utils.data.DataLoader train_loader,
torch.utils.data.DataLoader val_loader,
torch.optim.Optimizer optimizer,
torch.nn.Module criterion,
int num_epochs,
torch.device device,
str model_dir )

Initialize the model trainer.

Parameters
modeltorch.nn.Module
train_loaderloader for the training data
val_loaderloader for the validation data
optimizeroptimizer for the model
criterionloss function
num_epochsnumber of epochs
devicedevice to train the model
model_dirdir to save the model

Definition at line 54 of file trainer.py.

Member Function Documentation

◆ load_saved_model()

None load_saved_model ( self,
str model_file )

Load the saved model.

Parameters
model_filemodel path

Definition at line 90 of file trainer.py.

◆ load_saved_model_dict()

None load_saved_model_dict ( self,
str model_file )

Load the saved model.

Parameters
model_filemodel file

Definition at line 81 of file trainer.py.

◆ load_saved_optimizer()

None load_saved_optimizer ( self,
str optimizer_path )

Load the saved optimizer.

Parameters
optimizer_pathoptimizer path

Definition at line 99 of file trainer.py.

◆ test()

float test ( self,
torch.utils.data.DataLoader test_loader )

Test the model in batches.

Parameters
test_loaderdata loader for the test data
Returns
test loss

Definition at line 270 of file trainer.py.

◆ train()

None train ( self)

Train the model.

Definition at line 106 of file trainer.py.

◆ train_epoch()

float train_epoch ( self)

Train the model in batches.

Returns
training loss

Definition at line 174 of file trainer.py.

◆ validate_epoch()

float validate_epoch ( self,
torch.utils.data.DataLoader data_loader )

Validate the model in batches.

Parameters
data_loaderdata loader for the validation data
Returns
validation loss

Definition at line 229 of file trainer.py.

Member Data Documentation

◆ criterion

criterion = criterion

Definition at line 69 of file trainer.py.

◆ device

device = device

Definition at line 71 of file trainer.py.

◆ model

model = model

Definition at line 65 of file trainer.py.

◆ model_dir

model_dir = model_dir

Definition at line 72 of file trainer.py.

◆ num_epochs

num_epochs = num_epochs

Definition at line 70 of file trainer.py.

◆ optimizer

optimizer = optimizer

Definition at line 68 of file trainer.py.

◆ start_time

start_time = time.time()

Definition at line 73 of file trainer.py.

◆ train_loader

train_loader = train_loader

Definition at line 66 of file trainer.py.

◆ val_loader

val_loader = val_loader

Definition at line 67 of file trainer.py.