Skip to content
Snippets Groups Projects
Commit f6144c50 authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

update the device type

parent 9ee785aa
No related branches found
No related tags found
No related merge requests found
......@@ -161,7 +161,7 @@ class GNN(torch.nn.Module):
for node_type in self.label_nodes:
# Here the targets are per-node.
target = self.data[node_type].y
target = self.data[node_type].y.to(self.device)
loss = F.cross_entropy(
out_dict[node_type],
target,
......
......@@ -41,7 +41,7 @@ def run_gnn(structuredDataset: Base, entity_name, datasetName, results_dir):
original_labels = structuredDataset.dataset[entity_name].y
predicted_labels = model.predict_all()
cm = confusion_matrix(original_labels, predicted_labels[entity_name])
cm = confusion_matrix(original_labels.cpu().numpy(), predicted_labels[entity_name].cpu().numpy())
print("Confusion Matrix:")
print(cm)
evaluations["confusion_matrix"] = cm.tolist()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment