23Train a model using pytorch
27from copy
import deepcopy
31__all__ = [
"TrainModel"]
37 Train a model using pytorch
43 model: torch.nn.Module,
44 train_loader: torch.utils.data.DataLoader,
45 val_loader: torch.utils.data.DataLoader,
46 optimizer: torch.optim.Optimizer,
47 criterion: torch.nn.Module,
53 Initialize the model trainer
56 model: torch.nn.Module
57 train_loader: loader for the training data
58 val_loader: loader for the validation data
59 optimizer: optimizer for the model
60 criterion: loss function
61 num_epochs: number of epochs
62 device: device to train the model
63 model_dir: dir to save the model
80 model_file: model file
82 self.
model.load_state_dict(torch.load(model_file))
89 model_file: model path
91 self.
model = torch.load(model_file)
95 Load the saved optimizer
98 optimizer_path: optimizer path
100 self.
optimizer = torch.load(optimizer_path)
103 def train(self) -> None:
108 best_val_loss = float(
"Inf")
109 best_train_loss = float(
"Inf")
112 train_loss_history = []
113 val_loss_history = []
114 start_time = time.time()
116 best_model_state_dict =
None
117 best_train_model_state_dict =
None
124 train_loss_history.append(train_loss)
125 torch.save(train_loss_history, self.
model_dir +
"/train_loss.pt")
127 print(f
"Epoch: {epoch + 1}/{self.num_epochs} ",
128 f
"Training Loss: {train_loss:.3e} ")
134 val_loss_history.append(val_loss)
135 torch.save(val_loss_history, self.
model_dir +
"/val_loss.pt")
139 if val_loss < best_val_loss:
140 best_val_loss = val_loss
141 best_model_state_dict = deepcopy(self.
model.state_dict())
144 print(f
"Epoch: {epoch + 1}/{self.num_epochs} ",
145 f
"Validation Loss: {val_loss:.3e} ",
146 f
"Best Validation Loss: {best_val_loss:.3e}")
148 if train_loss < best_train_loss:
149 best_train_loss = train_loss
150 best_train_model_state_dict = deepcopy(self.
model.state_dict())
155 torch.save({
"epoch": epoch,
156 "model_state_dict": self.
model.state_dict(),
157 "optimizer_state_dict": self.
optimizer.state_dict(),
159 self.
model_dir +
"/model_epoch" + str(epoch) +
".pt")
161 torch.save(best_model_state_dict, self.
model_dir +
"/model.pt")
162 torch.save(best_train_model_state_dict, self.
model_dir +
"/model_train.pt")
163 elapsed_time = time.time() - start_time
165 print(f
"Elapsed time: {elapsed_time / 60:.2f} minutes")
170 Train the model in batches
184 for batch_idx, (data, target)
in enumerate(self.
train_loader):
186 data, target = data.to(self.
device), target.to(self.
device)
188 if target.dim() == 3:
189 target = target.view(-1, target.shape[-1])
195 output = self.
model(data)
202 if batch_idx % 10 == 0:
203 print(f
"Batch: {batch_idx}, Loss: {loss:.3e} ")
212 train_loss += loss.item() * data.size(0)
213 num_dataset += data.size(0)
217 return train_loss / num_dataset
220 def validate_epoch(self, data_loader: torch.utils.data.DataLoader) -> float:
222 Validate the model in batches
225 data_loader: data loader for the validation data
238 with torch.no_grad():
239 for batch_idx, (data, target)
in enumerate(self.
val_loader):
241 data, target = data.to(self.
device), target.to(self.
device)
243 if target.dim() == 3:
244 target = target.view(-1, target.shape[-1])
247 output = self.
model(data)
253 val_loss += loss.item() * data.size(0)
254 num_dataset += data.size(0)
258 return val_loss / num_dataset
261 def test(self, test_loader: torch.utils.data.DataLoader) -> float:
263 Test the model in batches
266 test_loader: data loader for the test data
273 print(
"Test Loss: {:.3e} ".format(test_loss))
274 torch.save(test_loss, self.
model_dir +
"/test_loss.pt")
Train a model using pytorch.
float test(self, torch.utils.data.DataLoader test_loader)
Test the model in batches.
float validate_epoch(self, torch.utils.data.DataLoader data_loader)
Validate the model in batches.
__init__(self, torch.nn.Module model, torch.utils.data.DataLoader train_loader, torch.utils.data.DataLoader val_loader, torch.optim.Optimizer optimizer, torch.nn.Module criterion, int num_epochs, torch.device device, str model_dir)
Initialize the model trainer.
None load_saved_model_dict(self, str model_file)
Load the saved model.
None train(self)
Train the model.
None load_saved_model(self, str model_file)
Load the saved model.
None load_saved_optimizer(self, str optimizer_path)
Load the saved optimizer.
float train_epoch(self)
Train the model in batches.