Skip to content
Snippets Groups Projects
Commit e02b2136 authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

Update main files

parent 5ca9eabd
No related branches found
No related tags found
No related merge requests found
......@@ -126,8 +126,18 @@ def experiment(datasetName: str, add_node_type = True, iterations: int = 1, crea
metrics = model.train_model(epochs=300, lr=0.001)
original_labels = np.array([data.y.item() for data in structuredDataset.dataset])
predicted_labels = model.predict_all().clone().detach().cpu().numpy()
cm = confusion_matrix(original_labels, predicted_labels)
predicted_labels = model.predict_all()
valid_original_labels = []
valid_predicted_labels = []
for idx, predicted_label in enumerate(predicted_labels):
if predicted_label is not None:
valid_original_labels.append(original_labels[idx])
valid_predicted_labels.append(predicted_label.item())
cm = confusion_matrix(valid_original_labels, valid_predicted_labels)
with open(run_dir / f"gnn_results.csv", "w", newline="") as f:
writer = csv.writer(f)
......
......@@ -33,16 +33,25 @@ def run_gnn(structuredDataset: Base, entity_name, datasetName, results_dir):
print("Initializing GNN model...")
model = GNN(structuredDataset.dataset)
print("Training model...")
metrics = model.train_model(epochs=150, lr=0.01)
metrics = model.train_model(epochs=150, lr=0.001, show_progress=True)
evaluations["gnn"] = metrics[entity_name]
print("\nBest Training Metrics:")
print("\nGNN Metrics:")
for metric, value in metrics[entity_name].items():
print(f"{metric.capitalize()}: {value:.4f}")
original_labels = structuredDataset.dataset[entity_name].y
predicted_labels = model.predict_all()
cm = confusion_matrix(original_labels.cpu().numpy(), predicted_labels[entity_name].cpu().numpy())
valid_original_labels = []
valid_predicted_labels = []
for idx, predicted_label in enumerate(predicted_labels[entity_name]):
if predicted_label is not None:
valid_original_labels.append(original_labels[idx])
valid_predicted_labels.append(predicted_label)
cm = confusion_matrix(valid_original_labels, valid_predicted_labels)
print("Confusion Matrix:")
print(cm)
evaluations["confusion_matrix"] = cm.tolist()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment