Coverage Control Library
Loading...
Searching...
No Matches
gnn_backbone.py
Go to the documentation of this file.
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22"""
23Implements a GNN architecture.
24"""
25import torch
26import torch_geometric
27
28from .config_parser import GNNConfigParser
29
30
31class GNNBackBone(torch.nn.Module, GNNConfigParser):
32 """
33 Implements a GNN architecture,
34 according to hyperparameters specified in the input config
35 """
36
37 def __init__(self, config, input_dim=None):
38 super().__init__()
39
40 self.parse(config)
41
42 if input_dim is not None:
43 self.input_diminput_dim = input_dim
44
45 self.add_module(
46 "graph_conv_0",
47 torch_geometric.nn.TAGConv(
48 in_channels=self.input_diminput_dim,
49 out_channels=self.latent_size,
50 K=self.num_hops,
51 ),
52 )
53
54 for i in range(1, self.num_layersnum_layers):
55 self.add_module(
56 "graph_conv_{}".format(i),
57 torch_geometric.nn.TAGConv(
58 in_channels=self.latent_size,
59 out_channels=self.latent_size,
60 K=self.num_hops,
61 ),
62 )
63
65 self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight=None
66 ) -> torch.Tensor:
67 for i in range(self.num_layersnum_layers):
68 x = self._modules["graph_conv_{}".format(i)](x, edge_index, edge_weight)
69 x = torch.relu(x)
70
71 return x
Class to parse the configuration for the GNN model.
None parse(self, dict config)
Parse the configuration for the GNN model.
Implements a GNN architecture, according to hyperparameters specified in the input config.
torch.Tensor forward(self, torch.Tensor x, torch.Tensor edge_index, edge_weight=None)