Coverage Control Library
Loading...
Searching...
No Matches
loaders.py
Go to the documentation of this file.
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22"""
23Module for loading datasets
24"""
25
26import torch
27from coverage_control import IOUtils
28from torch_geometric.data import Dataset
29
30from ...coverage_env_utils import CoverageEnvUtils
31from .data_loader_utils import DataLoaderUtils
32
33
34
35class LocalMapCNNDataset(Dataset):
36 """
37 Dataset for CNN training
38 """
39
41 self,
42 data_dir: str,
43 stage: str,
44 use_comm_map: bool,
45 output_dim: int,
46 preload: bool = True,
47 ):
48 super().__init__(None, None, None, None)
49 """
50 Constructor for the LocalMapCNNDataset class
51 Args:
52 data_dir (str): Directory containing the data
53 stage (str): Stage of the data (train, val, test)
54 use_comm_map (bool): Whether to use communication maps
55 output_dim (int): Dimension of the output
56 preload (bool): Whether to preload the data
57 """
58
59 self.stage = stage
60 self.data_dir = data_dir
61 self.output_dim = output_dim
62 self.use_comm_map = use_comm_map
63
64 if preload is True:
65 self.load_data()
66
67 def len(self):
68 return self.dataset_size
69
70 def get(self, idx):
71 maps = self.maps[idx]
72 target = self.targets[idx]
73
74 return maps, target
75
76 def load_data(self):
77 """
78 Load the data from the data directory
79 """
80 # maps has shape (num_samples, num_robots, nuimage_size, image_size)
81 self.maps = DataLoaderUtils.load_maps(
82 f"{self.data_dir}/{self.stage}", self.use_comm_map
83 )
84 num_channels = self.maps.shape[2]
85 image_size = self.maps.shape[3]
86
87 self.maps = self.maps.view(-1, num_channels, image_size, image_size)
88 self.dataset_size = self.maps.shape[0]
89
90 # self.targets, self.targets_mean, self.targets_std = DataLoaderUtils.load_features(f"{self.data_dir}/{self.stage}", self.output_dim)
91 self.targets, self.targets_mean, self.targets_std = (
92 DataLoaderUtils.load_actions(f"{self.data_dir}/{self.stage}")
93 )
94 self.targets = self.targets.view(-1, self.targets.shape[2])
95
96
97class CNNGNNDataset(Dataset):
98 """
99 Dataset for hybrid CNN-GNN training
100 """
101
102 def __init__(self, data_dir, stage, use_comm_map, world_size):
103 super().__init__(None, None, None, None)
104
105 self.stage = stage
106
107 self.maps = DataLoaderUtils.load_maps(f"{data_dir}/{stage}", use_comm_map)
108 self.dataset_size = self.maps.shape[0]
109
110 self.targets, self.targets_mean, self.targets_std = (
111 DataLoaderUtils.load_actions(f"{data_dir}/{stage}")
112 )
113 self.edge_weights = DataLoaderUtils.load_edge_weights(f"{data_dir}/{stage}")
114
115 self.robot_positions = DataLoaderUtils.load_robot_positions(
116 f"{data_dir}/{stage}"
117 )
118 self.robot_positions = (self.robot_positions + world_size / 2) / world_size
119
120 # Print the details of the dataset with device information
121 print(f"Dataset: {self.stage} | Size: {self.dataset_size}",
122 f"Coverage Maps: {self.maps.shape}",
123 f"Targets: {self.targets.shape}",
124 f"Robot Positions: {self.robot_positions.shape}",
125 f"Edge Weights: {self.edge_weights.shape}",
126 )
127
128 def len(self):
129 return self.dataset_size
130
131 def get(self, idx):
132 data = DataLoaderUtils.to_torch_geometric_data(
133 self.maps[idx], self.edge_weights[idx], self.robot_positions[idx]
134 )
135 # data = CoverageEnvUtils.GetTorchGeometricDataRobotPositions(self.maps[idx], self.robot_positions[idx])
136 targets = self.targets[idx]
137
138 if targets.dim == 3:
139 targets = targets.view(-1, targets.shape[-1])
140
141 return data, targets
142
143
144
145class VoronoiGNNDataset(Dataset):
146 """
147 Dataset for non-hybrid GNN training
148 """
149
150 def __init__(self, data_dir, stage, output_dim):
151 super().__init__(None, None, None, None)
152
153 self.stage = stage
154 self.output_dim = output_dim
155
156 self.features = DataLoaderUtils.load_features(f"{data_dir}/{stage}", output_dim)
157 self.dataset_size = self.features[0].shape[0]
158 self.targets, self.targets_mean, self.targets_std = (
159 DataLoaderUtils.load_actions(f"{data_dir}/{stage}")
160 )
161 self.edge_weights = DataLoaderUtils.load_edge_weights(f"{data_dir}/{stage}")
162
163 def len(self):
164 return self.dataset_size
165
166 def get(self, idx):
167 data = DataLoaderUtils.to_torch_geometric_data(
168 self.features[idx], self.edge_weights[idx], self.targets[idx]
169 )
170
171 return data, data.y
Dataset for hybrid CNN-GNN training.
Definition loaders.py:100
__init__(self, data_dir, stage, use_comm_map, world_size)
Definition loaders.py:102
__init__(self, str data_dir, str stage, bool use_comm_map, int output_dim, bool preload=True)
Definition loaders.py:47
load_data(self)
Load the data from the data directory.
Definition loaders.py:79
__init__(self, data_dir, stage, output_dim)
Definition loaders.py:150