Skip to content
Snippets Groups Projects
Commit ce33a8b4 authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

fix issue in the text experiment

parent b421fe46
No related branches found
No related tags found
No related merge requests found
{
"structured": [
{
"datasetName": "BA2Motif"
},
{
"datasetName": "BAMultiShape"
},
{
"datasetName": "MUTAG"
}
],
"text": [
{
"datasetName": "dblp",
"grouped_keyword_dir": "rawData/dblp/groups",
"entity_name": "author"
},
{
"datasetName": "imdb",
"grouped_keyword_dir": "rawData/imdb/groups",
"entity_name": "movie"
}
]
}
......@@ -185,9 +185,10 @@ def summarize_aggregated_results(aggregated_results, summary_filename):
print(f"Summary results saved to {summary_filename}")
def experiment(grouped_keyword_dir, dataset_name, entity_name, bag_of_words_size=1000, iterations=5):
def experiment(grouped_keyword_dir, dataset_name, entity_name, bag_of_words_size=1000, iterations=5, num_groups_list=[0, 5, 10, 15, 20, 25], create_high_level_concepts_as_boolean=False):
"""
Handles dataset loading and evaluation for experiments.
Manages the experiment based on specified number of groups and boolean concept creation settings.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA is {'available. Using GPU.' if device.type == 'cuda' else 'not available. Using CPU.'}")
......@@ -204,30 +205,22 @@ def experiment(grouped_keyword_dir, dataset_name, entity_name, bag_of_words_size
dataset = load_datasets(dataset_name=dataset_name, bag_of_words_size=bag_of_words_size)
model = run_gnn(dataset, entity_name, dataset_name, run_timestamp)
grouped_keyword_files = [
os.path.join(grouped_keyword_dir, f)
for f in os.listdir(grouped_keyword_dir)
if f.startswith('groupedKeywords_') and f.endswith('.json')
]
grouped_keyword_files.insert(0, "") # Allow the possibility of no grouped keywords
write_header = True
for create_high_level_concepts_as_boolean in [True, False]:
for group_keyword_file in sorted(grouped_keyword_files):
num_groups = 0 if group_keyword_file == "" else int(group_keyword_file.split('_')[1].split('.')[0])
for num_groups in num_groups_list:
group_keyword_file = "" if num_groups == 0 else os.path.join(grouped_keyword_dir, f'groupedKeywords_{num_groups}.json')
owl_graph_path = f'./owlGraphs/{dataset_name}_{run_timestamp}_{num_groups}_groups_{"bool" if create_high_level_concepts_as_boolean else "data"}.owl'
print("\n" + "=" * 50)
print(f"Running experiment {run} with create_high_level_concepts_as_boolean={create_high_level_concepts_as_boolean} and num_groups={num_groups}")
print("=" * 50)
high_level_concepts = fetch_high_level_concepts(dataset, num_groups, group_keyword_file) if num_groups != 0 else None
high_level_concepts = None if num_groups == 0 else fetch_high_level_concepts(dataset, num_groups, group_keyword_file)
results = explain_and_evaluate(
model, dataset.dataset, entity_name, owl_graph_path, high_level_concepts, create_high_level_concepts_as_boolean
)
append_to_csv_file(results, run_csv_filename, dataset_name, num_groups, write_header=write_header)
append_to_csv_file(results, run_csv_filename, dataset_name, num_groups, create_high_level_concepts_as_boolean, write_header=write_header)
for label, data in results.items():
# Initialize aggregation for this label and number of groups if not yet present
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment