24from torch_geometric.nn
import MLP
26from .config_parser
import GNNConfigParser
27from .cnn_backbone
import CNNBackBone
28from .gnn_backbone
import GNNBackBone
39 LPAC neural network architecture
44 self.
parse(in_config[
"GNNBackBone"])
56 def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
58 Forward pass of the LPAC model
60 x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
62 cnn_output = self.
cnn_backbone(x.view(-1, x.shape[-3], x.shape[-2], x.shape[-1]))
71 gnn_backbone_in = torch.cat([cnn_output, pos], dim=-1)
85 state_dict = torch.load(model_state_dict_path, weights_only=
True)
87 for key
in state_dict.keys():
88 new_state_dict[key.replace(
"_orig_mod.",
"")] = state_dict[key]
89 self.load_state_dict(new_state_dict, strict=
True)
91 def load_model(self, model_state_dict_path: str) ->
None:
93 Load the model from the state dict
95 self.load_state_dict(torch.load(model_state_dict_path, weights_only=
True), strict=
True)
99 Load the model from the state dict
101 self.load_state_dict(model_state_dict, strict=
True)
105 Load the CNN backbone from the model path
107 self.load_state_dict(torch.load(model_path, weights_only=
True).state_dict(), strict=
True)
111 Load the GNN backbone from the model path
113 self.load_state_dict(torch.load(model_path, weights_only=
True).state_dict(), strict=
True)
Implements a multi-layer convolutional neural network, with leaky-ReLU non-linearities between layers...
Class to parse the configuration for the GNN model.
None parse(self, dict config)
Parse the configuration for the GNN model.
Implements a GNN architecture, according to hyperparameters specified in the input config.
LPAC neural network architecture.
None load_compiled_state_dict(self, str model_state_dict_path)
__init__(self, in_config)
torch.Tensor forward(self, torch_geometric.data.Data data)
Forward pass of the LPAC model.
None load_model_state_dict(self, dict model_state_dict)
Load the model from the state dict.
None load_cnn_backbone(self, str model_path)
Load the CNN backbone from the model path.
None load_gnn_backbone(self, str model_path)
Load the GNN backbone from the model path.
None load_model(self, str model_state_dict_path)
Load the model from the state dict.