Coverage Control Library
Loading...
Searching...
No Matches
lpac.py
Go to the documentation of this file.
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22import torch
23import torch_geometric
24from torch_geometric.nn import MLP
25
26from .config_parser import GNNConfigParser
27from .cnn_backbone import CNNBackBone
28from .gnn_backbone import GNNBackBone
29
30__all__ = ["LPAC"]
31
32"""
33Module for LPAC model
34"""
35
36
37class LPAC(torch.nn.Module, GNNConfigParser):
38 """
39 LPAC neural network architecture
40 """
41 def __init__(self, in_config):
42 super().__init__()
43 self.cnn_config = in_config["CNNBackBone"]
44 self.parse(in_config["GNNBackBone"])
46 self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size + 2)
47 # --- no pos ---
48 # self.gnn_backbone = GNNBackBone(self.config, self.cnn_backbone.latent_size)
49 # --- no pos ---
50 self.gnn_mlp = MLP([self.latent_size, 32, 32])
51 self.output_linear = torch.nn.Linear(32, self.output_dimoutput_dim)
52 # Register buffers to model
53 self.register_buffer("actions_mean", torch.zeros(self.output_dimoutput_dim))
54 self.register_buffer("actions_std", torch.ones(self.output_dimoutput_dim))
55
56 def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
57 """
58 Forward pass of the LPAC model
59 """
60 x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
61 pos = data.pos
62 cnn_output = self.cnn_backbone(x.view(-1, x.shape[-3], x.shape[-2], x.shape[-1]))
63
64 # --- no pos ---
65 # gnn_output = self.gnn_backbone(cnn_output, edge_index)
66 # mlp_output = self.gnn_mlp(gnn_output)
67 # x = self.output_linear(mlp_output)
68 # x = self.output_linear(self.gnn_mlp(self.gnn_backbone(cnn_output, edge_index)))
69 # --- no pos ---
70
71 gnn_backbone_in = torch.cat([cnn_output, pos], dim=-1)
72 # print(gnn_backbone_in)
73 # gnn_output = self.gnn_backbone(gnn_backbone_in, edge_index)
74 # mid_test = self.gnn_mlp.lins[0](gnn_output)
75 # print(f'mid_test sum1: {mid_test.sum()}')
76 # mid_test = self.gnn_mlp.norms[0](mid_test)
77 # print(f'mid_test sum: {mid_test.sum()}')
78 # mlp_output = self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index)
79 # print(f'mlp_output sum: {mlp_output[0]}')
80 x = self.output_linear(self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index)))
81 return x
82
83 def load_compiled_state_dict(self, model_state_dict_path: str) -> None:
84 # remove _orig_mod from the state dict keys
85 state_dict = torch.load(model_state_dict_path, weights_only=True)
86 new_state_dict = {}
87 for key in state_dict.keys():
88 new_state_dict[key.replace("_orig_mod.", "")] = state_dict[key]
89 self.load_state_dict(new_state_dict, strict=True)
90
91 def load_model(self, model_state_dict_path: str) -> None:
92 """
93 Load the model from the state dict
94 """
95 self.load_state_dict(torch.load(model_state_dict_path, weights_only=True), strict=True)
96
97 def load_model_state_dict(self, model_state_dict: dict) -> None:
98 """
99 Load the model from the state dict
100 """
101 self.load_state_dict(model_state_dict, strict=True)
102
103 def load_cnn_backbone(self, model_path: str) -> None:
104 """
105 Load the CNN backbone from the model path
106 """
107 self.load_state_dict(torch.load(model_path, weights_only=True).state_dict(), strict=True)
108
109 def load_gnn_backbone(self, model_path: str) -> None:
110 """
111 Load the GNN backbone from the model path
112 """
113 self.load_state_dict(torch.load(model_path, weights_only=True).state_dict(), strict=True)
Implements a multi-layer convolutional neural network, with leaky-ReLU non-linearities between layers...
Class to parse the configuration for the GNN model.
None parse(self, dict config)
Parse the configuration for the GNN model.
Implements a GNN architecture, according to hyperparameters specified in the input config.
LPAC neural network architecture.
Definition lpac.py:40
None load_compiled_state_dict(self, str model_state_dict_path)
Definition lpac.py:83
torch.Tensor forward(self, torch_geometric.data.Data data)
Forward pass of the LPAC model.
Definition lpac.py:59
None load_model_state_dict(self, dict model_state_dict)
Load the model from the state dict.
Definition lpac.py:100
None load_cnn_backbone(self, str model_path)
Load the CNN backbone from the model path.
Definition lpac.py:106
None load_gnn_backbone(self, str model_path)
Load the GNN backbone from the model path.
Definition lpac.py:112
None load_model(self, str model_state_dict_path)
Load the model from the state dict.
Definition lpac.py:94