Skip to content
Snippets Groups Projects
Commit 009d7aba authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

Fix issues in the exepriment

parent d53ba529
No related branches found
No related tags found
No related merge requests found
......@@ -36,9 +36,6 @@ class LinearPressureFitness(AbstractFitness):
quality = individual.quality.values[0]
fitness = self.gain*quality - self.penalty*len(individual)
print(individual)
print(self.gain, quality, self.gain*quality, len(individual), fitness)
individual.fitness.values = (round(fitness, 5),)
class DiscriminativeExplainer:
......@@ -128,7 +125,7 @@ class DiscriminativeExplainer:
max_runtime: Optional[int] = 60,
num_generations: Optional[int] = 600,
quality_func: Optional[AbstractScorer] = None,
length_penalty: Optional[int] = 5,) -> OWLClassExpression:
length_penalty: Optional[int] = 0.5) -> OWLClassExpression:
"""Explains based on the GNN a given label. The explanation is in the form of a Class Expression.
Args:
......
import torch
import torch.nn.functional as F
from torch.nn import Module, Linear, BatchNorm1d, Dropout
from torch.nn import Linear, BatchNorm1d, Dropout
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import HeteroConv, GraphConv, global_mean_pool, global_add_pool
from torch_geometric.nn import HeteroConv, GraphConv
from torch_geometric.data import HeteroData
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score
import numpy as np
......@@ -12,76 +12,59 @@ from torch_geometric.loader import DataLoader
from tabulate import tabulate
from Utils.Utils import get_feature_sizes_and_edge_config, find_classes_with_y_labels
class GNN(torch.nn.Module):
def __init__(self, data: HeteroData, hidden_channels: int = 256,
num_hidden_layers: int = 4) -> None:
"""Initialize the Heterogeneous Graph Neural Network.
"""
Initialize the Heterogeneous Graph Neural Network.
Args:
data (HeteroData): The heterogeneous graph data
data (HeteroData): The heterogeneous graph data.
hidden_channels (int, optional): Number of hidden channels. Defaults to 256.
num_hidden_layers (int, optional): Number of hidden GNN layers. Defaults to 4.
"""
super(GNN, self).__init__()
# Get feature sizes and edge configuration
# Get feature sizes and edge configuration.
self.feature_sizes, self.edge_config = get_feature_sizes_and_edge_config(data)
self.num_hidden_layers = num_hidden_layers
self.data = data
# Initial normalization layers for each node type
# Initial normalization layers for each node type.
self.norm_layers = torch.nn.ModuleDict({
node_type: BatchNorm1d(feature_size)
for node_type, feature_size in self.feature_sizes.items()
})
# Separate conv layers and normalization for heterogeneous graphs
# Build heterogeneous convolution layers with corresponding batch norms.
self.convs = torch.nn.ModuleList()
self.bns = torch.nn.ModuleList()
# Create multiple conv layers with improved structure
for layer in range(num_hidden_layers):
conv_dict = {}
for (src, rel, dst) in self.edge_config:
in_channels = ((self.feature_sizes[src] if layer == 0 else hidden_channels),
(self.feature_sizes[dst] if layer == 0 else hidden_channels))
in_channels = (
self.feature_sizes[src] if layer == 0 else hidden_channels,
self.feature_sizes[dst] if layer == 0 else hidden_channels
)
conv_dict[(src, rel, dst)] = GraphConv(
in_channels=in_channels,
out_channels=hidden_channels
)
self.convs.append(HeteroConv(conv_dict, aggr='mean'))
# Add batch norm for each node type
self.bns.append(torch.nn.ModuleDict({
node_type: BatchNorm1d(hidden_channels)
for node_type in self.feature_sizes
}))
# Multi-head pooling for each node type
self.pooling = torch.nn.ModuleDict({
node_type: torch.nn.ModuleDict({
'mean': Linear(hidden_channels, hidden_channels),
'add': Linear(hidden_channels, hidden_channels)
}) for node_type in self.feature_sizes
})
# Find nodes with labels
self.label_nodes = find_classes_with_y_labels(self.data, first_only=False)
# Classification heads for each labeled node type
self.classifiers = torch.nn.ModuleDict()
self.node_classifiers = torch.nn.ModuleDict()
for node_type in self.label_nodes:
num_classes = len(torch.unique(data[node_type].y))
self.classifiers[node_type] = torch.nn.Sequential(
Linear(hidden_channels * 2, hidden_channels),
torch.nn.ELU(),
Dropout(p=0.2),
Linear(hidden_channels, hidden_channels // 2),
self.node_classifiers[node_type] = torch.nn.Sequential(
Linear(hidden_channels, hidden_channels),
torch.nn.ELU(),
Dropout(p=0.2),
Linear(hidden_channels // 2, num_classes)
Linear(hidden_channels, num_classes)
)
self.dropout = Dropout(p=0.2)
......@@ -91,88 +74,65 @@ class GNN(torch.nn.Module):
def forward(self, x_dict: Dict[str, torch.Tensor],
edge_index_dict: Dict[Tuple[str, str, str], torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Forward pass of the heterogeneous GNN.
"""
Forward pass of the heterogeneous GNN for node-level prediction.
Args:
x_dict (Dict[str, torch.Tensor]): Dictionary of node features for each node type
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): Dictionary of edge indices
for each edge type, where the tuple key is (source_type, edge_type, target_type)
x_dict (Dict[str, torch.Tensor]): Node features for each node type.
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): Edge indices for each edge type.
Returns:
Dict[str, torch.Tensor]: Dictionary of predictions for each labeled node type
Dict[str, torch.Tensor]: Log softmax predictions for each node in the labeled node types.
"""
# Initial feature normalization
# Initial feature normalization.
out_dict = {
node_type: self.norm_layers[node_type](x)
for node_type, x in x_dict.items()
}
# Process through conv layers
# Process through convolution layers.
for i, conv in enumerate(self.convs):
# Store for residual connection
identity = out_dict
# Apply convolution
conv_out = conv(out_dict, edge_index_dict)
# Apply batch norm, activation, and dropout for each node type
conv_out = {
node_type: self.dropout(
F.elu(self.bns[i][node_type](features))
)
node_type: self.dropout(F.elu(self.bns[i][node_type](features)))
for node_type, features in conv_out.items()
}
# Add residual connection after first layer
# Add residual connection after first layer.
if i > 0:
conv_out = {
node_type: features + 0.1 * identity[node_type]
for node_type, features in conv_out.items()
}
out_dict = conv_out
# Multi-head pooling and classification for each labeled node type
# Apply node-level classifiers directly on node embeddings.
final_out = {}
for node_type in self.label_nodes:
final_out[node_type] = self.node_classifiers[node_type](out_dict[node_type])
# Use global pooling functions
mean_pooled = global_mean_pool(out_dict[node_type])
add_pooled = global_add_pool(out_dict[node_type])
# Post-pooling transformations
mean_pooled = self.pooling[node_type]['mean'](mean_pooled)
add_pooled = self.pooling[node_type]['add'](add_pooled)
# Concatenate and classify
pooled = torch.cat([mean_pooled, add_pooled], dim=1)
final_out[node_type] = self.classifiers[node_type](pooled)
# Apply log softmax to outputs
# Return per-node log softmax predictions.
return {
node_type: F.log_softmax(out, dim=1)
for node_type, out in final_out.items()
node_type: F.log_softmax(pred, dim=1)
for node_type, pred in final_out.items()
}
def train_model(self, epochs: int = 300, lr: float = 0.001,
show_progress: bool = False) -> Dict[str, Dict[str, float]]:
"""Train the heterogeneous GNN model.
"""
Train the heterogeneous GNN model (node-level classification).
Args:
epochs (int, optional): Number of training epochs. Defaults to 300.
lr (float, optional): Learning rate. Defaults to 0.001.
show_progress (bool, optional): Whether to display training progress. Defaults to False.
show_progress (bool, optional): Whether to show training progress. Defaults to False.
Returns:
Dict[str, Dict[str, float]]: Dictionary of best metrics for each node type
"""
"""
Train with enhanced learning schedule and early stopping
Dict[str, Dict[str, float]]: Best metrics for each labeled node type.
"""
self.data = self.data.to(self.device)
optimizer = Adam(self.parameters(), lr=lr, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.7,
patience=20, min_lr=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=20, min_lr=1e-5)
# Initialize best metrics for each node type
best_metrics = {
......@@ -180,60 +140,52 @@ class GNN(torch.nn.Module):
for node_type in self.label_nodes
}
# Calculate class weights for each node type
# Compute class weights for each labeled node type.
class_weights = {}
for node_type in self.label_nodes:
labels = self.data[node_type].y.cpu().numpy()
counts = np.bincount(labels)
class_weights[node_type] = torch.FloatTensor(1.0 / counts).to(self.device)
num_classes = len(torch.unique(self.data[node_type].y))
class_weights[node_type] = torch.ones(num_classes).to(self.device)
patience = 150
no_improve = 0
best_avg_f1 = 0
for epoch in range(epochs):
# Training phase
self.train()
optimizer.zero_grad()
# Forward pass
# Forward pass (node-level predictions).
out_dict = self(self.data.x_dict, self.data.edge_index_dict)
total_loss = 0
current_metrics = {node_type: {} for node_type in self.label_nodes}
# Calculate loss and metrics for each node type
for node_type in self.label_nodes:
# Here the targets are per-node.
target = self.data[node_type].y
loss = F.cross_entropy(
out_dict[node_type],
self.data[node_type].y,
target,
weight=class_weights[node_type]
)
total_loss += loss
# Calculate metrics
pred = torch.argmax(out_dict[node_type], dim=1)
y_true = self.data[node_type].y.cpu().numpy()
y_true = target.cpu().numpy()
y_pred = pred.cpu().numpy()
current_metrics[node_type] = {
'accuracy': accuracy_score(y_true, y_pred),
'precision': precision_score(y_true, y_pred, average='weighted',
zero_division=1),
'recall': recall_score(y_true, y_pred, average='weighted',
zero_division=1),
'precision': precision_score(y_true, y_pred, average='weighted', zero_division=1),
'recall': recall_score(y_true, y_pred, average='weighted', zero_division=1),
'f1': f1_score(y_true, y_pred, average='weighted')
}
# Backward pass and optimization
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
optimizer.step()
# Calculate average F1 score across all node types
avg_f1 = np.mean([metrics['f1'] for metrics in current_metrics.values()])
scheduler.step(avg_f1)
# Update best metrics and model state
if avg_f1 > best_avg_f1:
best_avg_f1 = avg_f1
best_metrics = {
......@@ -261,20 +213,17 @@ class GNN(torch.nn.Module):
return best_metrics
def predict(self, node_type: str, idx: Optional[int] = None) -> torch.Tensor:
"""Make predictions for a specific node type.
"""
Make predictions for a specific node type.
Args:
node_type (str): Type of node to make predictions for
idx (Optional[int], optional): Specific node index. If None, predicts for all nodes.
Defaults to None.
node_type (str): Type of node to make predictions for.
idx (Optional[int], optional): If provided, returns prediction for that node index.
Otherwise, returns predictions for all nodes.
Returns:
torch.Tensor: Predicted class indices
Raises:
ValueError: If node_type has no labels
torch.Tensor: Predicted class indices.
"""
"""Prediction with optional indexing"""
if node_type not in self.label_nodes:
raise ValueError(f"Node type {node_type} has no labels")
......@@ -282,28 +231,28 @@ class GNN(torch.nn.Module):
with torch.no_grad():
predictions = self(self.data.x_dict, self.data.edge_index_dict)
pred = torch.argmax(predictions[node_type], dim=1)
return pred[idx] if idx is not None else pred
return pred if idx is None else pred[idx]
def predict_all(self) -> Dict[str, torch.Tensor]:
"""Make predictions for all labeled node types.
"""
Make node-level predictions for all labeled node types.
Returns:
Dict[str, torch.Tensor]: Dictionary of predictions for each labeled node type
Dict[str, torch.Tensor]: A dictionary containing predictions for each labeled node type.
"""
"""Predict for all labeled node types"""
self.eval()
with torch.no_grad():
predictions = self(self.data.x_dict, self.data.edge_index_dict)
return {
node_type: torch.argmax(predictions[node_type], dim=1)
for node_type in self.label_nodes
node_type: torch.argmax(pred, dim=1)
for node_type, pred in predictions.items()
}
def main():
"""Example usage with DBLP dataset"""
"""Example usage with DBLP dataset."""
from CustomDataset.Text.DBLP import DBLP
# Load DBLP dataset
# Load DBLP dataset.
dblp = DBLP(root='rawData/DBLP')
dataset = dblp.dataset
......@@ -341,7 +290,7 @@ def main():
predictions = model.predict_all()
for node_type in model.label_nodes:
print(f"\nConfusion Matrix for {node_type}:")
print(confusion_matrix(data[node_type].y.cpu(), predictions[node_type].cpu()))
print(confusion_matrix(dataset[node_type].y.cpu(), predictions[node_type].cpu()))
# Print comparison table
headers = ['Num Layers', 'Accuracy', 'Precision', 'Recall', 'F1']
......
......@@ -101,23 +101,24 @@ class SingleGraphOWLConverter(BaseOWLConverter):
# Builds OWL datatype properties (attributes) for each node type in the heterodata.
def _buildDataProperties(self):
classNamespace = Namespace(self.namespace)
xsdRange = XSD.boolean if self.create_data_properties_as_boolean else XSD.double
for node in self.dataset.node_types:
if "x" in self.dataset[node]:
xsdRange = XSD.boolean if self.create_data_properties_as_boolean else XSD.double
n = self.dataset[node].x.size(1)
for i in range(n):
propertyObjectPropertyName = f'{node}_property_{i+1}'
property_name = f'{node}_property_{i+1}'
if "xKeys" in self.dataset[node] and len(self.dataset[node].xKeys) > i:
propertyObjectPropertyName = self.dataset[node].xKeys[i]
property_name = self.dataset[node].xKeys[i]
if self.create_data_properties_as_boolean:
propertyObjectPropertyName = "has_" + propertyObjectPropertyName
propertyObjectProperty = classNamespace[propertyObjectPropertyName]
self.graph.add((propertyObjectProperty, RDF.type, OWL.DatatypeProperty))
self.graph.add((propertyObjectProperty, RDFS.domain, classNamespace[node]))
self.graph.add((propertyObjectProperty, RDFS.range, xsdRange))
property_name = "has_" + property_name
property_name_namespace = classNamespace[property_name]
self.graph.add((property_name_namespace, RDF.type, OWL.DatatypeProperty))
self.graph.add((property_name_namespace, RDFS.domain, classNamespace[node]))
self.graph.add((property_name_namespace, RDFS.range, xsdRange))
# Add high-level concepts if available for this node type
if node in self.high_level_concepts:
xsdRange = XSD.boolean if self.create_high_level_concepts_as_boolean else XSD.double
for theme in self.high_level_concepts[node].get('themes', []):
propertyObjectPropertyName = f'has_theme_{theme}'
if not self.create_high_level_concepts_as_boolean:
......@@ -125,7 +126,7 @@ class SingleGraphOWLConverter(BaseOWLConverter):
theme_namespace = classNamespace[propertyObjectPropertyName]
self.graph.add((theme_namespace, RDF.type, OWL.DatatypeProperty))
self.graph.add((theme_namespace, RDFS.domain, classNamespace[node]))
self.graph.add((theme_namespace, RDFS.range, XSD.boolean if self.create_high_level_concepts_as_boolean else XSD.double))
self.graph.add((theme_namespace, RDFS.range, xsdRange))
def _buildObjectProperties(self):
classNamespace = Namespace(self.namespace)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment