24from torch_geometric.nn
import MLP
26from .config_parser
import GNNConfigParser
27from .cnn_backbone
import CNNBackBone
28from .gnn_backbone
import GNNBackBone
37class LPAC(torch.nn.Module, GNNConfigParser):
39 LPAC neural network architecture
41 def __init__(self, in_config):
44 self.
parse(in_config[
"GNNBackBone"])
56 def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
58 Forward pass of the LPAC model
60 x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
62 cnn_output = self.
cnn_backbone(x.view(-1, x.shape[-3], x.shape[-2], x.shape[-1]))
71 gnn_backbone_in = torch.cat([cnn_output, pos], dim=-1)
83 def load_model(self, model_state_dict_path: str) ->
None:
85 Load the model from the state dict
87 self.load_state_dict(torch.load(model_state_dict_path), strict=
False)
91 Load the CNN backbone from the model path
93 self.load_state_dict(torch.load(model_path).state_dict(), strict=
False)
97 Load the GNN backbone from the model path
99 self.load_state_dict(torch.load(model_path).state_dict(), strict=
False)
Implements a multi-layer convolutional neural network, with leaky-ReLU non-linearities between layers...
None parse(self, dict config)
Parse the configuration for the GNN model.
Implements a GNN architecture, according to hyperparameters specified in the input config.
torch.Tensor forward(self, torch_geometric.data.Data data)
Forward pass of the LPAC model.
None load_cnn_backbone(self, str model_path)
Load the CNN backbone from the model path.
None load_gnn_backbone(self, str model_path)
Load the GNN backbone from the model path.
None load_model(self, str model_state_dict_path)
Load the model from the state dict.