Skip to content
Snippets Groups Projects
Commit 2dd20086 authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

Add conditions to explain ceratin labels

parent 2bcb1466
No related branches found
No related tags found
No related merge requests found
......@@ -58,7 +58,9 @@ def main():
structured_datasets_experiment(
dataset["dataset_name"],
add_node_type=dataset["add_node_type"],
iterations=args.iterations
iterations=args.iterations,
create_high_level_concepts_as_boolean=create_high_level_concepts_as_boolean,
selected_labels=args.labels
)
elif dataset_type == "text":
text_based_datasets_experiment(
......
......@@ -14,7 +14,7 @@ from ontolearn.metrics import Accuracy, Precision, Recall, F1
from pathlib import Path
def explain_gnn(model, dataset, datasetName, run_dir, add_node_type, high_level_concepts=None):
def explain_gnn(model, dataset, datasetName, run_dir, add_node_type, high_level_concepts=None, create_high_level_concepts_as_boolean=True, selected_labels=None):
"""Explain GNN predictions and store results in a CSV file."""
renderer = DLSyntaxObjectRenderer()
explainer = DiscriminativeExplainer(
......@@ -25,11 +25,13 @@ def explain_gnn(model, dataset, datasetName, run_dir, add_node_type, high_level_
generate_new_owl_file=True,
ignore_nodes=False,
high_level_concepts=high_level_concepts,
create_high_level_concepts_as_boolean=False,
create_high_level_concepts_as_boolean=create_high_level_concepts_as_boolean,
add_node_type=add_node_type
)
for label in range(2):
labels_to_evaluate = selected_labels or range(2)
for label in labels_to_evaluate:
# Generate explanations
hypotheses, explainer_model = explainer.explain(label, 5, max_runtime=90, num_generations=1000)
......@@ -74,7 +76,7 @@ def explain_gnn(model, dataset, datasetName, run_dir, add_node_type, high_level_
])
def experiment(datasetName: str, add_node_type = True, iterations: int = 1):
def experiment(datasetName: str, add_node_type = True, iterations: int = 1, create_high_level_concepts_as_boolean=True, selected_labels=None):
"""
Run the experiment for the specified dataset multiple times.
......@@ -163,7 +165,7 @@ def experiment(datasetName: str, add_node_type = True, iterations: int = 1):
structuredDataset.visualize_pattern_in_graph(pattern_idx, graph_idx)
print("\nAfter finding motifs:")
explain_gnn(model, structuredDataset.dataset, datasetName, run_dir, add_node_type, high_level_concepts)
explain_gnn(model, structuredDataset.dataset, datasetName, run_dir, add_node_type, high_level_concepts, create_high_level_concepts_as_boolean, selected_labels)
print(f"Results for iteration {iteration + 1} saved in {run_dir}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment