Coverage Control Library
Loading...
Searching...
No Matches
lpac.py
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22import torch
23import torch_geometric
24from torch_geometric.nn import MLP
25
26from .config_parser import GNNConfigParser
27from .cnn_backbone import CNNBackBone
28from .gnn_backbone import GNNBackBone
29
30__all__ = ["LPAC"]
31
32"""
33Module for LPAC model
34"""
35
36
37class LPAC(torch.nn.Module, GNNConfigParser):
38 """
39 LPAC neural network architecture
40 """
41 def __init__(self, in_config):
42 super().__init__()
43 self.cnn_config = in_config["CNNBackBone"]
44 self.parse(in_config["GNNBackBone"])
46 self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size + 2)
47 # --- no pos ---
48 # self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size)
49 # --- no pos ---
50 self.gnn_mlp = MLP([self.latent_size, 32, 32])
51 self.output_linear = torch.nn.Linear(32, self.output_dimoutput_dim)
52 # Register buffers to model
53 self.register_buffer("actions_mean", torch.zeros(self.output_dimoutput_dim))
54 self.register_buffer("actions_std", torch.ones(self.output_dimoutput_dim))
55
56 def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
57 """
58 Forward pass of the LPAC model
59 """
60 x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
61 pos = data.pos
62 cnn_output = self.cnn_backbone(x.view(-1, x.shape[-3], x.shape[-2], x.shape[-1]))
63
64 # --- no pos ---
65 # gnn_output = self.gnn_backbone(cnn_output, edge_index)
66 # mlp_output = self.gnn_mlp(gnn_output)
67 # x = self.output_linear(mlp_output)
68 # x = self.output_linear(self.gnn_mlp(self.gnn_backbone(cnn_output, edge_index)))
69 # --- no pos ---
70
71 gnn_backbone_in = torch.cat([cnn_output, pos], dim=-1)
72 # print(gnn_backbone_in)
73 # gnn_output = self.gnn_backbone(gnn_backbone_in, edge_index)
74 # mid_test = self.gnn_mlp.lins[0](gnn_output)
75 # print(f'mid_test sum1: {mid_test.sum()}')
76 # mid_test = self.gnn_mlp.norms[0](mid_test)
77 # print(f'mid_test sum: {mid_test.sum()}')
78 # mlp_output = self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index)
79 # print(f'mlp_output sum: {mlp_output[0]}')
80 x = self.output_linear(self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index)))
81 return x
82
83 def load_model(self, model_state_dict_path: str) -> None:
84 """
85 Load the model from the state dict
86 """
87 self.load_state_dict(torch.load(model_state_dict_path), strict=False)
88
89 def load_cnn_backbone(self, model_path: str) -> None:
90 """
91 Load the CNN backbone from the model path
92 """
93 self.load_state_dict(torch.load(model_path).state_dict(), strict=False)
94
95 def load_gnn_backbone(self, model_path: str) -> None:
96 """
97 Load the GNN backbone from the model path
98 """
99 self.load_state_dict(torch.load(model_path).state_dict(), strict=False)
Implements a multi-layer convolutional neural network, with leaky-ReLU non-linearities between layers...
None parse(self, dict config)
Parse the configuration for the GNN model.
Implements a GNN architecture, according to hyperparameters specified in the input config.
torch.Tensor forward(self, torch_geometric.data.Data data)
Forward pass of the LPAC model.
Definition lpac.py:59
None load_cnn_backbone(self, str model_path)
Load the CNN backbone from the model path.
Definition lpac.py:92
None load_gnn_backbone(self, str model_path)
Load the GNN backbone from the model path.
Definition lpac.py:98
None load_model(self, str model_state_dict_path)
Load the model from the state dict.
Definition lpac.py:86