Coverage Control Library
Loading...
Searching...
No Matches
cnn_backbone.py
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22"""
23Implements a multi-layer convolutional neural network
24"""
25import torch
26
27from .config_parser import CNNConfigParser
28
29
30class CNNBackBone(torch.nn.Module, CNNConfigParser):
31 """
32 Implements a multi-layer convolutional neural network,
33 with leaky-ReLU non-linearities between layers,
34 according to hyperparameters specified in the config
35 """
36
37 def __init__(self, config: dict):
38 super().__init__()
39 self.parse(config)
40
41 self.add_module(
42 "conv0",
43 torch.nn.Conv2d(
44 self.input_diminput_dim, self.latent_sizelatent_size, kernel_size=self.kernel_size
45 ),
46 )
47 self.add_module("batch_norm0", torch.nn.BatchNorm2d(self.latent_sizelatent_size))
48
49 for layer in range(self.num_layersnum_layers - 1):
50 self.add_module(
51 f"conv{layer + 1}",
52 torch.nn.Conv2d(
54 ),
55 )
56 self.add_module(
57 f"batch_norm{layer + 1}", torch.nn.BatchNorm2d(self.latent_sizelatent_size)
58 )
59
60 self.flatten_size = (
62 * (self.image_size - self.num_layersnum_layers * (self.kernel_size - 1)) ** 2
63 )
64
65 self.add_module(
66 "linear_1", torch.nn.Linear(self.flatten_size, self.latent_sizelatent_size)
67 )
68 # self.add_module("linear_2", torch.nn.Linear(self.latent_size, self.backbone_output_dim))
69 # self.add_module("linear_3", torch.nn.Linear(2 * self.output_dim, self.output_dim))
70
71 def forward(self, x: torch.Tensor) -> torch.Tensor:
72 """
73 Forward pass through the network
74
75 Args:
76 x: input tensor
77 """
78 for layer in range(self.num_layersnum_layers):
79 x = torch.nn.functional.leaky_relu(
80 self._modules[f"batch_norm{layer}"](self._modules[f"conv{layer}"](x))
81 )
82 # x = self._modules["conv{}".format(layer)](x)
83 # x = self._modules["batch_norm{}".format(layer)](x)
84 # x = torch.nn.functional.leaky_relu(x)
85
86 x = x.flatten(1)
87 x = torch.nn.functional.leaky_relu(self.linear_1(x))
88
89 return x
90 # x = torch.nn.functional.leaky_relu(self.linear_2(x))
91 # x = self.linear_3(x)
92 # output = x.reshape(x.shape[0], self.latent_size, -1)
93 # output, _ = torch.max(output, dim=2)
94 # return output
torch.Tensor forward(self, torch.Tensor x)
Forward pass through the network.
None parse(self, dict config)
Parse the configuration for the CNN model.