Coverage Control Library
Loading...
Searching...
No Matches
trainer.py
Go to the documentation of this file.
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22"""
23Train a model using pytorch
24"""
25
26import time
27from copy import deepcopy
28
29import torch
30
31__all__ = ["TrainModel"]
32
33
34
35class TrainModel:
36 """
37 Train a model using pytorch
38
39 """
40
41 def __init__(
42 self,
43 model: torch.nn.Module,
44 train_loader: torch.utils.data.DataLoader,
45 val_loader: torch.utils.data.DataLoader,
46 optimizer: torch.optim.Optimizer,
47 criterion: torch.nn.Module,
48 num_epochs: int,
49 device: torch.device,
50 model_dir: str,
51 ):
52 """
53 Initialize the model trainer
55 Args:
56 model: torch.nn.Module
57 train_loader: loader for the training data
58 val_loader: loader for the validation data
59 optimizer: optimizer for the model
60 criterion: loss function
61 num_epochs: number of epochs
62 device: device to train the model
63 model_dir: dir to save the model
64 """
65 self.model = model
66 self.train_loader = train_loader
67 self.val_loader = val_loader
68 self.optimizer = optimizer
69 self.criterion = criterion
70 self.num_epochs = num_epochs
71 self.device = device
72 self.model_dir = model_dir
73 self.start_time = time.time()
74
75 def load_saved_model_dict(self, model_file: str) -> None:
76 """
77 Load the saved model
78
79 Args:
80 model_file: model file
81 """
82 self.model.load_state_dict(torch.load(model_file))
83
84 def load_saved_model(self, model_file: str) -> None:
85 """
86 Load the saved model
87
88 Args:
89 model_file: model path
90 """
91 self.model = torch.load(model_file)
92
93 def load_saved_optimizer(self, optimizer_path: str) -> None:
94 """
95 Load the saved optimizer
96
97 Args:
98 optimizer_path: optimizer path
99 """
100 self.optimizer = torch.load(optimizer_path)
101
102 # Train in batches, save the best model using the validation set
103 def train(self) -> None:
104 """
105 Train the model
106 """
107 # Initialize the best validation loss
108 best_val_loss = float("Inf")
109 best_train_loss = float("Inf")
110
111 # Initialize the loss history
112 train_loss_history = []
113 val_loss_history = []
114 start_time = time.time()
115
116 best_model_state_dict = None
117 best_train_model_state_dict = None
118
119 # Train the model
120
121 for epoch in range(self.num_epochs):
122 # Training
123 train_loss = self.train_epoch()
124 train_loss_history.append(train_loss)
125 torch.save(train_loss_history, self.model_dir + "/train_loss.pt")
126 # Print the loss
127 print(f"Epoch: {epoch + 1}/{self.num_epochs} ",
128 f"Training Loss: {train_loss:.3e} ")
129
130 # Validation
131
132 if self.val_loader is not None:
133 val_loss = self.validate_epoch(self.val_loader)
134 val_loss_history.append(val_loss)
135 torch.save(val_loss_history, self.model_dir + "/val_loss.pt")
136
137 # Save the best model
138
139 if val_loss < best_val_loss:
140 best_val_loss = val_loss
141 best_model_state_dict = deepcopy(self.model.state_dict())
142 # torch.save(self.model.state_dict(), self.model_dir + "/model.pt")
143 # torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer.pt")
144 print(f"Epoch: {epoch + 1}/{self.num_epochs} ",
145 f"Validation Loss: {val_loss:.3e} ",
146 f"Best Validation Loss: {best_val_loss:.3e}")
147
148 if train_loss < best_train_loss:
149 best_train_loss = train_loss
150 best_train_model_state_dict = deepcopy(self.model.state_dict())
151 # torch.save(self.model.state_dict(), self.model_dir + "/model_curr.pt")
152 # torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer_curr.pt")
153
154 if epoch % 5 == 0:
155 torch.save({"epoch": epoch,
156 "model_state_dict": self.model.state_dict(),
157 "optimizer_state_dict": self.optimizer.state_dict(),
158 "loss": train_loss},
159 self.model_dir + "/model_epoch" + str(epoch) + ".pt")
160
161 torch.save(best_model_state_dict, self.model_dir + "/model.pt")
162 torch.save(best_train_model_state_dict, self.model_dir + "/model_train.pt")
163 elapsed_time = time.time() - start_time
164 # Print elapsed time in minutes
165 print(f"Elapsed time: {elapsed_time / 60:.2f} minutes")
166
167 # Train the model in batches
168 def train_epoch(self) -> float:
169 """
170 Train the model in batches
171
172 Returns:
173 training loss
174 """
175 # Initialize the training loss
176 train_loss = 0.0
177
178 # Set the model to training mode
179 self.model.train()
180
181 num_dataset = 0
182 # Train the model in batches
183
184 for batch_idx, (data, target) in enumerate(self.train_loader):
185 # Move the data to the device
186 data, target = data.to(self.device), target.to(self.device)
187
188 if target.dim() == 3:
189 target = target.view(-1, target.shape[-1])
190
191 # Clear the gradients
192 self.optimizer.zero_grad()
193
194 # Forward propagation
195 output = self.model(data)
196
197 # Calculate the loss
198 loss = self.criterion(output, target)
199
200 # Print batch number and loss
201
202 if batch_idx % 10 == 0:
203 print(f"Batch: {batch_idx}, Loss: {loss:.3e} ")
204
205 # Backward propagation
206 loss.backward()
207
208 # Update the parameters
209 self.optimizer.step()
210
211 # Update the training loss
212 train_loss += loss.item() * data.size(0)
213 num_dataset += data.size(0)
214
215 # Return the training loss
216
217 return train_loss / num_dataset
218
219 # Validate the model in batches
220 def validate_epoch(self, data_loader: torch.utils.data.DataLoader) -> float:
221 """
222 Validate the model in batches
223
224 Args:
225 data_loader: data loader for the validation data
226
227 Returns:
228 validation loss
229 """
230 # Initialize the validation loss
231 val_loss = 0.0
232
233 # Set the model to evaluation mode
234 self.model.eval()
235
236 num_dataset = 0
237 # Validate the model in batches
238 with torch.no_grad():
239 for batch_idx, (data, target) in enumerate(self.val_loader):
240 # Move the data to the device
241 data, target = data.to(self.device), target.to(self.device)
242
243 if target.dim() == 3:
244 target = target.view(-1, target.shape[-1])
245
246 # Forward propagation
247 output = self.model(data)
248
249 # Calculate the loss
250 loss = self.criterion(output, target)
251
252 # Update the validation loss
253 val_loss += loss.item() * data.size(0)
254 num_dataset += data.size(0)
255
256 # Return the validation loss
257
258 return val_loss / num_dataset
259
260 # Test the model in batches
261 def test(self, test_loader: torch.utils.data.DataLoader) -> float:
262 """
263 Test the model in batches
264
265 Args:
266 test_loader: data loader for the test data
267
268 Returns:
269 test loss
270 """
271
272 test_loss = self.validate_epoch(test_loader)
273 print("Test Loss: {:.3e} ".format(test_loss))
274 torch.save(test_loss, self.model_dir + "/test_loss.pt")
275
float test(self, torch.utils.data.DataLoader test_loader)
Test the model in batches.
Definition trainer.py:270
float validate_epoch(self, torch.utils.data.DataLoader data_loader)
Validate the model in batches.
Definition trainer.py:229
__init__(self, torch.nn.Module model, torch.utils.data.DataLoader train_loader, torch.utils.data.DataLoader val_loader, torch.optim.Optimizer optimizer, torch.nn.Module criterion, int num_epochs, torch.device device, str model_dir)
Initialize the model trainer.
Definition trainer.py:64
None load_saved_model_dict(self, str model_file)
Load the saved model.
Definition trainer.py:81
None load_saved_model(self, str model_file)
Load the saved model.
Definition trainer.py:90
None load_saved_optimizer(self, str optimizer_path)
Load the saved optimizer.
Definition trainer.py:99
float train_epoch(self)
Train the model in batches.
Definition trainer.py:174