Coverage Control Library
Loading...
Searching...
No Matches
controllers.py
Go to the documentation of this file.
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21# @file controller.py
22# @brief Base classes for CVT and neural network based controllers
23import coverage_control.nn as cc_nn
24import torch
25torch.set_float32_matmul_precision('high')
26
27from . import CentralizedCVT
28from . import ClairvoyantCVT
29from . import DecentralizedCVT
30from . import NearOptimalCVT
31from .. import IOUtils
32from .. import CoverageEnvUtils
33from ..core import CoverageSystem
34from ..core import Parameters
35from ..core import PointVector
36
37__all__ = ["ControllerCVT", "ControllerNN"]
38
39
40class ControllerCVT:
41 """
42 Controller class for CVT based controllers
43 """
44
45 def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
46 """
47 Constructor for the CVT controller
48 Args:
49 config: Configuration dictionary
50 params: Parameters object
51 env: CoverageSystem object
52 """
53 self.name = config["Name"]
54 self.params = params
55 match config["Algorithm"]:
56 case "DecentralizedCVT":
57 self.alg = DecentralizedCVT(params, env)
58 case "ClairvoyantCVT":
59 self.alg = ClairvoyantCVT(params, env)
60 case "CentralizedCVT":
61 self.alg = CentralizedCVT(params, env)
62 case "NearOptimalCVT":
63 self.alg = NearOptimalCVT(params, env)
64 case _:
65 raise ValueError(f"Unknown controller type: {controller_type}")
66
67 def step(self, env: CoverageSystem) -> (float, bool):
68 """
69 Step function for the CVT controller
70
71 Performs three steps:
72 1. Compute actions using the CVT algorithm
73 2. Get the actions from the algorithm
74 3. Step the environment using the actions
75 Args:
76 env: CoverageSystem object
77 Returns:
78 Objective value and convergence flag
79 """
80 self.alg.ComputeActions()
81 actions = self.alg.GetActions()
82 converged = self.alg.IsConverged()
83 error_flag = env.StepActions(actions)
84
85 if error_flag:
86 raise ValueError("Error in step")
87
88 return env.GetObjectiveValue(), converged
89
90
91class ControllerNN:
92 """
93 Controller class for neural network based controllers
94 """
95
96 def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
97 """
98 Constructor for the neural network controller
99 Args:
100 config: Configuration dictionary
101 params: Parameters object
102 env: CoverageSystem object
103 """
104 self.config = config
105 self.params = params
106 self.name = self.config["Name"]
107 self.use_cnn = self.config["UseCNN"]
108 self.use_comm_map = self.config["UseCommMap"]
109 self.cnn_map_size = self.config["CNNMapSize"]
110
111 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112 # print(f"Using device: {self.device}")
113
114 if "ModelFile" in self.config:
115 self.model_file = IOUtils.sanitize_path(self.config["ModelFile"])
116 self.model = torch.load(self.model_file).to(self.device)
117 else: # Load from ModelStateDict
118 self.learning_params_file = IOUtils.sanitize_path(
119 self.config["LearningParams"]
120 )
121 self.learning_params = IOUtils.load_toml(self.learning_params_file)
122 self.model = cc_nn.LPAC(self.learning_params).to(self.device)
123 self.model.load_model(IOUtils.sanitize_path(self.config["ModelStateDict"]))
124
125 self.actions_mean = self.model.actions_mean.to(self.device)
126 self.actions_std = self.model.actions_std.to(self.device)
127 self.model = self.model.to(self.device)
128 self.model.eval()
129 self.model = torch.compile(self.model, dynamic=True)
130
131 def step(self, env):
132 """
133 step function for the neural network controller
134
135 Performs three steps:
136 1. Get the data from the environment
137 2. Get the actions from the model
138 3. Step the environment using the actions
139 Args:
140 env: CoverageSystem object
141 Returns:
142 Objective value and convergence flag
143 """
144 pyg_data = CoverageEnvUtils.get_torch_geometric_data(
145 env, self.params, True, self.use_comm_map, self.cnn_map_size
146 ).to(self.device)
147 with torch.no_grad():
148 actions = self.model(pyg_data)
149 actions = actions * self.actions_std + self.actions_mean
150 point_vector_actions = PointVector(actions.cpu().numpy())
151 env.StepActions(point_vector_actions)
152
153 # Check if actions are all zeros (1e-12)
154 if torch.allclose(actions, torch.zeros_like(actions), atol=1e-5):
155 return env.GetObjectiveValue(), True
156 return env.GetObjectiveValue(), False
Controller class for CVT based controllers.
__init__(self, dict config, Parameters params, CoverageSystem env)
Constructor for the CVT controller.
(float, bool) step(self, CoverageSystem env)
Step function for the CVT controller.
Controller class for neural network based controllers.
__init__(self, dict config, Parameters params, CoverageSystem env)
Constructor for the neural network controller.
step(self, env)
step function for the neural network controller