Skip to content
Snippets Groups Projects
Commit 17adcf9e authored by Konstantin Julius Lotzgeselle's avatar Konstantin Julius Lotzgeselle :speech_balloon:
Browse files

Copied code from tutorial and tried trainig it with our dataset. (It didn't work)

parent 4c0ed80c
No related branches found
No related tags found
No related merge requests found
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"import numpy as np\n",
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"class Lang:\n",
" def __init__(self, name):\n",
" self.name = name\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Turn a Unicode string to plain ASCII, thanks to\n",
"# https://stackoverflow.com/a/518232/2809427\n",
"def unicodeToAscii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )\n",
"\n",
"# Lowercase, trim, and remove non-letter characters\n",
"def normalizeString(s):\n",
" s = unicodeToAscii(s.lower().strip())\n",
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
" s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n",
" return s.strip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"def readLangs(lang1, lang2, reverse=False):\n",
" print(\"Reading lines...\")\n",
"\n",
" # Read the file and split into lines\n",
" lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\\\n",
" read().strip().split('\\n')\n",
"\n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n",
"\n",
" # Reverse pairs, make Lang instances\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" return input_lang, output_lang, pairs"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def readLangs(lang1, lang2, reverse=False) -> (str, str, list[str]):\n",
" with open(\"data/training-data/eup/europarl-v7.de-en.de\", \"r\", encoding=\"utf8\") as f:\n",
" data_de = [line.rstrip(\"\\n\") for line in f]\n",
" with open(\"data/training-data/eup/europarl-v7.de-en.en\", \"r\", encoding=\"utf8\") as f:\n",
" data_en = [line.rstrip(\"\\n\") for line in f]\n",
" \n",
" ltd = set() # save lines to delete later\n",
"\n",
" for i in range(max(len(data_de), len(data_en))):\n",
" # Move sentence to next line if line is empty other file\n",
" if data_de[i] == \"\":\n",
" data_en[i+1] = data_en[i] + \" \" + data_en[i+1]\n",
" ltd.add(i)\n",
" if data_en[i] == \"\":\n",
" data_de[i+1] = data_de[i] + \" \" + data_de[i+1]\n",
" ltd.add(i)\n",
" \n",
" # Remove lines, where difference in words is > 40%\n",
" if abs(count_words(data_de[i]) - count_words(data_en[i])) / (max(count_words(data_de[i]), count_words(data_en[i])) + 1) > 0.4:\n",
" ltd.add(i)\n",
" \n",
" # Remove lines < 3 words or > 25 words\n",
" if max(count_words(data_de[i]), count_words(data_en[i])) < 3 or max(count_words(data_de[i]), count_words(data_en[i])) > 25:\n",
" ltd.add(i)\n",
"\n",
" temp_de = [l for i, l in enumerate(data_de) if i not in ltd]\n",
" data_de = temp_de\n",
" temp_en = [l for i, l in enumerate(data_en) if i not in ltd]\n",
" data_en = temp_en\n",
" print(len(data_de),len(data_en))\n",
" \n",
" # Print 3 random sentence pairs\n",
" ix = torch.randint(low=0, high=max(len(data_de), len(data_en)), size=(3, ))\n",
" for i in ix:\n",
" print(f\"Zeile: {i}\\nDeutsch: {data_de[i]}\\nEnglish: {data_en[i]}\\n\")\n",
" \n",
" print(f\"\\nNumber of lines: {len(data_de), len(data_en)}\")\n",
"\n",
" pairs = [[de, en] for de, en in zip(data_de, data_en)]\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" print(pairs[:10])\n",
" return input_lang, output_lang, pairs\n",
"\n",
"def count_words(string: str) -> int:\n",
" return len(string.split())\n",
"\n",
"\n",
"#data_de_1, data_en_1 = load_data_1()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"MAX_LENGTH = 10\n",
"\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"def filterPair(p):\n",
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
" len(p[1].split(' ')) < MAX_LENGTH and \\\n",
" p[1].startswith(eng_prefixes)\n",
"\n",
"\n",
"def filterPairs(pairs):\n",
" return [pair for pair in pairs if filterPair(pair)]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1046809 1046809\n",
"Zeile: 692607\n",
"Deutsch: Momentan haben wir also eine Arbeitsgruppe.\n",
"English: For the time being, then, we have a working group.\n",
"\n",
"Zeile: 821888\n",
"Deutsch: Hier hat der Staat aber noch eine zweite, und ich glaube noch sehr viel wichtigere Verantwortung.\n",
"English: However, the state has a second and, to my mind, far more important responsibility here.\n",
"\n",
"Zeile: 14166\n",
"Deutsch: Die Berichterstatterin hat Änderungen vorgeschlagen, um diesen Vorschlag zu verstärken.\n",
"English: The rapporteur has proposed amendments to reinforce this proposal.\n",
"\n",
"\n",
"Number of lines: (1046809, 1046809)\n",
"[['Resumption of the session', 'Wiederaufnahme der Sitzungsperiode'], ['You have requested a debate on this subject in the course of the next few days, during this part-session.', 'Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sitzungsperiode in den nächsten Tagen.'], [\"Please rise, then, for this minute' s silence.\", 'Ich bitte Sie, sich zu einer Schweigeminute zu erheben.'], [\"(The House rose and observed a minute' s silence)\", '(Das Parlament erhebt sich zu einer Schweigeminute.)'], ['Madam President, on a point of order.', 'Frau Präsidentin, zur Geschäftsordnung.'], ['You will be aware from the press and television that there have been a number of bomb explosions and killings in Sri Lanka.', 'Wie Sie sicher aus der Presse und dem Fernsehen wissen, gab es in Sri Lanka mehrere Bombenexplosionen mit zahlreichen Toten.'], ['Yes, Mr Evans, I feel an initiative of the type you have just suggested would be entirely appropriate.', 'Ja, Herr Evans, ich denke, daß eine derartige Initiative durchaus angebracht ist.'], ['If the House agrees, I shall do as Mr Evans has suggested.', 'Wenn das Haus damit einverstanden ist, werde ich dem Vorschlag von Herrn Evans folgen.'], ['Madam President, on a point of order.', 'Frau Präsidentin, zur Geschäftsordnung.'], ['I would like your advice about Rule 143 concerning inadmissibility.', 'Könnten Sie mir eine Auskunft zu Artikel 143 im Zusammenhang mit der Unzulässigkeit geben?']]\n",
"Read 1046809 sentence pairs\n",
"Trimmed to 0 sentence pairs\n",
"Counting words...\n",
"Counted words:\n",
"en 2\n",
"de 2\n"
]
},
{
"ename": "IndexError",
"evalue": "Cannot choose from an empty sequence",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[6], line 16\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m input_lang, output_lang, pairs\n\u001b[0;32m 15\u001b[0m input_lang, output_lang, pairs \u001b[38;5;241m=\u001b[39m prepareData(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mde\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124men\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m---> 16\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mchoice\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpairs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[1;32mc:\\Users\\konst\\miniconda3\\envs\\nlp-machine-learning-project\\Lib\\random.py:373\u001b[0m, in \u001b[0;36mRandom.choice\u001b[1;34m(self, seq)\u001b[0m\n\u001b[0;32m 370\u001b[0m \u001b[38;5;66;03m# As an accommodation for NumPy, we don't use \"if not seq\"\u001b[39;00m\n\u001b[0;32m 371\u001b[0m \u001b[38;5;66;03m# because bool(numpy.array()) raises a ValueError.\u001b[39;00m\n\u001b[0;32m 372\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(seq):\n\u001b[1;32m--> 373\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCannot choose from an empty sequence\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 374\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m seq[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_randbelow(\u001b[38;5;28mlen\u001b[39m(seq))]\n",
"\u001b[1;31mIndexError\u001b[0m: Cannot choose from an empty sequence"
]
}
],
"source": [
"def prepareData(lang1, lang2, reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filterPairs(pairs)\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" for pair in pairs:\n",
" input_lang.addSentence(pair[0])\n",
" output_lang.addSentence(pair[1])\n",
" print(\"Counted words:\")\n",
" print(input_lang.name, input_lang.n_words)\n",
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, pairs\n",
"\n",
"input_lang, output_lang, pairs = prepareData('de', 'en', True)\n",
"print(random.choice(pairs))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, dropout_p=0.1):\n",
" super(EncoderRNN, self).__init__()\n",
" self.hidden_size = hidden_size\n",
"\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, input):\n",
" embedded = self.dropout(self.embedding(input))\n",
" output, hidden = self.gru(embedded)\n",
" return output, hidden"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class DecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size):\n",
" super(DecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
" decoder_outputs.append(decoder_output)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n",
"\n",
" def forward_step(self, input, hidden):\n",
" output = self.embedding(input)\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
" output = self.out(output)\n",
" return output, hidden"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class BahdanauAttention(nn.Module):\n",
" def __init__(self, hidden_size):\n",
" super(BahdanauAttention, self).__init__()\n",
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
" self.Va = nn.Linear(hidden_size, 1)\n",
"\n",
" def forward(self, query, keys):\n",
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
" scores = scores.squeeze(2).unsqueeze(1)\n",
"\n",
" weights = F.softmax(scores, dim=-1)\n",
" context = torch.bmm(weights, keys)\n",
"\n",
" return context, weights\n",
"\n",
"class AttnDecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
" super(AttnDecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.attention = BahdanauAttention(hidden_size)\n",
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
" attentions = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
" decoder_input, decoder_hidden, encoder_outputs\n",
" )\n",
" decoder_outputs.append(decoder_output)\n",
" attentions.append(attn_weights)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" attentions = torch.cat(attentions, dim=1)\n",
"\n",
" return decoder_outputs, decoder_hidden, attentions\n",
"\n",
"\n",
" def forward_step(self, input, hidden, encoder_outputs):\n",
" embedded = self.dropout(self.embedding(input))\n",
"\n",
" query = hidden.permute(1, 0, 2)\n",
" context, attn_weights = self.attention(query, encoder_outputs)\n",
" input_gru = torch.cat((embedded, context), dim=2)\n",
"\n",
" output, hidden = self.gru(input_gru, hidden)\n",
" output = self.out(output)\n",
"\n",
" return output, hidden, attn_weights"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def indexesFromSentence(lang, sentence):\n",
" return [lang.word2index[word] for word in sentence.split(' ')]\n",
"\n",
"def tensorFromSentence(lang, sentence):\n",
" indexes = indexesFromSentence(lang, sentence)\n",
" indexes.append(EOS_token)\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n",
"\n",
"def tensorsFromPair(pair):\n",
" input_tensor = tensorFromSentence(input_lang, pair[0])\n",
" target_tensor = tensorFromSentence(output_lang, pair[1])\n",
" return (input_tensor, target_tensor)\n",
"\n",
"def get_dataloader(batch_size):\n",
" input_lang, output_lang, pairs = prepareData('eng', 'fra', True)\n",
"\n",
" n = len(pairs)\n",
" input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
" target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
"\n",
" for idx, (inp, tgt) in enumerate(pairs):\n",
" inp_ids = indexesFromSentence(input_lang, inp)\n",
" tgt_ids = indexesFromSentence(output_lang, tgt)\n",
" inp_ids.append(EOS_token)\n",
" tgt_ids.append(EOS_token)\n",
" input_ids[idx, :len(inp_ids)] = inp_ids\n",
" target_ids[idx, :len(tgt_ids)] = tgt_ids\n",
"\n",
" train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n",
" torch.LongTensor(target_ids).to(device))\n",
"\n",
" train_sampler = RandomSampler(train_data)\n",
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
" return input_lang, output_lang, train_dataloader"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n",
" decoder_optimizer, criterion):\n",
"\n",
" total_loss = 0\n",
" for data in dataloader:\n",
" input_tensor, target_tensor = data\n",
"\n",
" encoder_optimizer.zero_grad()\n",
" decoder_optimizer.zero_grad()\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n",
"\n",
" loss = criterion(\n",
" decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
" target_tensor.view(-1)\n",
" )\n",
" loss.backward()\n",
"\n",
" encoder_optimizer.step()\n",
" decoder_optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
"\n",
" return total_loss / len(dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import math\n",
"\n",
"def asMinutes(s):\n",
" m = math.floor(s / 60)\n",
" s -= m * 60\n",
" return '%dm %ds' % (m, s)\n",
"\n",
"def timeSince(since, percent):\n",
" now = time.time()\n",
" s = now - since\n",
" es = s / (percent)\n",
" rs = es - s\n",
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.switch_backend('agg')\n",
"import matplotlib.ticker as ticker\n",
"import numpy as np\n",
"\n",
"def showPlot(points):\n",
" plt.figure()\n",
" fig, ax = plt.subplots()\n",
" # this locator puts ticks at regular intervals\n",
" loc = ticker.MultipleLocator(base=0.2)\n",
" ax.yaxis.set_major_locator(loc)\n",
" plt.plot(points)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n",
" print_every=100, plot_every=100):\n",
" start = time.time()\n",
" plot_losses = []\n",
" print_loss_total = 0 # Reset every print_every\n",
" plot_loss_total = 0 # Reset every plot_every\n",
"\n",
" encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n",
" decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n",
" criterion = nn.NLLLoss()\n",
"\n",
" for epoch in range(1, n_epochs + 1):\n",
" loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n",
" print_loss_total += loss\n",
" plot_loss_total += loss\n",
"\n",
" if epoch % print_every == 0:\n",
" print_loss_avg = print_loss_total / print_every\n",
" print_loss_total = 0\n",
" print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n",
" epoch, epoch / n_epochs * 100, print_loss_avg))\n",
"\n",
" if epoch % plot_every == 0:\n",
" plot_loss_avg = plot_loss_total / plot_every\n",
" plot_losses.append(plot_loss_avg)\n",
" plot_loss_total = 0\n",
"\n",
" showPlot(plot_losses)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n",
" with torch.no_grad():\n",
" input_tensor = tensorFromSentence(input_lang, sentence)\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
"\n",
" _, topi = decoder_outputs.topk(1)\n",
" decoded_ids = topi.squeeze()\n",
"\n",
" decoded_words = []\n",
" for idx in decoded_ids:\n",
" if idx.item() == EOS_token:\n",
" decoded_words.append('<EOS>')\n",
" break\n",
" decoded_words.append(output_lang.index2word[idx.item()])\n",
" return decoded_words, decoder_attn"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def evaluateRandomly(encoder, decoder, n=10):\n",
" for i in range(n):\n",
" pair = random.choice(pairs)\n",
" print('>', pair[0])\n",
" print('=', pair[1])\n",
" output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n",
" output_sentence = ' '.join(output_words)\n",
" print('<', output_sentence)\n",
" print('')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading lines...\n"
]
},
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'data/eng-fra.txt'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[19], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m hidden_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m128\u001b[39m\n\u001b[0;32m 2\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m32\u001b[39m\n\u001b[1;32m----> 4\u001b[0m input_lang, output_lang, train_dataloader \u001b[38;5;241m=\u001b[39m \u001b[43mget_dataloader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6\u001b[0m encoder \u001b[38;5;241m=\u001b[39m EncoderRNN(input_lang\u001b[38;5;241m.\u001b[39mn_words, hidden_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m 7\u001b[0m decoder \u001b[38;5;241m=\u001b[39m AttnDecoderRNN(hidden_size, output_lang\u001b[38;5;241m.\u001b[39mn_words)\u001b[38;5;241m.\u001b[39mto(device)\n",
"Cell \u001b[1;32mIn[11], line 15\u001b[0m, in \u001b[0;36mget_dataloader\u001b[1;34m(batch_size)\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_dataloader\u001b[39m(batch_size):\n\u001b[1;32m---> 15\u001b[0m input_lang, output_lang, pairs \u001b[38;5;241m=\u001b[39m \u001b[43mprepareData\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43meng\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mfra\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 17\u001b[0m n \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(pairs)\n\u001b[0;32m 18\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros((n, MAX_LENGTH), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mint32)\n",
"Cell \u001b[1;32mIn[7], line 2\u001b[0m, in \u001b[0;36mprepareData\u001b[1;34m(lang1, lang2, reverse)\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprepareData\u001b[39m(lang1, lang2, reverse\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m----> 2\u001b[0m input_lang, output_lang, pairs \u001b[38;5;241m=\u001b[39m \u001b[43mreadLangs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlang1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlang2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreverse\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRead \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m sentence pairs\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mlen\u001b[39m(pairs))\n\u001b[0;32m 4\u001b[0m pairs \u001b[38;5;241m=\u001b[39m filterPairs(pairs)\n",
"Cell \u001b[1;32mIn[5], line 5\u001b[0m, in \u001b[0;36mreadLangs\u001b[1;34m(lang1, lang2, reverse)\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mReading lines...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 4\u001b[0m \u001b[38;5;66;03m# Read the file and split into lines\u001b[39;00m\n\u001b[1;32m----> 5\u001b[0m lines \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdata/\u001b[39;49m\u001b[38;5;132;43;01m%s\u001b[39;49;00m\u001b[38;5;124;43m-\u001b[39;49m\u001b[38;5;132;43;01m%s\u001b[39;49;00m\u001b[38;5;124;43m.txt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m%\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mlang1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlang2\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39m\\\n\u001b[0;32m 6\u001b[0m read()\u001b[38;5;241m.\u001b[39mstrip()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 8\u001b[0m \u001b[38;5;66;03m# Split every line into pairs and normalize\u001b[39;00m\n\u001b[0;32m 9\u001b[0m pairs \u001b[38;5;241m=\u001b[39m [[normalizeString(s) \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m l\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m)] \u001b[38;5;28;01mfor\u001b[39;00m l \u001b[38;5;129;01min\u001b[39;00m lines]\n",
"\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data/eng-fra.txt'"
]
}
],
"source": [
"hidden_size = 128\n",
"batch_size = 32\n",
"\n",
"input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n",
"\n",
"encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n",
"decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n",
"\n",
"train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'encoder' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[20], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mencoder\u001b[49m\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m 2\u001b[0m decoder\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m 3\u001b[0m evaluateRandomly(encoder, decoder)\n",
"\u001b[1;31mNameError\u001b[0m: name 'encoder' is not defined"
]
}
],
"source": [
"encoder.eval()\n",
"decoder.eval()\n",
"evaluateRandomly(encoder, decoder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def showAttention(input_sentence, output_words, attentions):\n",
" fig = plt.figure()\n",
" ax = fig.add_subplot(111)\n",
" cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')\n",
" fig.colorbar(cax)\n",
"\n",
" # Set up axes\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
" ['<EOS>'], rotation=90)\n",
" ax.set_yticklabels([''] + output_words)\n",
"\n",
" # Show label at every tick\n",
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"def evaluateAndShowAttention(input_sentence):\n",
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
" print('input =', input_sentence)\n",
" print('output =', ' '.join(output_words))\n",
" showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])\n",
"\n",
"\n",
"#evaluateAndShowAttention('il n est pas aussi grand que son pere')\n",
"\n",
"#evaluateAndShowAttention('je suis trop fatigue pour conduire')\n",
"\n",
"#evaluateAndShowAttention('je suis desole si c est une question idiote')\n",
"\n",
"#evaluateAndShowAttention('je suis reellement fiere de vous')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "nlp-machine-learning-project",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment