25torch.set_float32_matmul_precision(
'high')
27from .
import CentralizedCVT
28from .
import ClairvoyantCVT
29from .
import DecentralizedCVT
30from .
import NearOptimalCVT
32from ..
import CoverageEnvUtils
33from ..core
import CoverageSystem
34from ..core
import Parameters
35from ..core
import PointVector
37__all__ = [
"ControllerCVT",
"ControllerNN"]
42 Controller class for CVT based controllers
45 def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
47 Constructor for the CVT controller
49 config: Configuration dictionary
50 params: Parameters object
51 env: CoverageSystem object
55 match config[
"Algorithm"]:
56 case
"DecentralizedCVT":
57 self.
alg = DecentralizedCVT(params, env)
58 case
"ClairvoyantCVT":
59 self.
alg = ClairvoyantCVT(params, env)
60 case
"CentralizedCVT":
61 self.
alg = CentralizedCVT(params, env)
62 case
"NearOptimalCVT":
63 self.
alg = NearOptimalCVT(params, env)
65 raise ValueError(f
"Unknown controller type: {controller_type}")
67 def step(self, env: CoverageSystem) -> (float, bool):
69 Step function for the CVT controller
72 1. Compute actions using the CVT algorithm
73 2. Get the actions from the algorithm
74 3. Step the environment using the actions
76 env: CoverageSystem object
78 Objective value and convergence flag
80 self.
alg.ComputeActions()
81 actions = self.
alg.GetActions()
82 converged = self.
alg.IsConverged()
83 error_flag = env.StepActions(actions)
86 raise ValueError(
"Error in step")
88 return env.GetObjectiveValue(), converged
93 Controller class for neural network based controllers
96 def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
98 Constructor for the neural network controller
100 config: Configuration dictionary
101 params: Parameters object
102 env: CoverageSystem object
111 self.
device = torch.device(
"cuda" if torch.cuda.is_available()
else "cpu")
114 if "ModelFile" in self.
config:
119 self.
config[
"LearningParams"]
123 self.
model.load_model(IOUtils.sanitize_path(self.
config[
"ModelStateDict"]))
129 self.
model = torch.compile(self.
model, dynamic=
True)
133 step function for the neural network controller
135 Performs three steps:
136 1. Get the data from the environment
137 2. Get the actions from the model
138 3. Step the environment using the actions
140 env: CoverageSystem object
142 Objective value and convergence flag
144 pyg_data = CoverageEnvUtils.get_torch_geometric_data(
147 with torch.no_grad():
148 actions = self.
model(pyg_data)
150 point_vector_actions = PointVector(actions.cpu().numpy())
151 env.StepActions(point_vector_actions)
154 if torch.allclose(actions, torch.zeros_like(actions), atol=1e-5):
155 return env.GetObjectiveValue(),
True
156 return env.GetObjectiveValue(),
False
Controller class for CVT based controllers.
__init__(self, dict config, Parameters params, CoverageSystem env)
Constructor for the CVT controller.
(float, bool) step(self, CoverageSystem env)
Step function for the CVT controller.
Controller class for neural network based controllers.
__init__(self, dict config, Parameters params, CoverageSystem env)
Constructor for the neural network controller.
step(self, env)
step function for the neural network controller