Coverage Control Library
Loading...
Searching...
No Matches
dataset_utils.py
1# This file is part of the CoverageControl library
2#
3# Author: Saurav Agarwal
4# Contact: sauravag@seas.upenn.edu, agr.saurav1@gmail.com
5# Repository: https://github.com/KumarRobotics/CoverageControl
6#
7# Copyright (c) 2024, Saurav Agarwal
8#
9# The CoverageControl library is free software: you can redistribute it and/or
10# modify it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or (at your
12# option) any later version.
13#
14# The CoverageControl library is distributed in the hope that it will be
15# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
17# Public License for more details.
18#
19# You should have received a copy of the GNU General Public License along with
20# CoverageControl library. If not, see <https://www.gnu.org/licenses/>.
21
22'''
23Utility functions for combining and splitting datasets.
24'''
25
26import os
27import sys
28import yaml
29import torch
30if sys.version_info[1] < 11:
31 import tomli as tomllib
32else:
33 import tomllib
34
35from coverage_control import IOUtils
36
37def split_save_data(path, data_name, num_train, num_val, num_test):
38 data = IOUtils.load_tensor(path + data_name + '.pt')
39 # Check if tensor is sparse, if so, convert to dense
40 is_sparse = False
41 if data.is_sparse:
42 data = data.to_dense()
43 is_sparse = True
44 # Split data into training, validation, and test sets
45 train_data = data[:num_train].clone()
46 val_data = data[num_train:num_train+num_val].clone()
47 test_data = data[num_train+num_val:].clone()
48 # Save data
49 if is_sparse:
50 train_data = train_data.to_sparse()
51 val_data = val_data.to_sparse()
52 test_data = test_data.to_sparse()
53
54 torch.save(train_data, path + '/train/' + data_name + '.pt')
55 torch.save(val_data, path + '/val/' + data_name + '.pt')
56 torch.save(test_data, path + '/test/' + data_name + '.pt')
57 # delete data to free up memory
58 del data
59 del train_data
60 del val_data
61 del test_data
62 # delete data file
63 os.remove(path + data_name + '.pt')
64
65def normalize_data(data):
66 data_mean = torch.mean(data.view(-1, data.shape[-1]), dim=0)
67 data_std = torch.std(data.view(-1, data.shape[-1]), dim=0)
68 data = (data - data_mean) / data_std
69 return data_mean, data_std, data
70
71def split_dataset(config_path):
72 '''
73 Split dataset into training, validation, and test sets.
74 The information is received via yaml config file.
75 '''
76 config = IOUtils.load_toml(os.path.expanduser(config_path))
77 data_path = config['DataDir']
78 data_path = os.path.expanduser(data_path)
79 data_dir = data_path + '/data/'
80
81 train_dir = data_dir + '/train'
82 val_dir = data_dir + '/val'
83 test_dir = data_dir + '/test'
84
85 # Create directories if they don't exist
86 if not os.path.exists(train_dir):
87 os.makedirs(train_dir)
88 if not os.path.exists(val_dir):
89 os.makedirs(val_dir)
90 if not os.path.exists(test_dir):
91 os.makedirs(test_dir)
92
93 num_dataset = config['NumDataset']
94 train_ratio = config['DataSetSplit']['TrainRatio']
95 val_ratio = config['DataSetSplit']['ValRatio']
96
97 num_train = int(train_ratio * num_dataset)
98 num_val = int(val_ratio * num_dataset)
99 num_test = num_dataset - num_train - num_val
100
101 data_names = ['local_maps', 'comm_maps', 'obstacle_maps', 'actions', 'normalized_actions', 'edge_weights', 'robot_positions', 'coverage_features', 'normalized_coverage_features']
102 for data_name in data_names:
103 split_save_data(data_dir, data_name, num_train, num_val, num_test)
104
105def combine_dataset(config_path, subdir_list):
106 '''
107 Combine split datasets into one dataset.
108 The information is received via yaml config file.
109 subdir_list is a list of subdirectories to combine, e.g. ['0', '1', '2']
110 '''
111 config = IOUtils.load_toml(os.path.expanduser(config_path))
112 data_path = config['DataDir']
113 data_path = os.path.expanduser(data_path)
114 data_dir = data_path + '/data/'
115 # Create directory if it doesn't exist
116
117 data_names = ['local_maps', 'comm_maps', 'obstacle_maps', 'actions', 'edge_weights', 'robot_positions', 'coverage_features']
118 normalize_data_names = ['actions', 'coverage_features']
119 for data_name in data_names:
120 is_sparse = False
121 for subdir in subdir_list:
122 data = IOUtils.load_tensor(data_dir + subdir + '/' + data_name + '.pt')
123 if subdir == subdir_list[0]:
124 is_sparse = data.is_sparse
125 if is_sparse:
126 data = data.to_dense()
127 combined_data = data.clone()
128 else:
129 if is_sparse:
130 data = data.to_dense()
131 combined_data = torch.cat((combined_data, data.clone()), dim=0)
132 del data
133 if is_sparse:
134 combined_data = combined_data.to_sparse()
135 torch.save(combined_data, data_dir + data_name + '.pt')
136 if data_name in normalize_data_names:
137 combined_data = combined_data.to_dense()
138 combined_data_mean, combined_data_std, combined_data = normalize_data(combined_data)
139 if is_sparse:
140 combined_data = combined_data.to_sparse()
141 torch.save(combined_data_mean, data_dir + data_name + '_mean.pt')
142 torch.save(combined_data_std, data_dir + data_name + '_std.pt')
143 torch.save(combined_data, data_dir + 'normalized_' + data_name + '.pt')
144 del combined_data
145
146__all__ = ['split_dataset', 'combine_dataset']
147
148# Example of how to call split_dataset from terminal
149# python -c 'from dataset_utils import split_dataset; split_dataset("config.yaml")'
150
151# Example of how to call combine_dataset from terminal
152# python -c 'from dataset_utils import combine_dataset; combine_dataset("config.yaml", ["0", "1", "2"])'
153