Skip to content
Snippets Groups Projects
Commit 9ee785aa authored by AjUm-HEIDI's avatar AjUm-HEIDI
Browse files

update utils

parent 22ddc82d
No related branches found
No related tags found
No related merge requests found
......@@ -597,13 +597,13 @@ def group_themes(tensor, vocabulary, num_groups, groupedKeywordsPath=''):
# Create the grouped tensor
num_rows = tensor.size(0)
num_themes = len(grouped_themes)
grouped_tensor = torch.zeros((num_rows, num_themes), dtype=torch.float32)
grouped_tensor = torch.zeros((num_rows, num_themes), dtype=torch.float32, device=tensor.device)
# For each word in the vocabulary, add its count to the corresponding theme
for word_idx, word in enumerate(vocabulary):
if word in word_to_theme:
theme_idx = word_to_theme[word]
grouped_tensor[:, theme_idx] += tensor[:, word_idx]
grouped_tensor[:, theme_idx] += tensor[:, word_idx].to(grouped_tensor.device)
themes = list(grouped_themes.keys())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment