Coverage Control Library
Loading...
Searching...
No Matches
cnn_backbone.py
Go to the documentation of this file.
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22"""
23Implements a multi-layer convolutional neural network
24"""
25import torch
26
27from .config_parser import CNNConfigParser
28
29
30class CNNBackBone(torch.nn.Module, CNNConfigParser):
31 """
32 Implements a multi-layer convolutional neural network,
33 with leaky-ReLU non-linearities between layers,
34 according to hyperparameters specified in the config
35 """
36
37 def __init__(self, config: dict):
38 super().__init__()
39 self.parse(config)
40
41 self.add_module(
42 "conv0",
43 torch.nn.Conv2d(
45 ),
46 )
47 self.add_module("batch_norm0", torch.nn.BatchNorm2d(self.latent_sizelatent_size))
48
49 for layer in range(self.num_layersnum_layers - 1):
50 self.add_module(
51 f"conv{layer + 1}",
52 torch.nn.Conv2d(
54 ),
55 )
56 self.add_module(
57 f"batch_norm{layer + 1}", torch.nn.BatchNorm2d(self.latent_sizelatent_size)
58 )
59
60 self.flatten_size = (
62 * (self.image_size - self.num_layersnum_layers * (self.kernel_size - 1)) ** 2
63 )
64
65 self.add_module(
66 "linear_1", torch.nn.Linear(self.flatten_size, self.latent_sizelatent_size)
67 )
68 # self.add_module("linear_2", torch.nn.Linear(self.latent_size, self.backbone_output_dim))
69 # self.add_module("linear_3", torch.nn.Linear(2 * self.output_dim, self.output_dim))
70
71 def forward(self, x: torch.Tensor) -> torch.Tensor:
72 """
73 Forward pass through the network
74
75 Args:
76 x: input tensor
77 """
78 for layer in range(self.num_layersnum_layers):
79 x = torch.nn.functional.leaky_relu(
80 self._modules[f"batch_norm{layer}"](self._modules[f"conv{layer}"](x))
81 )
82 # x = self._modules["conv{}".format(layer)](x)
83 # x = self._modules["batch_norm{}".format(layer)](x)
84 # x = torch.nn.functional.leaky_relu(x)
85
86 x = x.flatten(1)
87 x = torch.nn.functional.leaky_relu(self.linear_1(x))
88
89 return x
90 # x = torch.nn.functional.leaky_relu(self.linear_2(x))
91 # x = self.linear_3(x)
92 # output = x.reshape(x.shape[0], self.latent_size, -1)
93 # output, _ = torch.max(output, dim=2)
94 # return output
Implements a multi-layer convolutional neural network, with leaky-ReLU non-linearities between layers...
torch.Tensor forward(self, torch.Tensor x)
Forward pass through the network.
This file contains the configuration parser for the models.
None parse(self, dict config)
Parse the configuration for the CNN model.