23Utility functions for combining and splitting datasets.
30if sys.version_info[1] < 11:
31 import tomli
as tomllib
35from coverage_control
import IOUtils
37def split_save_data(path, data_name, num_train, num_val, num_test):
38 data = IOUtils.load_tensor(path + data_name +
'.pt')
42 data = data.to_dense()
45 train_data = data[:num_train].clone()
46 val_data = data[num_train:num_train+num_val].clone()
47 test_data = data[num_train+num_val:].clone()
50 train_data = train_data.to_sparse()
51 val_data = val_data.to_sparse()
52 test_data = test_data.to_sparse()
54 torch.save(train_data, path +
'/train/' + data_name +
'.pt')
55 torch.save(val_data, path +
'/val/' + data_name +
'.pt')
56 torch.save(test_data, path +
'/test/' + data_name +
'.pt')
63 os.remove(path + data_name +
'.pt')
65def normalize_data(data):
66 data_mean = torch.mean(data.view(-1, data.shape[-1]), dim=0)
67 data_std = torch.std(data.view(-1, data.shape[-1]), dim=0)
68 data = (data - data_mean) / data_std
69 return data_mean, data_std, data
71def split_dataset(config_path):
73 Split dataset into training, validation, and test sets.
74 The information is received via yaml config file.
76 config = IOUtils.load_toml(os.path.expanduser(config_path))
77 data_path = config[
'DataDir']
78 data_path = os.path.expanduser(data_path)
79 data_dir = data_path +
'/data/'
81 train_dir = data_dir +
'/train'
82 val_dir = data_dir +
'/val'
83 test_dir = data_dir +
'/test'
86 if not os.path.exists(train_dir):
87 os.makedirs(train_dir)
88 if not os.path.exists(val_dir):
90 if not os.path.exists(test_dir):
93 num_dataset = config[
'NumDataset']
94 train_ratio = config[
'DataSetSplit'][
'TrainRatio']
95 val_ratio = config[
'DataSetSplit'][
'ValRatio']
97 num_train = int(train_ratio * num_dataset)
98 num_val = int(val_ratio * num_dataset)
99 num_test = num_dataset - num_train - num_val
101 data_names = [
'local_maps',
'comm_maps',
'obstacle_maps',
'actions',
'normalized_actions',
'edge_weights',
'robot_positions',
'coverage_features',
'normalized_coverage_features']
102 for data_name
in data_names:
103 split_save_data(data_dir, data_name, num_train, num_val, num_test)
105def combine_dataset(config_path, subdir_list):
107 Combine split datasets into one dataset.
108 The information is received via yaml config file.
109 subdir_list is a list of subdirectories to combine, e.g. ['0', '1', '2']
111 config = IOUtils.load_toml(os.path.expanduser(config_path))
112 data_path = config[
'DataDir']
113 data_path = os.path.expanduser(data_path)
114 data_dir = data_path +
'/data/'
117 data_names = [
'local_maps',
'comm_maps',
'obstacle_maps',
'actions',
'edge_weights',
'robot_positions',
'coverage_features']
118 normalize_data_names = [
'actions',
'coverage_features']
119 for data_name
in data_names:
121 for subdir
in subdir_list:
122 data = IOUtils.load_tensor(data_dir + subdir +
'/' + data_name +
'.pt')
123 if subdir == subdir_list[0]:
124 is_sparse = data.is_sparse
126 data = data.to_dense()
127 combined_data = data.clone()
130 data = data.to_dense()
131 combined_data = torch.cat((combined_data, data.clone()), dim=0)
134 combined_data = combined_data.to_sparse()
135 torch.save(combined_data, data_dir + data_name +
'.pt')
136 if data_name
in normalize_data_names:
137 combined_data = combined_data.to_dense()
138 combined_data_mean, combined_data_std, combined_data = normalize_data(combined_data)
140 combined_data = combined_data.to_sparse()
141 torch.save(combined_data_mean, data_dir + data_name +
'_mean.pt')
142 torch.save(combined_data_std, data_dir + data_name +
'_std.pt')
143 torch.save(combined_data, data_dir +
'normalized_' + data_name +
'.pt')
146__all__ = [
'split_dataset',
'combine_dataset']