Skip to content
Snippets Groups Projects
Commit 7bf0513d authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

Set the correct results dir variable

parent b3333a2b
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
class GNN(torch.nn.Module):
def __init__(self, device: Optional[str] = None):
"""
Base class for Graph Neural Networks, supporting both homogeneous
and heterogeneous GNNs.
Args:
device (Optional[str]): Device to use ('cuda' or 'cpu'). If None, auto-detect.
"""
super(GNN, self).__init__()
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
self.best_model_state = None
def forward(self, *args, **kwargs):
"""
Forward pass to be implemented by child classes.
"""
raise NotImplementedError("The `forward` method must be implemented in a subclass.")
def predict(self, *args, **kwargs) -> torch.Tensor:
"""
Predict labels for a single graph or node, depending on the type of GNN.
To be implemented in child classes if customization is required.
Returns:
torch.Tensor: Predicted label(s).
"""
self.eval()
with torch.no_grad():
return self._predict(*args, **kwargs)
def predict_all(self, *args, **kwargs) -> torch.Tensor:
"""
Predict labels for all graphs or nodes, depending on the type of GNN.
To be implemented in child classes if customization is required.
Returns:
torch.Tensor: Predicted labels.
"""
self.eval()
with torch.no_grad():
return self._predict_all(*args, **kwargs)
def _predict(self, *args, **kwargs) -> torch.Tensor:
"""
Internal predict method for subclasses to override.
"""
raise NotImplementedError("The `_predict` method must be implemented in a subclass.")
def _predict_all(self, *args, **kwargs) -> torch.Tensor:
"""
Internal predict_all method for subclasses to override.
"""
raise NotImplementedError("The `_predict_all` method must be implemented in a subclass.")
......@@ -213,8 +213,8 @@ def experiment(grouped_keyword_dir, dataset_name, entity_name, iterations=5, num
# Create timestamp directory for all results
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
results = Path(f"evaluation_results/{timestamp}_{dataset_name}")
results.mkdir(parents=True, exist_ok=True)
results_dir = Path(f"evaluation_results/{timestamp}_{dataset_name}")
results_dir.mkdir(parents=True, exist_ok=True)
aggregated_results = {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment