4def save_tensor(device, my_tensor, filename):
7 torch.save(my_tensor, f, _use_new_zipfile_serialization=
True)
8 with open(filename,
"wb")
as out_f:
10 out_f.write(f.getbuffer())
12model = torch.load(
'model_k3_1024_2l.pt')
13print(
"[python] model: ", model)
14gnn = model.gnn_backbone
16for param
in model.parameters():
17 param.requires_grad =
False
24gnn_state_dict = gnn.state_dict()
25for l
in range(nlayers):
27 lin = gnn_state_dict[f
'graph_conv_{l}.lins.0.weight']
28 bias = gnn_state_dict[f
'graph_conv_{l}.bias']
29 save_tensor(
'cpu', lin, f
'k3_params/py/lin_{l}_{0}.pt')
30 save_tensor(
'cpu', bias, f
'k3_params/py/bias_{l}.pt')
34 lin = gnn_state_dict[f
'graph_conv_{l}.lins.{k+1}.weight']
35 save_tensor(
'cpu', lin, f
'k3_params/py/lin_{l}_{k+1}.pt')
37mlp_layer = model.gnn_mlp
41mlp_layer_state_dict = mlp_layer.state_dict()
42print(mlp_layer_state_dict.keys())
44l0_wts = mlp_layer_state_dict[
'lins.0.weight']
45l0_bias = mlp_layer_state_dict[
'lins.0.bias']
46save_tensor(
'cpu', l0_wts,
'k3_params/py/mlp_lin_0.pt')
47save_tensor(
'cpu', l0_bias,
'k3_params/py/mlp_bias_0.pt')
49print(
"============================")
50print(mlp_layer.norms[0].module.state_dict().keys())
51print(
"============================")
52n0_wts = mlp_layer_state_dict[
'norms.0.module.weight']
53n0_bias = mlp_layer_state_dict[
'norms.0.module.bias']
54running_mean = mlp_layer_state_dict[
'norms.0.module.running_mean']
55running_var = mlp_layer_state_dict[
'norms.0.module.running_var']
56num_batches_tracked = mlp_layer_state_dict[
'norms.0.module.num_batches_tracked']
57save_tensor(
'cpu', n0_wts,
'k3_params/py/mlp_norm_0_weight.pt')
58save_tensor(
'cpu', n0_bias,
'k3_params/py/mlp_norm_0_bias.pt')
59save_tensor(
'cpu', running_mean,
'k3_params/py/mlp_norm_0_running_mean.pt')
60save_tensor(
'cpu', running_var,
'k3_params/py/mlp_norm_0_running_var.pt')
61save_tensor(
'cpu', num_batches_tracked,
'k3_params/py/mlp_norm_0_num_batches_tracked.pt')
63l1_wts = mlp_layer_state_dict[
'lins.1.weight']
64l1_bias = mlp_layer_state_dict[
'lins.1.bias']
66save_tensor(
'cpu', l1_wts,
'k3_params/py/mlp_lin_1.pt')
67save_tensor(
'cpu', l1_bias,
'k3_params/py/mlp_bias_1.pt')
69output_layer = model.output_linear
72output_layer_state_dict = output_layer.state_dict()
73outlayer_wts = output_layer_state_dict[
'weight']
74outlayer_bias = output_layer_state_dict[
'bias']
76save_tensor(
'cpu', outlayer_wts,
'k3_params/py/outlayer_wts.pt')
77save_tensor(
'cpu', outlayer_bias,
'k3_params/py/outlayer_bias.pt')
79save_tensor(
'cpu', model.actions_mean,
'k3_params/py/actions_mean.pt')
80save_tensor(
'cpu', model.actions_std,
'k3_params/py/actions_std.pt')