Coverage Control Library
Loading...
Searching...
No Matches
gnn_backbone.py
Go to the documentation of this file.
1
# This file is part of the CoverageControl library
2
#
3
# Author: Saurav Agarwal
4
# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5
# Repository: https://github.com/KumarRobotics/CoverageControl
6
#
7
# Copyright (c) 2024, Saurav Agarwal
8
#
9
# The CoverageControl library is free software: you can redistribute it and/or
10
# modify it under the terms of the GNU General Public License as published by
11
# the Free Software Foundation, either version 3 of the License, or (at your
12
# option) any later version.
13
#
14
# The CoverageControl library is distributed in the hope that it will be
15
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17
# Public License for more details.
18
#
19
# You should have received a copy of the GNU General Public License along with
20
# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22
"""
23
Implements a GNN architecture.
24
"""
25
import
torch
26
import
torch_geometric
27
28
from
.config_parser
import
GNNConfigParser
29
30
31
class
GNNBackBone
(torch.nn.Module,
GNNConfigParser
):
32
"""
33
Implements a GNN architecture,
34
according to hyperparameters specified in the input config
35
"""
36
37
def
__init__
(self, config, input_dim=None):
38
super().
__init__
()
39
40
self.
parse
(config)
41
42
if
input_dim
is
not
None
:
43
self.
input_dim
input_dim
= input_dim
44
45
self.add_module(
46
"graph_conv_0"
,
47
torch_geometric.nn.TAGConv(
48
in_channels=self.
input_dim
input_dim
,
49
out_channels=self.
latent_size
,
50
K=self.
num_hops
,
51
),
52
)
53
54
for
i
in
range(1, self.
num_layers
num_layers
):
55
self.add_module(
56
"graph_conv_{}"
.format(i),
57
torch_geometric.nn.TAGConv(
58
in_channels=self.
latent_size
,
59
out_channels=self.
latent_size
,
60
K=self.
num_hops
,
61
),
62
)
63
64
def
forward
(
65
self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight=
None
66
) -> torch.Tensor:
67
for
i
in
range(self.
num_layers
num_layers
):
68
x = self._modules[
"graph_conv_{}"
.format(i)](x, edge_index, edge_weight)
69
x = torch.relu(x)
70
71
return
x
coverage_control.nn.models.config_parser.GNNConfigParser
Class to parse the configuration for the GNN model.
Definition
config_parser.py:60
coverage_control.nn.models.config_parser.GNNConfigParser.parse
None parse(self, dict config)
Parse the configuration for the GNN model.
Definition
config_parser.py:73
coverage_control.nn.models.config_parser.GNNConfigParser.input_dim
input_dim
Definition
config_parser.py:64
coverage_control.nn.models.config_parser.GNNConfigParser.num_hops
num_hops
Definition
config_parser.py:66
coverage_control.nn.models.config_parser.GNNConfigParser.latent_size
latent_size
Definition
config_parser.py:68
coverage_control.nn.models.config_parser.GNNConfigParser.num_layers
num_layers
Definition
config_parser.py:67
coverage_control.nn.models.gnn_backbone.GNNBackBone
Implements a GNN architecture, according to hyperparameters specified in the input config.
Definition
gnn_backbone.py:35
coverage_control.nn.models.gnn_backbone.GNNBackBone.forward
torch.Tensor forward(self, torch.Tensor x, torch.Tensor edge_index, edge_weight=None)
Definition
gnn_backbone.py:66
coverage_control.nn.models.gnn_backbone.GNNBackBone.input_dim
input_dim
Definition
gnn_backbone.py:43
coverage_control.nn.models.gnn_backbone.GNNBackBone.__init__
__init__(self, config, input_dim=None)
Definition
gnn_backbone.py:37
coverage_control.nn.models.gnn_backbone.GNNBackBone.num_layers
num_layers
Definition
gnn_backbone.py:54
python
coverage_control
nn
models
gnn_backbone.py
Generated by
1.12.0