Skip to content
Snippets Groups Projects
Commit 62b7eb17 authored by marvnsch's avatar marvnsch
Browse files

Some enhancements to the model

parent 8f7ee777
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id:initial_id tags:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
from matplotlib import pyplot as plt
from pathlib import Path
```
%% Cell type:markdown id:d8d7f32150682efd tags:
## 0. Prepare the data
%% Cell type:code id:f7c39c06ce3a14db tags:
``` python
# split the data into training/dev/test
data_array_en = open(Path('./data/training-data/dev/news-test2008.en'), 'r').readlines()
data_array_de = open(Path('./data/training-data/dev/news-test2008.de'), 'r').readlines()
def load_data() -> list[str]:
with open("data/training-data/eup/europarl-v7.de-en.de", "r", encoding="utf8") as f:
data_de = [line.rstrip("\n") for line in f]
with open("data/training-data/eup/europarl-v7.de-en.en", "r", encoding="utf8") as f:
data_en = [line.rstrip("\n") for line in f]
ltd = set() # save lines to delete later
for i in range(max(len(data_de), len(data_en))):
# Move sentence to next line if line is empty other file
if data_de[i] == "":
data_en[i+1] = data_en[i] + " " + data_en[i+1]
ltd.add(i)
if data_en[i] == "":
data_de[i+1] = data_de[i] + " " + data_de[i+1]
ltd.add(i)
# Remove lines, where difference in words is > 40%
if abs(count_words(data_de[i]) - count_words(data_en[i])) / (max(count_words(data_de[i]), count_words(data_en[i])) + 1) > 0.4:
ltd.add(i)
# Remove lines < 3 words or > 25 words
if max(count_words(data_de[i]), count_words(data_en[i])) < 3 or max(count_words(data_de[i]), count_words(data_en[i])) > 25:
ltd.add(i)
temp_de = [l for i, l in enumerate(data_de) if i not in ltd]
data_de = temp_de
temp_en = [l for i, l in enumerate(data_en) if i not in ltd]
data_en = temp_en
print(len(data_de),len(data_en))
# Print 3 random sentence pairs
ix = torch.randint(low=0, high=max(len(data_de), len(data_en)), size=(3, ))
for i in ix:
print(f"Zeile: {i}\nDeutsch: {data_de[i]}\nEnglish: {data_en[i]}\n")
print(f"\nNumber of lines: {len(data_de), len(data_en)}")
return data_de, data_en
data_en = open(Path('./data/training-data/dev/news-test2008.en'), 'r').read()
def count_words(string: str) -> int:
return len(string.split())
idx = torch.randint(low=0, high=2000, size=(3, ))
for id in idx:
print(id.item())
print("ENG: " + data_array_en[id.item()] + "DEU: " + data_array_de[id.item()])
source, target = load_data()
```
%% Output
1163
ENG: JPMorgan: recommends to weigh carefully both cases for "similar threats" (saturated market and worse economic climate).
DEU: JPMorgan: empfiehlt angesichts der „ähnlichen Bedrohungen“ (gesättigter Markt und schlechteres Wirtschaftsklima) die Erwartungen in beiden Fällen zurückzusetzen
1046809 1046809
Zeile: 993209
Deutsch: Aber es muß auch darum gehen, Anreize zu schaffen für einen umweltfreundlichen lokalen öffentlichen Nahverkehr.
English: But it is also necessary to create incentives for environmentally friendly local public transport.
Zeile: 459853
Deutsch: Vielleicht sollte er dramatisch verlangsamt werden?
English: Perhaps it should be slowed down dramatically?
Zeile: 605086
Deutsch: Die Prämien haben im Übrigen durchaus positive grenzüberschreitende Wirkungen.
English: The incentives have also had a positive cross-border impact.
1816
ENG: The charge that she concentrated too much on foreign affairs, she dismissed with a terribly presumptuous statement.
DEU: Den Vorwurf, dass sie sich zu sehr auf die Außenpolitik konzentriere, hat Angela Merkel mit einem arg überheblichen Satz zurückgewiesen.
1846
ENG: One day after resigning as army chief, Pakistani ruler Musharraf was sworn in as president.
DEU: Einen Tag nach seinem Rücktritt als Armeechef ist der pakistanische Machthaber Musharraf als Präsident vereidigt worden.
Number of lines: (1046809, 1046809)
%% Cell type:markdown id:f2beddcc4122495a tags:
## 1. Text tokenization
%% Cell type:code id:d8ccbafa97fba573 tags:
``` python
# set up the tokenizer
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.processors import TemplateProcessing
# setting the unknown token (e.g. for emojis)
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
# adding special tokens
# [UNK] : unknown word/token
# [CLS] : starting token (new sentence sequence)
# [SEP] : separator for chaining multiple sentences
# [PAD] : padding needed for encoder input
# [MASK] : bad words!?
trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
trainer = BpeTrainer(vocab_size=50000, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
# set up the pre-tokenizer -> this ensures, that the maximal token length is one word
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
```
%% Cell type:code id:55cbac65a50a0199 tags:
``` python
tokenizer.train(['./data/training-data/dev/newstest2013.en'], trainer)
tokenizer.train(["data/training-data/eup/europarl-v7.de-en.de", "data/training-data/eup/europarl-v7.de-en.en"], trainer)
# configure post processing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
vocab_size = tokenizer.get_vocab_size()
```
%% Output
%% Cell type:code id:569b9a3425aa5800 tags:
``` python
print(data_array_en[15])
# testing the trained tokenizer
test_en = tokenizer.encode(data_array_en[14])
test_de = tokenizer.encode(data_array_de[11])
print(test_en.tokens)
#print(test_en.ids)
print(test_de.tokens)
#print(test_de.ids)
```
%% Output
Government crisis coming, says Gallup
['[CLS]', 'They', 'also', 'predict', 'that', 'the', 'ECB', 'will', 'cut', 'interest', 'rates', 'twice', 'during', 'the', 'course', 'of', '2008', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['[CLS]', 'D', 'er', 'Welt', 'mark', 't', 'pre', 'is', 'f', 'ü', 'r', 'Ro', 'h', '[UNK]', 'l', 'st', 'ie', 'g', 'in', 'dies', 'em', 'J', 'ah', 'r', 'um', '52', 'Pro', 'z', 'ent', '-', 'im', 'ver', 'gang', 'en', 'en', 'Mon', 'at', 'er', 're', 'ich', 'te', 'der', 'Pre', 'is', 'pro', 'F', 'ass', 'des', 'sch', 'war', 'zen', 'G', 'old', 'es', 'na', 'he', 'z', 'u', '100', 'US', 'Dol', 'lar', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
%% Cell type:markdown id:9c0f853775a802ec tags:
## 2. Prepare the training data
%% Cell type:code id:2e4dc87ce98b6cdd tags:
``` python
# Prepare training batch
def training_data(batch_size: int = 10, max_tokens: int = 200) -> tuple[torch.tensor, torch.tensor]:
def training_data(batch_size: int = 10, max_tokens: int = 50) -> tuple[torch.tensor, torch.tensor]:
x_training_data = []
y_training_data = []
# select random sentences
batch_indices = torch.randint(0, len(data_array_en), (batch_size, ))
batch_indices = torch.randint(0, len(source), (batch_size, ))
for idx in batch_indices:
x_training_data.append(data_array_en[idx])
y_training_data.append(data_array_de[idx])
x_training_data.append(target[idx])
y_training_data.append(source[idx])
# tokenize data
tokenizer.enable_padding(pad_id=3)
x_training_data = tokenizer.encode_batch(x_training_data)
tokenizer.enable_padding(pad_id=3, length=max_tokens)
y_training_data = tokenizer.encode_batch(y_training_data)
# extract ids for every sequence
for i in range(len(batch_indices)):
x_training_data[i] = x_training_data[i].ids
y_training_data[i] = y_training_data[i].ids
# 'tensorfy' x data
x_training_data = torch.tensor(x_training_data)
# 'tensorfy' & one hot encode y data
#y_training_data = F.one_hot(torch.tensor(y_training_data), num_classes=vocab_size)
y_training_data = torch.tensor(y_training_data)
return x_training_data, y_training_data
print(training_data())
```
%% Output
(tensor([[ 1, 388, 1282, 3117, 9643, 707, 186, 4944, 430, 16,
(tensor([[ 1, 556, 9472, 344, 386, 346, 984, 362, 472, 18,
2, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3],
[ 1, 502, 565, 8649, 9315, 401, 346, 1625, 1566, 16999,
2138, 401, 10036, 346, 72, 17, 14392, 1849, 363, 15202,
17748, 335, 344, 18, 2, 3, 3],
[ 1, 721, 342, 963, 335, 459, 1522, 360, 3311, 363,
1008, 13528, 25339, 18, 2, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3],
[ 1, 2657, 440, 4595, 3001, 20519, 362, 346, 3019, 1561,
39484, 12799, 335, 3897, 17, 13314, 472, 18, 2, 3,
3, 3, 3, 3, 3, 3, 3],
[ 1, 10169, 525, 359, 632, 360, 15202, 68, 4380, 375,
37324, 477, 19158, 936, 914, 342, 13922, 360, 359, 346,
1445, 1403, 335, 2624, 18, 2, 3],
[ 1, 14207, 16, 346, 4097, 362, 346, 5604, 4193, 10550,
338, 6024, 6130, 1009, 363, 341, 22069, 4488, 29208, 18,
2, 3, 3, 3, 3, 3, 3],
[ 1, 721, 546, 2619, 1048, 401, 21666, 7450, 360, 359,
6691, 360, 4356, 4897, 8267, 18, 2, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3],
[ 1, 721, 342, 470, 3184, 401, 990, 1478, 668, 359,
21775, 9083, 37514, 362, 346, 8748, 401, 882, 7700, 18,
2, 3, 3, 3, 3, 3, 3],
[ 1, 721, 342, 4702, 4371, 401, 380, 486, 6034, 7872,
363, 401, 380, 659, 14235, 2583, 335, 1141, 360, 1258,
1313, 882, 486, 39184, 682, 18, 2],
[ 1, 19047, 16, 941, 987, 16, 882, 440, 360, 359,
764, 1251, 401, 5140, 2, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3]]), tensor([[ 1, 596, 578, 339, 435, 367, 956, 2403, 7978, 18,
2, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 3305, 134, 98, 899, 773, 2115, 14, 38, 88,
946, 37, 251, 106, 218, 271, 142, 5847, 1114, 11655,
101, 1330, 15, 1559, 895, 7792, 2615, 230, 233, 101,
14198, 16, 149, 6664, 1854, 3660, 6696, 101, 14217, 14,
406, 98, 5495, 676, 301, 98, 2203, 16, 2],
[ 1, 25, 22, 1468, 109, 268, 489, 111, 2757, 107,
14328, 3592, 14, 5978, 1468, 111, 4932, 189, 232, 15,
1002, 742, 134, 5181, 10922, 118, 26, 22, 1468, 489,
111, 233, 241, 427, 249, 3620, 3707, 16, 2, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 6, 40, 185, 1000, 81, 284, 111, 837, 142,
98, 394, 10, 78, 3276, 218, 2198, 14, 98, 3269,
407, 98, 1078, 393, 38, 150, 752, 218, 2198, 16,
2, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 1188, 469, 97, 4975, 2153, 98, 1257, 109, 60,
4919, 137, 99, 231, 117, 115, 109, 5030, 16, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 582, 619, 415, 12866, 2317, 5209, 16, 475, 352,
8180, 3296, 4473, 8201, 16, 367, 367, 1755, 510, 372,
12175, 24927, 10069, 369, 3211, 9613, 17730, 3142, 18, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 467, 848, 2154, 4491, 352, 2016, 369, 367, 3126,
442, 6647, 8666, 23685, 448, 2812, 2171, 18, 2, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 52, 42, 1872, 4697, 116, 940, 9446, 9590, 6492,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 720, 586, 7415, 17167, 352, 1744, 16, 549, 618,
628, 381, 729, 17335, 1014, 4289, 3610, 852, 12465, 3018,
18, 2, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 720, 570, 1359, 636, 6085, 27733, 1771, 784, 471,
11607, 510, 367, 28935, 406, 9854, 1962, 335, 2624, 20901,
834, 18, 2, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 5097, 6837, 367, 23413, 10643, 501, 442, 352, 7203,
18377, 369, 442, 35812, 3579, 12270, 34707, 396, 18, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 802, 415, 1869, 5362, 16, 655, 352, 7618, 729,
549, 504, 474, 335, 352, 2014, 415, 16, 13502, 3055,
19139, 18, 2, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 802, 786, 474, 7207, 16, 475, 1208, 2829, 17233,
358, 22631, 500, 18, 2, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 802, 415, 6441, 12248, 16, 475, 416, 448, 7130,
9857, 4790, 369, 475, 416, 5613, 2122, 5719, 834, 16,
427, 383, 1639, 16, 25574, 628, 12850, 18, 2, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 2475, 642, 339, 981, 3171, 916, 824, 500, 5140,
2, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 45, 39, 50, 10515, 136, 1739, 111, 2394, 221,
363, 15, 1140, 200, 11477, 101, 10716, 2777, 111, 3524,
239, 199, 98, 3154, 651, 430, 14, 3315, 118, 891,
1441, 111, 269, 97, 98, 52, 42, 16, 2, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 337, 753, 1036, 8452, 80, 166, 2502, 111, 9658,
341, 2297, 98, 995, 142, 760, 98, 33, 638, 185,
186, 12033, 1063, 14, 122, 2571, 658, 173, 5604, 2395,
97, 98, 1857, 109, 8512, 1080, 118, 98, 135, 2512,
16, 2, 3, 3, 3, 3, 3, 3, 3],
[ 1, 337, 347, 98, 1232, 104, 3311, 118, 98, 4241,
104, 972, 210, 971, 16, 2, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3],
[ 1, 453, 1815, 11324, 668, 103, 3186, 3454, 1664, 405,
116, 1093, 5548, 14, 187, 7502, 1710, 3607, 209, 173,
2413, 157, 147, 1955, 60, 1016, 97, 11474, 192, 10,
3815, 1458, 1171, 16, 2, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3]]), tensor([[ 1, 6182, 63, ..., 3, 3, 3],
[ 1, 54, 281, ..., 3, 3, 3],
[ 1, 25, 22, ..., 3, 3, 3],
...,
[ 1, 35, 211, ..., 3, 3, 3],
[ 1, 35, 211, ..., 3, 3, 3],
[ 1, 8389, 42, ..., 3, 3, 3]]))
3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]))
%% Cell type:markdown id:689e2e565cce2845 tags:
## 3. Build the sequence2sequence RNN
%% Cell type:code id:e8d99510479108f4 tags:
``` python
embedding_dimension = 100
embedding_matrix_enc = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dimension)
embedding_matrix_dec = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dimension)
class Encoder(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bidirectional: bool = False):
super(Encoder, self).__init__()
self._hidden_size = hidden_size
self._num_layers = num_layers
# lstm layer
self._lstm = torch. nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=bidirectional,
batch_first=True)
self._dropout = torch.nn.Dropout(0.1)
def forward(self, embedded_sequence: torch.Tensor):
h_0 = torch.zeros(self._num_layers, embedded_sequence.size(0), self._hidden_size) #hidden state WITH batches
c_0 = torch.zeros(self._num_layers, embedded_sequence.size(0), self._hidden_size) #internal state WITH batches
#h_0 = torch.zeros(self._num_layers, self._hidden_size) #hidden state WITHOUT batches
#c_0 = torch.zeros(self._num_layers, self._hidden_size) #internal state WITHOUT batches
output, (hn, cn) = self._lstm(embedded_sequence, (h_0, c_0))
return output, hn, cn
class Decoder(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int, output_size: int,
num_layers: int = 1, bidirectional: bool = False,
max_tokens: int = 40, batch_size: int = 10):
max_tokens: int = 40):
super(Decoder, self).__init__()
self._hidden_size = hidden_size
self._num_layers = num_layers
self._max_tokens = max_tokens
self._batch_size = batch_size
# embedding matrix
self._embedding = embedding_matrix_dec
# lstm layer
self._lstm = torch. nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=bidirectional,
batch_first=True)
# output layer (fully connected linear layer)
self._out = nn.Linear(hidden_size, output_size)
def forward(self, x):
batch_size = x[0].size(0)
hidden_state = x[1]
cell_state = x[2]
outputs = []
# prepare start token
x_in = torch.empty(self._batch_size, 1, dtype=torch.long).fill_(1)
x_in = torch.empty(batch_size, 1, dtype=torch.long).fill_(1)
for i in range(self._max_tokens):
out, hidden_state, cell_state = self.forward_step(x_in, hidden_state, cell_state)
outputs.append(out)
# Without teacher forcing: use its own predictions as the next input
_, topi = out.topk(1)
x_in = topi.squeeze(-1).detach() # detach from history as input
outputs = torch.cat(outputs, dim=1) # WTF is happening here!? -> TODO: Understand the code
outputs = F.log_softmax(outputs, dim=-1)
return outputs, hidden_state, cell_state
def forward_step(self, x_in, hidden_state, cell_state):
output = self._embedding(x_in)
output = F.relu(output)
output, (h_t, c_t) = self._lstm(output, (hidden_state, cell_state))
output = self._out(output)
return output, h_t, c_t
```
%% Cell type:markdown id:535bc20b2f12f2da tags:
## 4. Train the model
%% Cell type:code id:1f8d3152359f6658 tags:
``` python
LSTM_hidden_size = 100
max_tokens_per_sequence = 200
LSTM_hidden_size = 128
max_tokens_per_sequence = 70
model = nn.Sequential(
embedding_matrix_enc,
Encoder(input_size=embedding_dimension, hidden_size=LSTM_hidden_size),
Decoder(input_size=embedding_dimension, hidden_size=LSTM_hidden_size,
output_size=vocab_size, max_tokens=max_tokens_per_sequence)
)
num_epochs = 100
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
num_epochs = 1000
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
loss_function = torch.nn.NLLLoss()
for i in range(1, num_epochs + 1):
# reset gradients
optimizer.zero_grad()
# make prediction
x_train, y_train = training_data(max_tokens=max_tokens_per_sequence)
x_train, y_train = training_data(batch_size=32, max_tokens=max_tokens_per_sequence)
predict = model(x_train)[0]
# match dimensions of prediction & gold_label vector
predict = predict.view(-1, predict.size(-1))
y_train = y_train.view(-1)
# calculate loss & propagate it backwards
loss = loss_function(predict, y_train)
loss.backward()
optimizer.step()
if i % 10 == 0:
print("---- Iteration " + str(i) + " ----")
print("loss: " + str(loss.item()))
```
%% Output
---- Iteration 10 ----
loss: 2.325965642929077
loss: 9.291775703430176
---- Iteration 20 ----
loss: 2.1237902641296387
loss: 6.553857803344727
---- Iteration 30 ----
loss: 2.099151849746704
loss: 4.213151454925537
---- Iteration 40 ----
loss: 1.8893781900405884
loss: 3.1044561862945557
---- Iteration 50 ----
loss: 2.0475175380706787
loss: 3.47859263420105
---- Iteration 60 ----
loss: 3.166140079498291
---- Iteration 70 ----
loss: 3.3509914875030518
---- Iteration 80 ----
loss: 2.626647710800171
---- Iteration 90 ----
loss: 3.137316942214966
---- Iteration 100 ----
loss: 3.088139295578003
---- Iteration 110 ----
loss: 2.9085235595703125
---- Iteration 120 ----
loss: 2.8475253582000732
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[136], line 20
Cell In[17], line 20
17 optimizer.zero_grad()
19 # make prediction
---> 20 x_train, y_train = training_data(max_tokens=max_tokens_per_sequence)
---> 20 x_train, y_train = training_data(batch_size=32, max_tokens=max_tokens_per_sequence)
21 predict = model(x_train)[0]
23 # match dimensions of prediction & gold_label vector
Cell In[128], line 28, in training_data(batch_size, max_tokens)
Cell In[8], line 28, in training_data(batch_size, max_tokens)
24 x_training_data = torch.tensor(x_training_data)
26 # 'tensorfy' & one hot encode y data
27 #y_training_data = F.one_hot(torch.tensor(y_training_data), num_classes=vocab_size)
---> 28 y_training_data = torch.tensor(y_training_data)
29 return x_training_data, y_training_data
ValueError: expected sequence of length 200 at dim 1 (got 201)
ValueError: expected sequence of length 50 at dim 1 (got 62)
%% Cell type:markdown id:44f9b74f91565a4a tags:
## 5. Sample from the model
%% Cell type:code id:b95fb365f686125d tags:
``` python
test_sequence = ("Ist dies der Weg, oder nicht?.")
test_sequence_enc = tokenizer.encode(test_sequence)
print(test_sequence_enc.ids)
test_sequence_batched = torch.tensor(test_sequence_enc.ids).view(1, -1)
predict, _, _ = model(test_sequence_batched)
_, topi = predict.topk(1)
decoded_ids = topi.squeeze()
tokenizer.decode(list(decoded_ids))
```
%% Output
[1, 6264, 432, 352, 2398, 16, 886, 474, 49994, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
'Ich ist der der der'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment