Skip to content
Snippets Groups Projects
Commit 7733c9e8 authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

Update the class strucres of the structural database

parent b224916a
No related branches found
No related tags found
No related merge requests found
......@@ -137,13 +137,22 @@ class PatternFinder:
(self.max_frequency is None or presence_count <= self.max_frequency)):
frequent_patterns.append((data["graph"], data["graphs"]))
# Remove patterns that are supergraphs of others
to_remove = set()
for i, (graph_i, _) in enumerate(frequent_patterns):
for j, (graph_j, _) in enumerate(frequent_patterns):
if i != j and i not in to_remove and j not in to_remove:
# Check if graph_i is a supergraph of graph_j
if self._check_pattern_in_graph(graph_j, graph_i):
to_remove.add(i)
# Sort by frequency and size
frequent_patterns.sort(key=lambda x: (-len(x[1]), -len(x[0].nodes())))
return frequent_patterns
final_patterns = [pattern for i, pattern in enumerate(frequent_patterns) if i not in to_remove]
return final_patterns
def _check_pattern_in_graph(self, pattern_graph: nx.Graph, target_graph: nx.Graph) -> bool:
"""
Check if a pattern exists in a target graph.
Check if a pattern exists in a target graph, including edge properties.
Args:
pattern_graph (nx.Graph): The pattern to search for.
......@@ -152,11 +161,20 @@ class PatternFinder:
Returns:
bool: True if the pattern is found in the target graph.
"""
# target_graph.remove_edges_from(nx.selfloop_edges(target_graph))
# Define node and edge match functions
def node_match(node1, node2):
# Match nodes based on the 'type' attribute
return node1.get('type', None) == node2.get('type', None)
def edge_match(edge1, edge2):
# Match edges based on the 'label' attribute, if present
return edge1.get('type', None) == edge2.get('type', None)
# Initialize the graph matcher with both node and edge match functions
matcher = isomorphism.GraphMatcher(
target_graph,
pattern_graph,
node_match=self._node_match
node_match=node_match,
edge_match=edge_match
)
return matcher.subgraph_is_isomorphic()
......@@ -538,3 +538,66 @@ def create_masks(data: HeteroData, key: str, split_ratio=(0.7, 0.15, 0.15)):
data[key].train_mask = train_mask
data[key].val_mask = val_mask
data[key].test_mask = test_mask
def group_themes(tensor, vocabulary, num_groups, groupedKeywordsPath=''):
"""
Groups tensor features into themes based on vocabulary words.
Args:
tensor (torch.Tensor): Original tensor with word frequencies
vocabulary (list): List of words corresponding to tensor columns
num_groups (int): Number of groups/themes to create
groupedKeywordsPath (str, optional): Path to save/load grouped keywords
Returns:
tuple: (grouped_tensor, vocabulary)
- grouped_tensor: Tensor with columns representing themes
- vocabulary: List of themes (new vocabulary for grouped tensor)
"""
# Try to load existing groups if path is provided
if groupedKeywordsPath and os.path.exists(groupedKeywordsPath):
try:
with open(groupedKeywordsPath, "r") as json_file:
grouped_themes = json.load(json_file)
except Exception as e:
print(f"Error loading grouped themes: {e}")
raise
else:
try:
result = group_keywords_into_themes(vocabulary, num_groups)
grouped_themes = result.get("success", {})
# Save groups if path is provided
if groupedKeywordsPath:
with open(groupedKeywordsPath, "w") as json_file:
json.dump(grouped_themes, json_file, indent=4)
print(f"Grouped themes saved to {groupedKeywordsPath}.")
except Exception as e:
print(f"Error grouping keywords into themes: {e}")
raise
# Replace spaces with underscores in theme names
grouped_themes = {key.replace(' ', '_'): value for key, value in grouped_themes.items()}
# Create a mapping from words to theme indices
word_to_theme = {}
for theme_idx, (theme, words) in enumerate(grouped_themes.items()):
for word in words:
word_to_theme[word] = theme_idx
# Create the grouped tensor
num_rows = tensor.size(0)
num_themes = len(grouped_themes)
grouped_tensor = torch.zeros((num_rows, num_themes), dtype=torch.float32)
# For each word in the vocabulary, add its count to the corresponding theme
for word_idx, word in enumerate(vocabulary):
if word in word_to_theme:
theme_idx = word_to_theme[word]
grouped_tensor[:, theme_idx] += tensor[:, word_idx]
themes = list(grouped_themes.keys())
return grouped_tensor, themes
\ No newline at end of file
This diff is collapsed.
from collections import Counter, defaultdict
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import torch
import networkx as nx
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from torch_geometric.datasets import BA2MotifDataset
from torch_geometric.utils import to_networkx
from ConceptLearner.Visualiser import Visualiser
from ConceptLearner.PatternFinder import PatternFinder
from networkx.algorithms import isomorphism
from customDBs.StructuredDataset import StructuredDataset
import networkx as nx
class MultiShape:
def __init__(self, path='../rawData/BA2Motif'):
self.path = path
self.dataset = None
self.nx_graphs = None
self.patterns = []
self.pattern_combos = None
self.visualizer = None
self.pattern_tracker = None
self.dataset = BA2MotifDataset(root=self.path)
@staticmethod
def _fetch_known_patterns():
class BA2Motif(StructuredDataset):
"""
A class to process the BA2Motif dataset, find frequent patterns, and visualize graphs and patterns.
"""
def __init__(self, path='../rawData/BA2Motif'):
"""
Initialize the BA2Motif processor.
Args:
path (str): Path to the BA2MotifDataset directory.
"""
dataset = BA2MotifDataset(root=path)
super().__init__(dataset)
def _fetch_known_patterns(self):
"""
Generate predefined patterns for BA2Motif.
Returns:
list[nx.Graph]: List of predefined patterns.
"""
house = nx.Graph()
house.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (4, 0), (4, 1)])
house.graph['title'] = 'house'
wheel = nx.wheel_graph(5)
wheel = nx.Graph()
wheel.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)])
wheel.graph['title'] = 'wheel'
return [house, wheel]
def detect_motifs(self, debug=False):
self.nx_graphs = [to_networkx(data, to_undirected=True) for data in self.dataset]
if debug:
self.patterns = self._fetch_known_patterns()
finder = PatternFinder(
self.nx_graphs,
patterns=self.patterns,
min_size=5,
max_size=9,
min_frequency=1,
min_degree=2,
community_detection=debug
)
self.pattern_tracker = finder.find_patterns()
self.visualizer = Visualiser(self.nx_graphs, self.pattern_tracker)
keys_to_remove = set()
for i, (graph1, _) in enumerate(self.pattern_tracker):
for j, (graph2, _) in enumerate(self.pattern_tracker):
if i != j and i not in keys_to_remove:
matcher = isomorphism.GraphMatcher(graph2, graph1)
if matcher.subgraph_is_isomorphic():
# If graph1 is a subgraph of graph2, remove the larger graph (graph2)
keys_to_remove.add(j)
superclasses = []
for i, (graph, _) in enumerate(self.pattern_tracker):
self.patterns.append(graph)
superclasses.append(graph.graph.get("title", f"Pattern_{i}") )
print(superclasses)
self.dataset.data.super_classes = superclasses
index_to_superclasses = {i: [] for i in range(len(self.dataset))}
for superclass, (_, indices) in zip(superclasses, self.pattern_tracker):
for idx in indices:
index_to_superclasses[idx].append(superclass)
for i, data in enumerate(self.dataset):
data.super_classes = index_to_superclasses[i]
return self.dataset
def analyze_combinations(self):
self.pattern_combos = Counter(tuple(sorted(data.super_classes))
for data in self.dataset)
def visualize_graphs(self, graph_indices=None):
if not self.nx_graphs:
return None
self.visualizer.visualize_graphs(graph_indices)
return self.visualizer.output_dir
def visualize_patterns(self):
if not self.patterns:
return None
self.visualizer.visualize_patterns()
return self.visualizer.output_dir
def visualize_pattern_in_graph(self, pattern_idx: int, graph_idx: int):
if not self.patterns:
return None
self.visualizer.highlight_pattern(pattern_idx, graph_idx)
return self.visualizer.output_dir
def visualize_all(self):
if not self.patterns:
return None
self.visualizer.visualize_graphs()
self.visualizer.visualize_patterns()
for pattern_idx in range(len(self.patterns)):
for graph_idx in range(len(self.nx_graphs)):
self.visualizer.highlight_pattern(pattern_idx, graph_idx)
return self.visualizer.output_dir
def print_analysis(self):
if not hasattr(self, 'pattern_tracker') or not self.pattern_tracker:
print("Pattern analysis has not been performed yet.")
return
# Pattern Frequencies
print("Pattern Frequencies:")
for i, (graph, indices) in enumerate(self.pattern_tracker):
title = graph.graph.get("title", f"Pattern_{i}") # Use title if available, otherwise default to Pattern_i
print(f"{title}: found in {len(indices)} graphs")
# Pattern Combinations
print("\nPattern Combinations:")
# Create a Counter for combinations of patterns across dataset graphs
combination_counter = Counter(
tuple(sorted(graph.graph.get("title", f"Pattern_{i}") for i, (graph, indices) in enumerate(self.pattern_tracker) if idx in indices))
for idx in range(len(self.dataset))
)
for combo, count in combination_counter.most_common():
print(f"{' + '.join(combo)}: {count} graphs")
def get_dataset(self):
return self.dataset
def get_patterns(self):
return self.patterns
def get_graphs(self):
return self.nx_graphs
if __name__ == "__main__":
ms = MultiShape(path='../rawData/BAM2ultiShapes', debug=True)
ms = BA2Motif(path='../rawData/BAM2ultiShapes')
super_classes, presence_matrix = ms.detect_motifs(debug=True)
ms.visualize_patterns()
......@@ -157,8 +52,8 @@ if __name__ == "__main__":
# Predict 1 if "house" pattern is in super_classes, else 0
predicted_labels = [
1 if 'house' in data.super_classes else 0
for data in dataset
1 if np.array_equal(data, np.array([1, 0])) else 0
for data in presence_matrix
]
metrics = {
......@@ -187,7 +82,7 @@ if __name__ == "__main__":
if incorrect_indices:
print("\nSuperclasses of incorrectly predicted graphs:")
for idx in incorrect_indices:
super_classes = dataset[idx].super_classes
super_classes = presence_matrix[idx]
print(f"Graph {idx}: Superclasses: {super_classes}")
# Visualize all incorrect graphs
......
from collections import defaultdict
from datetime import datetime
import argparse
import json
import os
import re
import torch
from torch_geometric.data import HeteroData
from ConceptLearner.Utils import group_keywords_into_themes, clean_false_entries_in_dataset
from ConceptLearner.Utils import group_keywords_into_themes, clean_false_entries_in_dataset, group_themes
from nltk.corpus import stopwords
import nltk
def load_dblp(path="./rawData/dblp", bag_of_words_size=10, groupKeywords=True, groupedKeywordsPath='', removeAllFalseValues=True):
if groupedKeywordsPath == '':
try:
stop_words = set(stopwords.words('english'))
except LookupError:
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
# Define label categories as constants
LABELS = ["Database", "Data Mining", "Artificial Intelligence", "Information Retrieval"]
def load_dblp(path="./rawData/dblp", bag_of_words_size=100, groupKeywords=True, groupedKeywordsPath='', removeAllFalseValues=True):
"""
Loads the DBLP dataset and constructs a HeteroData object for PyTorch Geometric.
Args:
path (str): Path to the DBLP data directory.
bag_of_words_size (int): Number of top words to consider if grouping is disabled.
groupKeywords (bool): Whether to group keywords into themes.
groupedKeywordsPath (str): Path to save/load grouped keywords.
removeAllFalseValues (bool): Whether to clean false entries in the dataset.
Returns:
HeteroData: The constructed heterogeneous graph data.
"""
if not groupedKeywordsPath:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
groupedKeywordsPath = f"groupedKeywords_{timestamp}.json"
groupedKeywordsPath = os.path.join(path, f"groupedKeywords_{timestamp}.json")
# Load authors
author_ids, author_labels, author_id_dict = _get_authors(path)
# Load papers and grouped themes
paper_tensor, paper_id_dict, bag_of_words = _get_papers(path, bag_of_words_size, groupKeywords, groupedKeywordsPath)
author_paper_mappings = _get_author_paper_mappings(path, author_id_dict, paper_id_dict)
conf_ids, conf_id_dict = _get_conference(path)
paper_conference_mappings = _get_paper_conference_mappings(path, paper_id_dict, conf_id_dict)
# Construct HeteroData
dataset = HeteroData()
# Author Nodes
dataset['author'].num_nodes = len(author_labels)
dataset['author'].y = torch.tensor(list(author_labels))
dataset['author'].yLabel = ["Database", "Data Mining", "Artificial Intelligence", "Information Retrieval"]
dataset['author'].x = torch.full((dataset['author'].num_nodes, 1), 0.0, dtype=torch.float32) # Dummy feature matrix with float32
dataset['author'].y = torch.tensor(author_labels)
dataset['author'].yLabel = LABELS
dataset['author'].x = torch.zeros((dataset['author'].num_nodes, 1), dtype=torch.float32) # Dummy feature matrix
# Paper Nodes
dataset['paper'].x = paper_tensor.float()
dataset['paper'].xKeys = bag_of_words
# Author-Paper Edges
dataset['author', 'writes', 'paper'].edge_index = author_paper_mappings.t()
dataset['paper', 'written_by', 'author'].edge_index = author_paper_mappings.t()[[1, 0], :]
# dataset['conference'].num_nodes = len(conf_ids)
# dataset['paper', 'published_in', 'conference'].edge_index = paper_conference_mappings.t()
# dataset['conference', 'publishes', 'paper'].edge_index = paper_conference_mappings.clone().flip(dims=[1])
# if removeAllFalseValues:
# clean_false_entries_in_dataset(dataset, 'paper')
dataset['conference'].x = torch.tensor(conf_ids)
dataset['paper', 'published_in', 'conference'].edge_index = paper_conference_mappings.t()
return dataset
def _get_authors(path):
file_path = path + "/author_label.txt"
"""
Reads the author_label.txt file and extracts author IDs and labels.
Args:
path (str): Path to the DBLP data directory.
Returns:
tuple: (ids, labels, id_dict)
"""
file_path = os.path.join(path, "author_label.txt")
ids = []
labels = []
try:
with open(file_path, "r") as file:
for idx, line in enumerate(file):
for line in file:
line = line.strip()
id, label, name = line.split("\t")
if not line:
continue # Skip empty lines
parts = line.split("\t")
if len(parts) != 3:
continue
id, label, name = parts
ids.append(id)
labels.append(int(label))
except FileNotFoundError:
print(f"File not found: {file_path}")
raise
except Exception as e:
print(f"Error reading {file_path}: {e}")
raise
id_dict = {id: idx for idx, id in enumerate(ids)}
# Validate labels
if labels and max(labels) >= len(LABELS):
print(f"Encountered a label {max(labels)} outside the defined yLabel categories.")
raise ValueError("Invalid label found in author_label.txt.")
return ids, labels, id_dict
def _get_papers(path, bag_of_words_size, groupKeywords, groupedKeywordsPath):
file_path = path + "/paper.txt"
def fetch_themes(num_groups, groupedKeywordsPath=''):
"""
Groups tensor features into themes based on vocabulary words.
Args:
tensor (torch.Tensor): Original tensor with word frequencies
vocabulary (list): List of words corresponding to tensor columns
num_groups (int): Number of groups/themes to create
groupedKeywordsPath (str, optional): Path to save/load grouped keywords
Returns:
tuple: (grouped_tensor, grouped_themes)
- grouped_tensor: Tensor with columns representing themes
- grouped_themes: Dictionary mapping theme names to lists of words
"""
grouped_themes = {}
vocabulary = dataset['paper'].xKeys
tensor = dataset['paper'].x
grouped_tensor, grouped_themes = group_themes(tensor, vocabulary, num_groups, groupedKeywordsPath)
high_level_concepts = {
"paper": {
"themes" : grouped_themes,
"presence_matrix": grouped_tensor
}
}
return high_level_concepts
def _get_papers(path, bag_of_words_size):
"""
Reads the paper.txt file and processes the text into a tensor.
Args:
path (str): Path to the DBLP data directory.
bag_of_words_size (int): Number of top words to consider.
groupKeywords (bool): Whether to group keywords into themes.
groupedKeywordsPath (str): Path to save/load grouped keywords.
Returns:
tuple: (paper_tensor, paper_id_dict, vocabulary, grouped_themes)
"""
file_path = os.path.join(path, "paper.txt")
bag_of_words = []
vocabulary = []
total_words = {}
total_words = defaultdict(int)
paper_id_dict = {}
id_paper_dict = {}
stop_words = stopwords.words('english')
count = 0
with open(file_path, "r") as file:
try:
with open(file_path, "r") as file:
for idx, line in enumerate(file):
count += 1
line = line.strip()
id, text = line.split("\t", 1)
if not line:
continue # Skip empty lines
parts = line.split("\t", 1)
if len(parts) != 2:
continue
id, text = parts
words = re.findall(r'\w+', text.lower()) # Tokenize and convert to lowercase
word_count = defaultdict(int)
for word in words:
if word in stop_words:
continue
if word not in vocabulary:
if word not in total_words:
vocabulary.append(word)
if word in total_words:
total_words[word] += 1
else:
total_words[word] = 1
word_count[word] += 1
bag_of_words.append(dict(word_count))
paper_id_dict[id] = idx
id_paper_dict[idx] = id
except FileNotFoundError:
print(f"File not found: {file_path}")
raise
except Exception as e:
print(f"Error processing line {idx}: {line}. Error: {str(e)}")
print(f"Error processing {file_path} at line {idx}: {e}")
raise
total_words = dict(sorted(total_words.items(), key=lambda item: item[1], reverse=True)[:(1800 if groupKeywords else bag_of_words_size)])
# Get top N words for vocabulary
top_n = bag_of_words_size
total_words = dict(sorted(total_words.items(), key=lambda item: item[1], reverse=True)[:top_n])
vocabulary = list(total_words.keys())
# Create the initial matrix
matrix = []
# Group keywords based on the groupKeywords flag
if groupKeywords:
fileExists = os.path.exists(groupedKeywordsPath)
grouped_themes = {}
if fileExists:
# Load the grouped themes from the JSON file
with open(groupedKeywordsPath, "r") as json_file:
grouped_themes = json.load(json_file)
else:
result = group_keywords_into_themes(vocabulary, bag_of_words_size)
grouped_themes = result["success"]
with open(groupedKeywordsPath, "w") as json_file:
json.dump(grouped_themes, json_file, indent=4)
grouped_themes = {key.replace(' ', '_'): value for key, value in grouped_themes.items()}
# Create a mapping from keywords to theme indices
keyword_to_theme = {}
for theme_idx, (theme, keywords) in enumerate(grouped_themes.items()):
for keyword in keywords:
keyword_to_theme[keyword] = theme_idx
# Create the matrix based on grouped themes
for item in bag_of_words:
theme_vector = [0] * len(grouped_themes)
for word, count in item.items():
if word in keyword_to_theme:
theme_idx = keyword_to_theme[word]
theme_vector[theme_idx] += count
matrix.append(theme_vector)
# Update vocabulary to reflect themes
vocabulary = list(grouped_themes.keys())
else:
# Create the matrix without grouping
for item in bag_of_words:
vector = [item.get(word, 0) for word in vocabulary]
matrix.append(vector)
# Replace spaces with underscores in vocabulary
vocabulary = [word.replace(' ', '_') for word in vocabulary]
# Convert to tensor
paper_tensor = torch.tensor(matrix, dtype=torch.float32)
# Convert the matrix to a PyTorch tensor
tensor = torch.tensor(matrix, dtype=torch.float32)
return tensor, paper_id_dict, vocabulary
return paper_tensor, paper_id_dict, vocabulary
def _get_author_paper_mappings(path, author_id_dict, paper_id_dict):
file_path = path + "/paper_author.txt"
"""
Reads the paper_author.txt file and creates author-paper edge mappings.
Args:
path (str): Path to the DBLP data directory.
author_id_dict (dict): Mapping from author IDs to indices.
paper_id_dict (dict): Mapping from paper IDs to indices.
Returns:
torch.Tensor: Edge indices tensor of shape [2, num_edges].
"""
file_path = os.path.join(path, "paper_author.txt")
mappings = []
try:
with open(file_path, "r") as file:
for idx, line in enumerate(file):
line = line.strip()
paper, author = line.split("\t", 1)
if not line:
continue # Skip empty lines
parts = line.split("\t", 1)
if len(parts) != 2:
continue
paper, author = parts
if paper in paper_id_dict and author in author_id_dict:
mappings.append([author_id_dict[author], paper_id_dict[paper]])
except FileNotFoundError:
print(f"File not found: {file_path}")
raise
except Exception as e:
print(f"Error reading {file_path} at line {idx}: {e}")
raise
if mappings:
mappings = torch.tensor(mappings, dtype=torch.long)
mappings = torch.unique(mappings, dim=0) # Remove duplicates
else:
mappings = torch.empty((2, 0), dtype=torch.long)
mappings = torch.tensor(mappings)
return mappings
def _get_conference(path):
file_path = path + "/conf.txt"
"""
Reads the conf.txt file and extracts conference IDs and names.
Args:
path (str): Path to the DBLP data directory.
Returns:
tuple: (ids, id_dict)
"""
file_path = os.path.join(path, "conf.txt")
ids = []
try:
with open(file_path, "r") as file:
for idx, line in enumerate(file):
for line in file:
line = line.strip()
id, name = line.split("\t")
if not line:
continue # Skip empty lines
parts = line.split("\t")
if len(parts) != 2:
continue
id, name = parts
ids.append(id)
except FileNotFoundError:
print(f"File not found: {file_path}")
raise
except Exception as e:
print(f"Error reading {file_path}: {e}")
raise
id_dict = {id: idx for idx, id in enumerate(ids)}
return ids, id_dict
def _get_paper_conference_mappings(path, paper_id_dict, conf_id_dict):
file_path = path + "/paper_conf.txt"
"""
Reads the paper_conf.txt file and creates paper-conference edge mappings.
Args:
path (str): Path to the DBLP data directory.
paper_id_dict (dict): Mapping from paper IDs to indices.
conf_id_dict (dict): Mapping from conference IDs to indices.
Returns:
torch.Tensor: Edge indices tensor of shape [2, num_edges].
"""
file_path = os.path.join(path, "paper_conf.txt")
mappings = []
try:
with open(file_path, "r") as file:
for idx, line in enumerate(file):
line = line.strip()
paper, conference = line.split("\t", 1)
if not line:
continue # Skip empty lines
parts = line.split("\t", 1)
if len(parts) != 2:
continue
paper, conference = parts
if paper in paper_id_dict and conference in conf_id_dict:
mappings.append([paper_id_dict[paper], conf_id_dict[conference]])
except FileNotFoundError:
print(f"File not found: {file_path}")
raise
except Exception as e:
print(f"Error reading {file_path} at line {idx}: {e}")
raise
if mappings:
mappings = torch.tensor(mappings, dtype=torch.long)
mappings = torch.unique(mappings, dim=0) # Remove duplicates
else:
mappings = torch.empty((2, 0), dtype=torch.long)
mappings = torch.tensor(mappings)
return mappings
......
from collections import Counter, defaultdict
import torch
import networkx as nx
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx
from ConceptLearner.PatternFinder import PatternFinder
from ConceptLearner.Visualiser import Visualiser
from customDBs.StructuredDataset import StructuredDataset
class MUTAG:
class MUTAG(StructuredDataset):
"""
A class to process the MUTAG dataset, find frequent patterns, and visualize graphs and patterns.
"""
......@@ -19,31 +13,10 @@ class MUTAG:
Args:
path (str): Path to the TUDataset directory.
"""
self.path = path
self.dataset = None
self.nx_graphs = []
self.patterns = []
self.pattern_tracker = None
self.visualizer = None
self._initialize()
def _initialize(self):
"""
Load the MUTAG dataset and convert graphs to NetworkX format.
"""
self.dataset = TUDataset(root=self.path, name='MUTAG')
self.nx_graphs = [to_networkx(data, to_undirected=True) for data in self.dataset]
# Add node type information to graphs
for data, nx_graph in zip(self.dataset, self.nx_graphs):
for node in nx_graph.nodes():
node_features = data.x[node]
node_type = torch.where(node_features == 1)[0].item()
nx_graph.nodes[node]['type'] = node_type
if self.add_super_class:
self.detect_motifs()
dataset = TUDataset(root=path, name='MUTAG')
node_labels = ["C", "N", "O", "F", "I", "Cl", "Br"]
edge_labels = ["aromatic", "single", "double", "triple"]
super().__init__(dataset, node_labels=node_labels, edge_labels=edge_labels)
def _fetch_known_patterns(self):
"""
......@@ -52,169 +25,40 @@ class MUTAG:
Returns:
list[nx.Graph]: List of predefined patterns.
"""
cycle_4 = nx.cycle_graph(4)
cycle_4.graph['title'] = 'cycle_4'
star_5 = nx.star_graph(5)
star_5.graph['title'] = 'star_5'
return [cycle_4, star_5]
return []
def detect_motifs(self, debug=False):
"""
Find frequent patterns in the dataset using the PatternFinder.
Override to pass MUTAG-specific parameters to the base detect_motifs.
Args:
debug (bool): Whether to use predefined patterns for debugging.
Returns:
TUDataset: Dataset annotated with detected patterns and superclasses.
list: Titles of the discovered patterns.
np.ndarray: Presence matrix of patterns in graphs.
"""
if debug:
self.patterns = self._fetch_known_patterns()
finder = PatternFinder(
self.nx_graphs,
patterns=self.patterns,
return super().detect_motifs(
debug=debug,
min_size=3,
max_size=20,
min_frequency=5,
min_frequency=2,
min_density=0.4,
min_degree=1
)
self.pattern_tracker = finder.find_patterns()
self.visualizer = Visualiser(self.nx_graphs, self.pattern_tracker)
superclasses = []
for i, (graph, _) in enumerate(self.pattern_tracker):
self.patterns.append(graph)
superclasses.append(graph.graph.get("title", f"Pattern_{i}"))
self.dataset.data.super_classes = superclasses
index_to_superclasses = {i: [] for i in range(len(self.dataset))}
for superclass, (_, indices) in zip(superclasses, self.pattern_tracker):
for idx in indices:
index_to_superclasses[idx].append(superclass)
for i, data in enumerate(self.dataset):
data.super_classes = index_to_superclasses[i]
return self.dataset
def visualize_graphs(self, graph_indices=None):
"""
Visualize selected graphs in the dataset.
Args:
graph_indices (list[int], optional): Indices of graphs to visualize. Defaults to None.
Returns:
str: Directory where the graphs are saved.
"""
if not self.nx_graphs:
print("No graphs available to visualize.")
return None
self.visualizer.visualize_graphs(graph_indices)
return self.visualizer.output_dir
def visualize_patterns(self):
"""
Visualize all discovered patterns.
Returns:
str: Directory where the patterns are saved.
"""
if not self.patterns:
print("No patterns available to visualize.")
return None
self.visualizer.visualize_patterns()
return self.visualizer.output_dir
def visualize_pattern_in_graph(self, pattern_idx, graph_idx):
"""
Highlight a specific pattern within a specific graph.
Args:
pattern_idx (int): Index of the pattern to visualize.
graph_idx (int): Index of the graph to visualize the pattern in.
Returns:
str: Directory where the visualization is saved.
"""
if not self.patterns:
print("No patterns available to visualize.")
return None
self.visualizer.highlight_pattern(pattern_idx, graph_idx)
return self.visualizer.output_dir
def visualize_all(self):
"""
Visualize all graphs and patterns in the dataset.
"""
if not self.patterns:
print("No patterns available to visualize.")
return None
self.visualizer.visualize_graphs()
self.visualizer.visualize_patterns()
for pattern_idx in range(len(self.patterns)):
for graph_idx in range(len(self.nx_graphs)):
self.visualizer.highlight_pattern(pattern_idx, graph_idx)
return self.visualizer.output_dir
def print_analysis(self):
"""
Print a summary of the dataset and discovered patterns.
"""
if not hasattr(self, 'pattern_tracker') or not self.pattern_tracker:
print("Pattern analysis has not been performed yet.")
return
print("Pattern Frequencies:")
for i, (graph, indices) in enumerate(self.pattern_tracker):
title = graph.graph.get("title", f"Pattern_{i}")
print(f"{title}: found in {len(indices)} graphs")
print(f"Number of graphs: {len(self.dataset)}")
print(f"Number of classes: {self.dataset.num_classes}")
print(f"Number of patterns discovered: {len(self.patterns)}")
def get_dataset(self):
"""
Get the processed dataset.
Returns:
TUDataset: The processed MUTAG dataset.
"""
return self.dataset
def get_patterns(self):
"""
Get the discovered patterns.
Returns:
list[nx.Graph]: List of frequent patterns.
"""
return self.patterns
def get_graphs(self):
"""
Get the processed NetworkX graphs.
Returns:
list[nx.Graph]: List of NetworkX graphs.
"""
return self.nx_graphs
if __name__ == "__main__":
mutag = MUTAG()
original_labels = [data.y.item() for data in mutag.dataset]
print(original_labels)
superclasses, presence_matrix = mutag.detect_motifs()
print(presence_matrix)
mutag.visualize_patterns()
mutag.print_analysis()
mutag.detect_motifs()
graph_indices_to_visualize = [45, 81, 128]
graph_indices_to_visualize = [1, 2, 5, 7, 9, # Label 0
10, 11, 12, 14, 15] # Label 1
print(f"Visualizing graphs: {graph_indices_to_visualize}")
graph_visualization_dir = mutag.visualize_graphs(graph_indices_to_visualize)
if graph_visualization_dir:
......@@ -227,3 +71,4 @@ if __name__ == "__main__":
for pattern_idx in range(min(2, len(mutag.patterns))):
for graph_idx in graph_indices_to_visualize:
mutag.visualize_pattern_in_graph(pattern_idx, graph_idx)
from collections import Counter, defaultdict
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import torch
import networkx as nx
from torch_geometric.datasets import BAMultiShapesDataset
from torch_geometric.utils import to_networkx
from ConceptLearner.Visualiser import Visualiser
from ConceptLearner.PatternFinder import PatternFinder
from networkx.algorithms import isomorphism
from StructuredDatasets import StructuredDatasets
class MultiShape:
class MultiShape(StructuredDatasets):
"""
A class to process the BAMultiShapes dataset, find frequent patterns, and visualize graphs and patterns.
"""
def __init__(self, path='../rawData/BAMultiShapes'):
self.path = path
self.dataset = None
self.nx_graphs = None
self.patterns = []
self.pattern_combos = None
self.visualizer = None
self.pattern_tracker = None
self._initialize()
def _initialize(self):
self.dataset = BAMultiShapesDataset(root=self.path)
self.visualizer = Visualiser(self.nx_graphs, self.pattern_tracker)
"""
Initialize the MultiShape processor.
Args:
path (str): Path to the BAMultiShapesDataset directory.
"""
dataset = BAMultiShapesDataset(root=path)
super().__init__(dataset)
@staticmethod
def _fetch_known_patterns():
"""
Generate predefined patterns for BAMultiShapes.
Returns:
list[nx.Graph]: List of predefined patterns.
"""
house = nx.Graph()
house.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (4, 0), (4, 1)])
house.graph['title'] = 'house'
......@@ -39,114 +38,22 @@ class MultiShape:
return [house, grid, wheel]
def detect_motifs(self, debug=False):
self.nx_graphs = [to_networkx(data, to_undirected=True) for data in self.dataset]
if debug:
self.patterns = self._fetch_known_patterns()
finder = PatternFinder(
self.nx_graphs,
patterns=self.patterns,
min_size=5,
max_size=9,
min_frequency=1,
min_degree=2,
community_detection=debug
)
self.pattern_tracker = finder.find_patterns()
keys_to_remove = set()
for i, (graph1, _) in enumerate(self.pattern_tracker):
for j, (graph2, _) in enumerate(self.pattern_tracker):
if i != j and i not in keys_to_remove:
matcher = isomorphism.GraphMatcher(graph1, graph2)
if matcher.subgraph_is_isomorphic():
keys_to_remove.add(i)
superclasses = []
for i, (graph, _) in enumerate(self.pattern_tracker):
self.patterns.append(graph)
superclasses.append(graph.graph.get("title", f"Pattern_{i}"))
self.dataset.data.super_classes = superclasses
index_to_superclasses = {i: [] for i in range(len(self.dataset))}
for superclass, (_, indices) in zip(superclasses, self.pattern_tracker):
for idx in indices:
index_to_superclasses[idx].append(superclass)
for i, data in enumerate(self.dataset):
data.super_classes = index_to_superclasses[i]
return self.dataset
def visualize_graphs(self, graph_indices=None):
if not self.nx_graphs:
return None
self.visualizer.visualize_graphs(graph_indices)
return self.visualizer.output_dir
def visualize_patterns(self):
if not self.patterns:
return None
self.visualizer.visualize_patterns()
return self.visualizer.output_dir
def visualize_pattern_in_graph(self, pattern_idx: int, graph_idx: int):
if not self.patterns:
return None
self.visualizer.highlight_pattern(pattern_idx, graph_idx)
return self.visualizer.output_dir
def visualize_all(self):
if not self.patterns:
return None
self.visualizer.visualize_graphs()
self.visualizer.visualize_patterns()
for pattern_idx in range(len(self.patterns)):
for graph_idx in range(len(self.nx_graphs)):
self.visualizer.highlight_pattern(pattern_idx, graph_idx)
return self.visualizer.output_dir
def print_analysis(self):
if not hasattr(self, 'pattern_tracker') or not self.pattern_tracker:
print("Pattern analysis has not been performed yet.")
return
print("Pattern Frequencies:")
for i, (graph, indices) in enumerate(self.pattern_tracker):
title = graph.graph.get("title", f"Pattern_{i}")
print(f"{title}: found in {len(indices)} graphs")
print("\nPattern Combinations:")
combination_counter = Counter(
tuple(sorted(graph.graph.get("title", f"Pattern_{i}") for i, (graph, indices) in enumerate(self.pattern_tracker) if idx in indices))
for idx in range(len(self.dataset))
)
for combo, count in combination_counter.most_common():
print(f"{' + '.join(combo)}: {count} graphs")
def get_dataset(self):
return self.dataset
def get_patterns(self):
return self.patterns
def get_graphs(self):
return self.nx_graphs
if __name__ == "__main__":
ms = MultiShape(path='../rawData/BAMultiShapes', debug=True)
ms.detect_motifs()
ms.visualize_patterns()
ms.print_analysis()
ms = MultiShape(path='../rawData/BAMultiShapes')
ms.detect_motifs(debug=True)
patterns, presence_matrix = ms.detect_motifs(ms.dataset)
dataset = ms.get_dataset()
original_labels = [data.y.item() for data in dataset]
predicted_labels = [1 if len(data.super_classes) == 2 else 0 for data in dataset]
predicted_labels = [1 if row.sum() == 2 else 0 for row in presence_matrix]
print(patterns)
ms.visualize_patterns()
ms.print_analysis()
metrics = {
'accuracy': accuracy_score(original_labels, predicted_labels),
'confusion_matrix': confusion_matrix(original_labels, predicted_labels),
......
from collections import Counter
import numpy as np
import torch
from torch_geometric.utils import to_networkx
from ConceptLearner.PatternFinder import PatternFinder
from ConceptLearner.Visualiser import Visualiser
class StructuredDataset:
"""
Base class for processing graph datasets, finding patterns, and visualizing graphs.
"""
def __init__(self, dataset, node_labels=None, edge_labels=None):
"""
Initialize the StructuredDatasets.
Args:
dataset: The PyTorch Geometric dataset to process.
node_labels (list[str], optional): List of node label names.
edge_labels (list[str], optional): List of edge label names.
"""
self.dataset = dataset
self.nx_graphs = []
self.patterns = []
self.pattern_tracker = None
self.visualizer = None
self.node_labels = node_labels
self.edge_labels = edge_labels
self._initialize()
def _initialize(self):
"""
Convert the dataset into NetworkX graphs with annotated node and edge features.
"""
self.nx_graphs = [to_networkx(data, to_undirected=True) for data in self.dataset]
if self.node_labels:
for data, nx_graph in zip(self.dataset, self.nx_graphs):
for node in nx_graph.nodes():
node_features = data.x[node]
node_type = torch.where(node_features == 1)[0].item()
nx_graph.nodes[node]['type'] = self.node_labels[node_type]
if self.edge_labels:
for data, nx_graph in zip(self.dataset, self.nx_graphs):
for edge_index, edge_features in zip(data.edge_index.T, data.edge_attr):
src, dest = edge_index.tolist()
edge_type = torch.where(edge_features == 1)[0].item()
nx_graph[src][dest]['type'] = self.edge_labels[edge_type]
def detect_motifs(self, debug=False, **finder_kwargs):
"""
Find frequent patterns in the dataset.
Args:
debug (bool): Whether to use predefined patterns for debugging.
**finder_kwargs: Additional parameters for PatternFinder.
Returns:
list: Titles of the discovered patterns.
np.ndarray: Presence matrix of patterns in graphs.
"""
if debug and hasattr(self, "_fetch_known_patterns"):
self.patterns = self._fetch_known_patterns()
finder = PatternFinder(self.nx_graphs, patterns=self.patterns, **finder_kwargs)
self.pattern_tracker = finder.find_patterns()
self.visualizer = Visualiser(self.nx_graphs, self.pattern_tracker)
superclasses = []
for i, (graph, _) in enumerate(self.pattern_tracker):
self.patterns.append(graph)
superclasses.append(graph.graph.get("title", f"Pattern_{i}"))
num_graphs = len(self.dataset)
num_patterns = len(superclasses)
presence_matrix = np.zeros((num_graphs, num_patterns), dtype=int)
for pattern_idx, (_, graph_indices) in enumerate(self.pattern_tracker):
for graph_idx in graph_indices:
presence_matrix[graph_idx, pattern_idx] = 1
return superclasses, presence_matrix
def visualize_graphs(self, graph_indices=None):
"""
Visualize selected graphs in the dataset.
Args:
graph_indices (list[int], optional): Indices of graphs to visualize.
Returns:
str: Directory where the graphs are saved.
"""
if not self.nx_graphs:
return None
return self.visualizer.visualize_graphs(graph_indices)
def visualize_patterns(self):
"""
Visualize all discovered patterns.
Returns:
str: Directory where the patterns are saved.
"""
if not self.patterns:
return None
return self.visualizer.visualize_patterns()
def visualize_pattern_in_graph(self, pattern_idx, graph_idx):
"""
Highlight a specific pattern within a specific graph.
Args:
pattern_idx (int): Index of the pattern to visualize.
graph_idx (int): Index of the graph to visualize the pattern in.
Returns:
str: Directory where the visualization is saved.
"""
if not self.patterns:
return None
return self.visualizer.highlight_pattern(pattern_idx, graph_idx)
def print_analysis(self):
"""
Print a summary of the dataset and discovered patterns.
"""
if not self.pattern_tracker:
print("Pattern analysis has not been performed yet.")
return
print("Pattern Frequencies:")
for i, (graph, indices) in enumerate(self.pattern_tracker):
title = graph.graph.get("title", f"Pattern_{i}")
print(f"{title}: found in {len(indices)} graphs")
print(f"Number of graphs: {len(self.dataset)}")
print(f"Number of patterns discovered: {len(self.patterns)}")
def get_dataset(self):
"""
Get the processed dataset.
Returns:
Dataset: The processed dataset.
"""
return self.dataset
def get_patterns(self):
"""
Get the discovered patterns.
Returns:
list[nx.Graph]: List of frequent patterns.
"""
return self.patterns
def get_graphs(self):
"""
Get the processed NetworkX graphs.
Returns:
list[nx.Graph]: List of NetworkX graphs.
"""
return self.nx_graphs
from datetime import datetime
from sklearn.metrics import confusion_matrix
import torch
import json
from ConceptLearner.DiscriminativeExplainer import DiscriminativeExplainer
from customDBs.MultiShape import MultiShape
from customDBs.MUTAG import MUTAG
from customDBs.BA2Motif import BA2Motif
from customDBs.StructuredDataset import StructuredDataset
from ConceptLearner.GNN4 import GNN
from ontolearn.owlapy.render import DLSyntaxObjectRenderer
from ontolearn.metrics import Accuracy, Precision, Recall, F1
from pathlib import Path
def explain_gnn(model, dataset, datasetName, explanations_dict, high_level_concepts = None):
"""Explain GNN predictions and store the best results for each label."""
renderer = DLSyntaxObjectRenderer()
explainer = DiscriminativeExplainer(
model,
dataset,
"http://example.org/",
owl_graph_path=f"./owlGraphs/{datasetName}_experiment{ '_with_motif' if high_level_concepts is not None else '_without_motif'}.owl",
generate_new_owl_file=True,
ignore_nodes=True,
high_level_concepts=high_level_concepts
)
for label in range(2):
# Generate explanations
hypotheses, explainer_model = explainer.explain(label, 5, max_runtime=30, num_generations=400, quality_func=F1())
[print(renderer.render(hypothesis.concept), hypothesis.quality) for hypothesis in hypotheses]
# Best hypothesis
best_hypothesis = hypotheses[0]
rendered_hypothesis = renderer.render(best_hypothesis.concept)
print(f"Best hypothesis for label {label}: {rendered_hypothesis}, Quality: {best_hypothesis.quality}, Length: {best_hypothesis.len}")
# Evaluate the best hypothesis
metrics = [Accuracy(), Recall(), Precision(), F1()]
evaluation = {
"addedMotifs": high_level_concepts is not None,
"concept": rendered_hypothesis,
"quality": best_hypothesis.quality,
"length": best_hypothesis.len
}
for metric in metrics:
evaluated_concept = explainer_model.kb.evaluate_concept(
best_hypothesis.concept, metric, explainer_model._learning_problem
)
evaluation[metric.name] = evaluated_concept.q
# Add explanation to the appropriate label
explanations_dict["explanation"].setdefault(f"label_{label}", []).append(evaluation)
print("Evaluation results:")
print(evaluation)
def experiment(structuredDataset: StructuredDataset, datasetName: str):
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Prepare output directory
output_dir = Path("evaluation_results")
output_dir.mkdir(parents=True, exist_ok=True)
# Initialize JSON structure
evaluations = {
"gnn": {},
"explanation": {}
}
# Initialize GNN model
print("Initializing GNN model...")
model = GNN(structuredDataset.dataset)
# Train model
print("Training model...")
metrics = model.train_model(epochs=300, lr=0.001, show_progress=True)
original_labels = [data.y.item() for data in structuredDataset.dataset]
predicted_labels = model.predict_all()
print("Confusion Matrix:")
print(confusion_matrix(original_labels, predicted_labels))
# Save GNN training metrics
evaluations["gnn"] = metrics
print("\nBest Training Metrics:")
for metric, value in metrics.items():
print(f"{metric.capitalize()}: {value:.4f}")
# Explain GNN before finding motifs
# print("\nBefore finding motifs:")
# explain_gnn(model, structuredDataset.dataset, datasetName, evaluations)
# Detect motifs and visualize patterns
print("\nDetecting motifs...")
patterns, presence_matrix = structuredDataset.detect_motifs()
high_level_concepts={
"patterns" : patterns,
"presence_matrix": presence_matrix
}
print("\nDetected motifs...")
print(patterns)
patterns_path = structuredDataset.visualize_patterns()
timeStamp = patterns_path.split(1)
evaluations["path"] = patterns_path
# Explain GNN after finding motifs
print("\nAfter finding motifs:")
explain_gnn(model, structuredDataset.dataset, datasetName, evaluations, high_level_concepts)
output_file = output_dir / f"{datasetName}_evaluations_{timeStamp}.json"
# Save evaluations to file
with open(output_file, "w") as f:
json.dump(evaluations, f, indent=4, ensure_ascii=False)
def main():
datasets = [
# {
# "structuredDataset": BA2Motif(),
# "datasetName": "BA2Motif"
# }
# ,
# {
# "structuredDataset": MultiShape(),
# "datasetName": "MultiShape"
# }
# ,
{
"structuredDataset": MUTAG(),
"datasetName": "MUTAG"
}
]
for dataset in datasets:
experiment(dataset["structuredDataset"], dataset["datasetName"])
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment