23Implements an architecture consisting of a multi-layer CNN followed by an MLP, according to parameters specified in the input config
26from torch_geometric.nn
import MLP
28from .cnn_backbone
import CNNBackBone
29from .config_parser
import CNNConfigParser
37 Implements an architecture consisting of a multi-layer CNN followed by an MLP, according to parameters specified in the input config.
55 def forward(self, x: torch.Tensor) -> torch.Tensor:
57 Forward pass through the network
73 Loads a model saved in cpp jit format
75 jit_model = torch.jit.load(model_path)
76 self.load_state_dict(jit_model.state_dict(), strict=
False)
80 Loads a model saved in pytorch format
82 self.load_state_dict(torch.load(model_path), strict=
False)
Implements an architecture consisting of a multi-layer CNN followed by an MLP, according to parameter...
__init__(self, dict config)
torch.Tensor forward(self, torch.Tensor x)
Forward pass through the network.
None load_cpp_model(self, str model_path)
Loads a model saved in cpp jit format.
None load_model(self, str model_path)
Loads a model saved in pytorch format.
Implements a multi-layer convolutional neural network, with leaky-ReLU non-linearities between layers...
This file contains the configuration parser for the models.
None parse(self, dict config)
Parse the configuration for the CNN model.