Coverage Control Library
Loading...
Searching...
No Matches
save_gnn_params.py
1import io
2import torch
3
4def save_tensor(device, my_tensor, filename):
5 # print("[python] my_tensor: ", my_tensor)
6 f = io.BytesIO()
7 torch.save(my_tensor, f, _use_new_zipfile_serialization=True)
8 with open(filename, "wb") as out_f:
9 # Copy the BytesIO stream to the output file
10 out_f.write(f.getbuffer())
11
12model = torch.load('model_k3_1024_2l.pt')
13print("[python] model: ", model)
14gnn = model.gnn_backbone
15
16for param in model.parameters():
17 param.requires_grad = False
18nlayers = 2
19K = 3
20num_nodes = 10
21nfeatures = 34
22latent_size = 256
23
24gnn_state_dict = gnn.state_dict()
25for l in range(nlayers):
26 print(f'layer: {l}')
27 lin = gnn_state_dict[f'graph_conv_{l}.lins.0.weight']
28 bias = gnn_state_dict[f'graph_conv_{l}.bias']
29 save_tensor('cpu', lin, f'k3_params/py/lin_{l}_{0}.pt')
30 save_tensor('cpu', bias, f'k3_params/py/bias_{l}.pt')
31
32 for k in range(K):
33 print(f'k: {k}')
34 lin = gnn_state_dict[f'graph_conv_{l}.lins.{k+1}.weight']
35 save_tensor('cpu', lin, f'k3_params/py/lin_{l}_{k+1}.pt')
36
37mlp_layer = model.gnn_mlp
38mlp_layer.to('cpu')
39mlp_layer.eval()
40
41mlp_layer_state_dict = mlp_layer.state_dict()
42print(mlp_layer_state_dict.keys())
43
44l0_wts = mlp_layer_state_dict['lins.0.weight']
45l0_bias = mlp_layer_state_dict['lins.0.bias']
46save_tensor('cpu', l0_wts, 'k3_params/py/mlp_lin_0.pt')
47save_tensor('cpu', l0_bias, 'k3_params/py/mlp_bias_0.pt')
48
49print("============================")
50print(mlp_layer.norms[0].module.state_dict().keys())
51print("============================")
52n0_wts = mlp_layer_state_dict['norms.0.module.weight']
53n0_bias = mlp_layer_state_dict['norms.0.module.bias']
54running_mean = mlp_layer_state_dict['norms.0.module.running_mean']
55running_var = mlp_layer_state_dict['norms.0.module.running_var']
56num_batches_tracked = mlp_layer_state_dict['norms.0.module.num_batches_tracked']
57save_tensor('cpu', n0_wts, 'k3_params/py/mlp_norm_0_weight.pt')
58save_tensor('cpu', n0_bias, 'k3_params/py/mlp_norm_0_bias.pt')
59save_tensor('cpu', running_mean, 'k3_params/py/mlp_norm_0_running_mean.pt')
60save_tensor('cpu', running_var, 'k3_params/py/mlp_norm_0_running_var.pt')
61save_tensor('cpu', num_batches_tracked, 'k3_params/py/mlp_norm_0_num_batches_tracked.pt')
62
63l1_wts = mlp_layer_state_dict['lins.1.weight']
64l1_bias = mlp_layer_state_dict['lins.1.bias']
65
66save_tensor('cpu', l1_wts, 'k3_params/py/mlp_lin_1.pt')
67save_tensor('cpu', l1_bias, 'k3_params/py/mlp_bias_1.pt')
68
69output_layer = model.output_linear
70output_layer.to('cpu')
71output_layer.eval()
72output_layer_state_dict = output_layer.state_dict()
73outlayer_wts = output_layer_state_dict['weight']
74outlayer_bias = output_layer_state_dict['bias']
75
76save_tensor('cpu', outlayer_wts, 'k3_params/py/outlayer_wts.pt')
77save_tensor('cpu', outlayer_bias, 'k3_params/py/outlayer_bias.pt')
78
79save_tensor('cpu', model.actions_mean, 'k3_params/py/actions_mean.pt')
80save_tensor('cpu', model.actions_std, 'k3_params/py/actions_std.pt')