Coverage Control Library
Loading...
Searching...
No Matches
cnn.py
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22"""
23Implements an architecture consisting of a multi-layer CNN followed by an MLP, according to parameters specified in the input config
24"""
25import torch
26from torch_geometric.nn import MLP
27
28from .cnn_backbone import CNNBackBone
29from .config_parser import CNNConfigParser
30
31__all__ = ["CNN"]
32
33
34
35class CNN(torch.nn.Module, CNNConfigParser):
36 """
37 Implements an architecture consisting of a multi-layer CNN followed by an MLP, according to parameters specified in the input config.
38 """
39
40 def __init__(self, config: dict):
41 super().__init__()
42 self.parse(config)
43
45 self.mlp = MLP(
46 [
51 ]
52 )
53 self.linear = torch.nn.Linear(self.latent_sizelatent_size, self.output_dim)
54
55 def forward(self, x: torch.Tensor) -> torch.Tensor:
56 """
57 Forward pass through the network
58
59 Args:
60 x: Input tensor
61
62 Returns:
63 Output tensor
64 """
65 x = self.cnn_backbone(x)
66 x = self.mlp(x)
67 x = self.linear(x)
68
69 return x
70
71 def load_cpp_model(self, model_path: str) -> None:
72 """
73 Loads a model saved in cpp jit format
74 """
75 jit_model = torch.jit.load(model_path)
76 self.load_state_dict(jit_model.state_dict(), strict=False)
77
78 def load_model(self, model_path: str) -> None:
79 """
80 Loads a model saved in pytorch format
81 """
82 self.load_state_dict(torch.load(model_path), strict=False)
torch.Tensor forward(self, torch.Tensor x)
Forward pass through the network.
Definition cnn.py:64
None load_cpp_model(self, str model_path)
Loads a model saved in cpp jit format.
Definition cnn.py:74
None load_model(self, str model_path)
Loads a model saved in pytorch format.
Definition cnn.py:81
Implements a multi-layer convolutional neural network, with leaky-ReLU non-linearities between layers...
None parse(self, dict config)
Parse the configuration for the CNN model.