Skip to content
Snippets Groups Projects
Commit 7aaa9046 authored by Konstantin Julius Lotzgeselle's avatar Konstantin Julius Lotzgeselle :speech_balloon:
Browse files
parents 2cefe6bd 4ed42b8c
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id:initial_id tags:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
from matplotlib import pyplot as plt
from pathlib import Path
```
%% Cell type:code id:f7c39c06ce3a14db tags:
``` python
# split the data into training/dev/test
data_array_en = open(Path('./data/training-data/dev/news-test2008.en'), 'r').readlines()
data_array_de = open(Path('./data/training-data/dev/news-test2008.de'), 'r').readlines()
data_en = open(Path('./data/training-data/dev/news-test2008.en'), 'r').read()
idx = torch.randint(low=0, high=2000, size=(3, ))
for id in idx:
print(id.item())
print("ENG: " + data_array_en[id.item()] + "DEU: " + data_array_de[id.item()])
```
%% Output
728
ENG: Where to go?
DEU: Ausflugsziele
810
ENG: It follows from the strategy of national debt financing and control for 2007 that the government planned to borrow 159.2 billion crowns this year.
DEU: Aus der Finanzierungsstrategie und dem Umgang mit den Staatsschulden für 2007 geht hervor, dass die Regierung für dieses Jahr geplant hatte, 159,2 Milliarden Kronen aufzunehmen.
928
ENG: More than a third of pregnancies is unexpected.
DEU: Noch immer ist mehr als ein Drittel der Schwangerschaften nicht geplant
%% Cell type:markdown id:f2beddcc4122495a tags:
## 1. Text tokenization
%% Cell type:code id:d8ccbafa97fba573 tags:
``` python
# set up the tokenizer
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.processors import TemplateProcessing
# setting the unknown token (e.g. for emojis)
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
# adding special tokens
# [UNK] : unknown word/token
# [CLS] : starting token (new sentence sequence)
# [SEP] : separator for chaining multiple sentences
# [PAD] : padding needed for encoder input
# [MASK] : bad words!?
trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
# set up the pre-tokenizer -> this ensures, that the maximal token length is one word
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
```
%% Cell type:code id:55cbac65a50a0199 tags:
``` python
tokenizer.train(['./data/training-data/dev/newstest2013.en'], trainer)
# configure post processing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
```
%% Output
%% Cell type:code id:569b9a3425aa5800 tags:
``` python
print(tokenizer.get_vocab_size())
print(data_array_en[15])
# testing the trained tokenizer
test_en = tokenizer.encode(data_array_en[14])
test_de = tokenizer.encode(data_array_de[11])
print(test_en.tokens)
#print(test_en.ids)
print(test_de.tokens)
#print(test_de.ids)
```
%% Output
14823
Government crisis coming, says Gallup
['[CLS]', 'They', 'also', 'predict', 'that', 'the', 'ECB', 'will', 'cut', 'interest', 'rates', 'twice', 'during', 'the', 'course', 'of', '2008', '.', '[SEP]']
['[CLS]', 'D', 'er', 'Welt', 'mark', 't', 'pre', 'is', 'f', 'ü', 'r', 'Ro', 'h', '[UNK]', 'l', 'st', 'ie', 'g', 'in', 'dies', 'em', 'J', 'ah', 'r', 'um', '52', 'Pro', 'z', 'ent', '-', 'im', 'ver', 'gang', 'en', 'en', 'Mon', 'at', 'er', 're', 'ich', 'te', 'der', 'Pre', 'is', 'pro', 'F', 'ass', 'des', 'sch', 'war', 'zen', 'G', 'old', 'es', 'na', 'he', 'z', 'u', '100', 'US', 'Dol', 'lar', '.', '[SEP]']
%% Cell type:markdown id:689e2e565cce2845 tags:
%% Cell type:markdown id:9c0f853775a802ec tags:
## 2. Build the sequence2sequence RNN
## 2. Prepare the training data
%% Cell type:code id:b085c13e22062eff tags:
%% Cell type:code id:2e4dc87ce98b6cdd tags:
``` python
# Prepare training batch
def training_data(batch_size: int = 10) -> int:
x_training_data = []
y_training_data = []
# select random sentences
batch_indices = torch.randint(0, len(data_array_en), (batch_size, ))
for idx in batch_indices:
x_training_data.append(data_array_en[idx])
y_training_data.append(data_array_de[idx])
# tokenize data
tokenizer.enable_padding()
x_training_data = tokenizer.encode_batch(x_training_data)
tokenizer.no_padding()
y_training_data = tokenizer.encode_batch(y_training_data)
# 'tensorfy' data
for i in range(len(batch_indices)):
x_training_data[i] = torch.tensor(x_training_data[i].ids)
y_training_data[i] = torch.tensor(y_training_data[i].ids)
return x_training_data, y_training_data
x_training_data[i] = x_training_data[i].ids
y_training_data[i] = y_training_data[i].ids
return torch.tensor(x_training_data), torch.tensor(y_training_data)
print(training_data())
```
%% Output
%% Cell type:markdown id:689e2e565cce2845 tags:
([tensor([ 1, 3645, 14, 2257, 366, 109, 940, 3268, 97, 2978, 1155, 111,
100, 233, 16, 2, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0]), tensor([ 1, 149, 4452, 852, 3745, 136, 259, 517, 11413, 97,
1222, 6, 45, 1173, 6, 2626, 16, 912, 111, 98,
3797, 14, 239, 209, 173, 435, 6859, 2020, 98, 11863,
109, 5902, 97, 98, 5351, 582, 16, 2, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0]), tensor([ 1, 549, 185, 1324, 984, 111, 209, 0, 16, 2, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0]), tensor([ 1, 6, 1569, 199, 317, 113, 291, 1423, 5144, 108, 271, 16,
2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0]), tensor([ 1, 50, 4244, 1485, 313, 4618, 142, 113, 9764, 60,
5707, 210, 33, 1393, 793, 813, 6601, 111, 12825, 97,
98, 6664, 16, 2, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0]), tensor([ 1, 420, 949, 98, 3605, 109, 531, 184, 118, 6207, 111, 125,
2352, 97, 1208, 161, 1066, 331, 103, 3159, 5802, 2062, 14, 559,
141, 6931, 6207, 1466, 134, 1937, 27, 134, 97, 152, 60, 5524,
305, 144, 103, 32, 1716, 100, 16, 2, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0]), tensor([ 1, 1571, 5713, 5264, 6, 2074, 627, 105, 6, 317,
2468, 4645, 101, 98, 13885, 109, 473, 84, 1516, 1608,
0, 78, 11797, 1695, 126, 154, 727, 83, 108, 1222,
1291, 1226, 14, 100, 10817, 7529, 103, 98, 6109, 102,
3479, 2241, 175, 14, 50, 9577, 1608, 72, 14, 259,
1329, 111, 60, 187, 265, 115, 5740, 28, 149, 1371,
6, 1695, 126, 154, 727, 83, 108, 6, 510, 304,
111, 98, 6109, 102, 3479, 2241, 175, 196, 1267, 2173,
2892, 8649, 176, 39, 419, 90, 73, 97, 3265, 136,
134, 1955, 211, 0, 116, 14, 14626, 14, 136, 98,
11838, 109, 98, 5566, 1432, 3041, 518, 234, 1695, 126,
154, 727, 83, 108, 109, 98, 6, 8732, 148, 50,
9577, 1608, 72, 6, 97, 7605, 101, 98, 895, 208,
1075, 16, 2]), tensor([ 1, 1036, 16, 42, 1207, 176, 317, 1909, 110, 101, 98, 644,
2373, 451, 1291, 1226, 103, 98, 39, 138, 6238, 1229, 97, 53,
68, 105, 3479, 16, 2, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0]), tensor([ 1, 8439, 1397, 97, 2302, 366, 2, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0]), tensor([ 1, 549, 136, 368, 1624, 68, 3584, 142, 98, 14357,
2368, 406, 12106, 98, 3605, 109, 98, 10738, 16, 2,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0])], [tensor([ 1, 35, 144, 113, 233, 105, 354, 79, 41, 1080, 99, 6106,
113, 0, 77, 710, 77, 3491, 298, 79, 97, 2978, 16, 2]), tensor([ 1, 8389, 194, 80, 202, 105, 37, 460, 140, 4186,
423, 243, 61, 73, 2289, 14, 1820, 73, 163, 97,
213, 67, 99, 245, 6, 45, 1173, 6, 15, 566,
1169, 195, 71, 73, 81, 106, 1364, 1649, 79, 678,
1308, 14, 163, 146, 120, 99, 216, 105, 73, 354,
327, 78, 216, 79, 278, 37, 460, 140, 102, 73,
216, 79, 6305, 1820, 36, 97, 85, 147, 208, 147,
195, 113, 278, 5902, 6415, 67, 107, 61, 525, 8693,
73, 392, 84, 11021, 16, 2]), tensor([ 1, 8412, 66, 180, 73, 216, 152, 85, 80, 6362, 16, 0,
2]), tensor([ 1, 6, 8389, 1418, 1987, 172, 1813, 578, 758, 6, 14, 78,
160, 179, 1423, 5144, 108, 16, 2]), tensor([ 1, 50, 4244, 15, 1485, 68, 105, 67, 103, 957, 2506, 428,
79, 14, 63, 334, 108, 9261, 33, 1393, 64, 340, 57, 176,
216, 102, 4186, 97, 6228, 249, 50, 97, 194, 9261, 33, 1393,
102, 67, 107, 606, 67, 103, 14, 125, 81, 106, 108, 78,
216, 63, 730, 80, 424, 140, 5084, 14, 100, 278, 42, 101,
1545, 85, 179, 126, 85, 2315, 67, 1163, 16, 2]), tensor([ 1, 3424, 541, 65, 95, 77, 6101, 67, 71, 635, 6002, 79,
424, 231, 354, 105, 1308, 35, 103, 105, 525, 57, 124, 756,
105, 331, 103, 253, 104, 9843, 1800, 952, 9261, 361, 3042, 73,
14, 125, 68, 32, 1716, 100, 66, 180, 108, 69, 110, 74,
140, 19, 3468, 6002, 79, 141, 411, 105, 65, 95, 77, 1937,
27, 485, 586, 16, 2]), tensor([ 1, 6118, 167, 72, 1820, 3491, 3066, 104, 948, 3992,
113, 4186, 6, 2074, 627, 105, 6, 81, 106, 64,
97, 147, 105, 2082, 103, 105, 57, 141, 68, 1875,
100, 278, 36, 140, 98, 113, 9261, 473, 84, 1516,
1608, 78, 1695, 126, 154, 15, 727, 9300, 525, 41,
1080, 168, 11797, 1364, 140, 95, 77, 79, 67, 103,
14, 70, 123, 1163, 413, 194, 6118, 134, 3491, 163,
578, 139, 6109, 102, 3479, 2241, 175, 97, 50, 9577,
1608, 72, 8993, 85, 79, 85, 80, 64, 97, 249,
9472, 95, 140, 1130, 1308, 6036, 243, 61, 73, 104,
28, 8389, 128, 140, 78, 6, 1695, 126, 154, 15,
727, 83, 108, 6, 14, 1820, 278, 64, 122, 1614,
147, 64, 2430, 78, 965, 70, 1315, 8649, 176, 39,
419, 90, 73, 167, 72, 6109, 102, 3479, 2241, 175,
3265, 956, 4808, 105, 70, 4186, 143, 2735, 79, 67,
103, 14, 78, 329, 42, 170, 68, 105, 15, 635,
367, 67, 71, 228, 140, 278, 38, 130, 0, 179,
126, 278, 77, 635, 67, 438, 79, 141, 446, 245,
1695, 126, 154, 15, 727, 9300, 81, 123, 6, 8732,
75, 50, 9577, 1608, 72, 6, 14, 1820, 228, 65,
167, 72, 895, 105, 42, 163, 120, 936, 79, 97,
52, 72, 71, 228, 65, 78, 329, 16, 2]), tensor([ 1, 0, 265, 63, 116, 2475, 9637, 149, 733, 67, 103, 179,
42, 1207, 176, 97, 54, 68, 105, 3491, 101, 81, 106, 85,
141, 68, 2082, 103, 105, 139, 39, 138, 6238, 15, 214, 534,
1364, 182, 9843, 16, 2]), tensor([ 1, 37, 95, 73, 65, 6104, 97, 6353, 1813, 41, 1080, 245,
2]), tensor([ 1, 1415, 68, 278, 5987, 65, 120, 248, 6294, 278, 599, 0,
72, 281, 72, 95, 456, 105, 1820, 53, 192, 216, 102, 163,
578, 64, 340, 535, 9312, 556, 68, 122, 9261, 47, 5113, 606,
265, 95, 3188, 216, 79, 147, 105, 16, 2])])
## 3. Build the sequence2sequence RNN
%% Cell type:code id:1f8d3152359f6658 tags:
%% Cell type:code id:e8d99510479108f4 tags:
``` python
class Model(nn.Module):
def __init__(self, input_size: int, output_size: int, hidden_dim: int, n_layers: int = 1,
non_linearity: str = "tanh", bidirectional: bool = False):
super(Model, self).__init__()
vocab_size = tokenizer.get_vocab_size()
embedding_dimension = 100
embedding_matrix_enc = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dimension)
embedding_matrix_dec = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dimension)
class Encoder(nn.Module):
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bidirectional: bool = False):
super(Encoder, self).__init__()
self._hidden_size = hidden_size
self._num_layers = num_layers
# lstm layer
self._lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=bidirectional,
batch_first=True)
def forward(self, embedded_sequence: torch.Tensor):
print("Data shape after embedding: " + str(embedded_sequence.shape))
h_0 = torch.zeros(self._num_layers, embedded_sequence.size(0), self._hidden_size) #hidden state WITH batches
c_0 = torch.zeros(self._num_layers, embedded_sequence.size(0), self._hidden_size) #internal state WITH batches
#h_0 = torch.zeros(self._num_layers, self._hidden_size) #hidden state WITHOUT batches
#c_0 = torch.zeros(self._num_layers, self._hidden_size) #internal state WITHOUT batches
output, (hn, cn) = self._lstm(embedded_sequence, (h_0, c_0))
return output, hn, cn
self._hidden_dim = hidden_dim
self._n_layers = n_layers
# rnn layer(s)
self._rnn = nn.RNN(input_size=input_size,
hidden_size=hidden_dim,
nonlinearity=non_linearity,
bidirectional=bidirectional)
class Decoder(nn.Module):
def __init__(self, input_size: int, hidden_size: int, output_size: int,
num_layers: int = 1, bidirectional: bool = False,
max_tokens: int = 40, batch_size: int = 10):
super(Decoder, self).__init__()
# fully connected layer
self._fc = nn.Linear(hidden_dim, output_size)
self._hidden_size = hidden_size
self._num_layers = num_layers
self._max_tokens = max_tokens
self._batch_size = batch_size
def forward(self, x: torch.Tensor):
batch_size = x.size(0)
# embedding matrix
self._embedding = embedding_matrix_dec
# initializing the hidden state for the first input
hidden = self.init_hidden(batch_size)
# lstm layer
self._lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=bidirectional,
batch_first=True)
# passing the input and the hidden state into the model and obtaining outputs
out, hidden = self._rnn(x, hidden)
# output layer (fully connected linear layer)
self._out = nn.Linear(hidden_size, output_size)
# reshaping the outputs, so it fits into the fully connected layer
out = out.contigous().view(-1, self._hidden_dim)
out = self._fc(out)
def forward(self, x):
encoder_out = x[0]
hidden_state = x[1]
cell_state = x[2]
outputs = []
return out, hidden
# prepare start token
x_in = torch.empty(self._batch_size, 1, dtype=torch.long).fill_(1)
def init_hidden(self, batch_size: int):
hidden = torch.zeros(self._n_layers, batch_size, self._hidden_dim)
return hidden
for i in range(self._max_tokens):
out, hidden_state, cell_state = self.forward_step(x_in, hidden_state, cell_state)
outputs.append(out)
# Without teacher forcing: use its own predictions as the next input
_, topi = out.topk(1)
x_in = topi.squeeze(-1).detach() # detach from history as input
outputs = torch.cat(outputs, dim=1)
outputs = F.log_softmax(outputs, dim=-1)
return outputs, hidden_state, cell_state
def forward_step(self, x_in, hidden_state, cell_state):
output = self._embedding(x_in)
output = F.relu(output)
output, (h_t, c_t) = self._lstm(output, (hidden_state, cell_state))
output = self._out(output)
return output, h_t, c_t
LSTM_hidden_size = 100
model = nn.Sequential(
embedding_matrix_enc,
Encoder(input_size=embedding_dimension, hidden_size=LSTM_hidden_size),
Decoder(input_size=embedding_dimension, hidden_size=LSTM_hidden_size, output_size=vocab_size)
)
print(model)
train = training_data()
print(model(train[0]))
```
%% Output
Sequential(
(0): Embedding(14823, 100)
(1): Encoder(
(_lstm): LSTM(100, 100, batch_first=True)
)
(2): Decoder(
(_embedding): Embedding(14823, 100)
(_lstm): LSTM(100, 100, batch_first=True)
(_out): Linear(in_features=100, out_features=14823, bias=True)
)
)
Food: Where European inflation slipped up
[Encoding(num_tokens=26, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]), Encoding(num_tokens=26, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]
Data shape after embedding: torch.Size([10, 55, 100])
(tensor([[[-9.6415, -9.4431, -9.6636, ..., -9.5648, -9.5597, -9.5541],
[-9.7372, -9.4867, -9.5413, ..., -9.4937, -9.5778, -9.5429],
[-9.8106, -9.6389, -9.4536, ..., -9.5229, -9.5157, -9.4979],
...,
[-9.7573, -9.6155, -9.4762, ..., -9.5930, -9.5073, -9.5158],
[-9.8039, -9.7443, -9.3759, ..., -9.5943, -9.5311, -9.4868],
[-9.8037, -9.6423, -9.4548, ..., -9.5880, -9.4976, -9.5846]],
[[-9.6416, -9.4431, -9.6636, ..., -9.5648, -9.5596, -9.5541],
[-9.7373, -9.4867, -9.5413, ..., -9.4937, -9.5778, -9.5429],
[-9.8107, -9.6389, -9.4536, ..., -9.5229, -9.5157, -9.4979],
...,
[-9.7573, -9.6155, -9.4762, ..., -9.5930, -9.5073, -9.5158],
[-9.8039, -9.7443, -9.3759, ..., -9.5943, -9.5311, -9.4868],
[-9.8037, -9.6423, -9.4548, ..., -9.5880, -9.4976, -9.5846]],
[[-9.6415, -9.4431, -9.6635, ..., -9.5648, -9.5597, -9.5541],
[-9.7372, -9.4867, -9.5413, ..., -9.4937, -9.5778, -9.5429],
[-9.8106, -9.6389, -9.4536, ..., -9.5229, -9.5158, -9.4979],
...,
[-9.7573, -9.6155, -9.4762, ..., -9.5930, -9.5073, -9.5158],
[-9.8039, -9.7443, -9.3759, ..., -9.5943, -9.5311, -9.4868],
[-9.8037, -9.6423, -9.4548, ..., -9.5880, -9.4976, -9.5846]],
...,
[[-9.6415, -9.4431, -9.6636, ..., -9.5648, -9.5597, -9.5541],
[-9.7372, -9.4867, -9.5413, ..., -9.4937, -9.5778, -9.5429],
[-9.8106, -9.6389, -9.4536, ..., -9.5229, -9.5157, -9.4979],
...,
[-9.7573, -9.6155, -9.4762, ..., -9.5930, -9.5073, -9.5158],
[-9.8039, -9.7443, -9.3759, ..., -9.5943, -9.5311, -9.4868],
[-9.8037, -9.6423, -9.4548, ..., -9.5880, -9.4976, -9.5846]],
[[-9.6415, -9.4431, -9.6635, ..., -9.5648, -9.5597, -9.5541],
[-9.7372, -9.4867, -9.5413, ..., -9.4937, -9.5778, -9.5429],
[-9.8106, -9.6389, -9.4536, ..., -9.5229, -9.5158, -9.4979],
...,
[-9.7573, -9.6155, -9.4762, ..., -9.5930, -9.5073, -9.5158],
[-9.8039, -9.7443, -9.3759, ..., -9.5943, -9.5311, -9.4868],
[-9.8037, -9.6423, -9.4548, ..., -9.5880, -9.4976, -9.5846]],
[[-9.6398, -9.6152, -9.6075, ..., -9.5829, -9.5681, -9.5603],
[-9.6789, -9.6375, -9.4903, ..., -9.5866, -9.5447, -9.4690],
[-9.6374, -9.6482, -9.5078, ..., -9.5755, -9.5485, -9.4923],
...,
[-9.5746, -9.5585, -9.6480, ..., -9.6288, -9.5277, -9.4217],
[-9.7304, -9.7413, -9.4759, ..., -9.6331, -9.5177, -9.4472],
[-9.6139, -9.6048, -9.5876, ..., -9.6273, -9.5298, -9.4094]]],
grad_fn=<LogSoftmaxBackward0>), tensor([[[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[ 0.0023, 0.1486, 0.0237, -0.0758, -0.1192, 0.0142, -0.1573,
-0.0982, -0.0524, 0.2234, -0.1119, 0.2357, 0.0279, -0.1894,
0.0085, -0.1701, -0.1204, -0.1979, 0.2874, -0.0976, 0.2367,
-0.0215, -0.0603, 0.1947, 0.1782, 0.4303, 0.1435, 0.1800,
0.1081, 0.1067, 0.1762, 0.0943, -0.0293, -0.1086, 0.1312,
-0.0862, 0.0772, -0.1786, 0.1441, 0.0510, -0.2177, -0.0075,
-0.1853, 0.0700, -0.1061, -0.0112, -0.1742, -0.0911, -0.1217,
-0.0673, 0.3696, 0.0280, -0.1376, -0.0340, -0.1302, -0.1149,
0.0929, -0.0090, 0.0960, 0.0556, 0.1672, -0.0918, -0.0973,
0.0414, -0.1669, 0.2295, -0.0037, -0.0553, 0.1901, 0.2150,
0.0686, 0.0431, -0.2670, -0.0582, -0.0539, -0.0426, 0.3081,
0.0424, -0.1656, 0.0080, -0.1062, 0.0526, 0.0500, 0.0263,
-0.3661, 0.1119, 0.0373, -0.1885, 0.0189, 0.0363, 0.0817,
0.0535, -0.1849, 0.2174, 0.1434, 0.1933, 0.0359, 0.1224,
-0.0005, -0.0861],
[-0.0113, 0.0558, -0.0812, -0.0709, -0.0982, 0.1706, -0.2960,
-0.1090, -0.1950, 0.3196, 0.0115, 0.1616, 0.0742, 0.0353,
-0.1603, -0.3228, 0.1030, -0.1173, 0.1040, 0.1915, 0.1484,
-0.0626, -0.3212, -0.0640, -0.1243, 0.4256, -0.0017, 0.0050,
-0.1011, 0.1405, -0.3329, 0.0539, 0.2078, -0.2096, 0.1785,
-0.1552, 0.3053, -0.3544, 0.1163, -0.0297, -0.2161, 0.0857,
0.1611, 0.0786, 0.0451, -0.0543, -0.3378, -0.0603, -0.0304,
-0.0617, 0.2123, -0.0015, -0.1774, -0.0083, -0.1252, -0.1364,
0.0091, -0.1556, 0.0480, -0.0661, 0.0439, -0.0315, 0.1671,
0.1805, -0.1611, 0.2483, 0.0749, -0.1219, -0.0331, 0.2407,
0.2305, 0.2929, -0.1674, -0.0710, 0.0576, -0.2653, 0.2473,
0.1214, 0.0282, -0.1709, 0.0265, 0.0554, 0.1247, 0.0222,
-0.2213, 0.1275, 0.2263, -0.1711, 0.1263, 0.0723, -0.0991,
0.1312, -0.1859, 0.1661, -0.0202, 0.1720, 0.2466, -0.0232,
0.0257, -0.1481]]], grad_fn=<StackBackward0>), tensor([[[ 4.1388e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1388e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1388e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1388e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1388e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1389e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1388e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1389e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[ 4.1388e-03, 3.4760e-01, 7.7964e-02, -1.3910e-01, -1.6001e-01,
3.0040e-02, -2.9443e-01, -2.3068e-01, -1.2384e-01, 3.9988e-01,
-2.7987e-01, 4.8966e-01, 6.5901e-02, -3.3702e-01, 1.9192e-02,
-3.1752e-01, -2.5456e-01, -3.7401e-01, 5.2014e-01, -2.6364e-01,
4.8279e-01, -3.6519e-02, -1.2773e-01, 3.8893e-01, 3.3339e-01,
1.3188e+00, 2.3630e-01, 3.4717e-01, 2.1574e-01, 1.9013e-01,
3.3517e-01, 2.1240e-01, -5.5276e-02, -2.6756e-01, 2.2808e-01,
-1.9447e-01, 1.2559e-01, -2.7378e-01, 2.9605e-01, 1.3176e-01,
-4.2958e-01, -1.4714e-02, -2.9502e-01, 1.3312e-01, -2.1871e-01,
-2.5984e-02, -4.0586e-01, -2.2980e-01, -3.2106e-01, -1.5043e-01,
8.0840e-01, 7.3231e-02, -3.2636e-01, -9.5514e-02, -2.9875e-01,
-2.0356e-01, 1.7600e-01, -1.7115e-02, 1.6503e-01, 8.3258e-02,
3.4476e-01, -1.7428e-01, -1.7552e-01, 1.1986e-01, -3.0016e-01,
5.2874e-01, -7.1281e-03, -1.0720e-01, 3.5520e-01, 5.5678e-01,
1.2603e-01, 7.3370e-02, -5.3671e-01, -1.1090e-01, -8.8089e-02,
-6.8762e-02, 7.0469e-01, 8.5552e-02, -3.6397e-01, 1.4415e-02,
-2.4050e-01, 1.2464e-01, 1.3162e-01, 3.9336e-02, -6.0231e-01,
2.3541e-01, 7.5265e-02, -5.2977e-01, 5.5146e-02, 1.1436e-01,
2.4792e-01, 9.5830e-02, -4.6578e-01, 5.3352e-01, 4.4076e-01,
3.4128e-01, 9.6502e-02, 2.2520e-01, -1.2290e-03, -2.0600e-01],
[-2.4798e-02, 1.6421e-01, -2.8511e-01, -1.4592e-01, -2.2922e-01,
4.2751e-01, -5.1420e-01, -2.0352e-01, -3.4673e-01, 5.6531e-01,
3.7713e-02, 6.2584e-01, 1.9636e-01, 5.9184e-02, -2.7884e-01,
-5.0163e-01, 1.8202e-01, -2.3392e-01, 2.9898e-01, 6.2062e-01,
3.1362e-01, -1.5483e-01, -5.5394e-01, -9.8141e-02, -3.0405e-01,
1.3070e+00, -3.5137e-03, 8.8364e-03, -1.9745e-01, 2.5917e-01,
-5.8083e-01, 1.2488e-01, 4.1021e-01, -5.6031e-01, 2.3878e-01,
-2.4106e-01, 4.3839e-01, -5.7050e-01, 2.1050e-01, -5.2716e-02,
-5.9602e-01, 1.6424e-01, 2.7367e-01, 1.4049e-01, 8.2080e-02,
-8.9452e-02, -6.7448e-01, -1.4769e-01, -7.4956e-02, -1.3364e-01,
4.0567e-01, -3.2883e-03, -3.6234e-01, -1.9333e-02, -3.2529e-01,
-2.3158e-01, 1.7457e-02, -3.0411e-01, 7.4791e-02, -1.2701e-01,
7.4187e-02, -5.5861e-02, 3.6002e-01, 4.6427e-01, -3.2839e-01,
5.6560e-01, 1.6330e-01, -2.4283e-01, -5.7410e-02, 5.7161e-01,
5.1104e-01, 4.7838e-01, -3.7463e-01, -2.9594e-01, 1.1476e-01,
-4.4387e-01, 6.2682e-01, 2.2299e-01, 6.1789e-02, -3.8020e-01,
5.2411e-02, 1.2436e-01, 3.3974e-01, 9.3500e-02, -4.4880e-01,
3.2380e-01, 4.9098e-01, -4.9971e-01, 2.6986e-01, 2.5225e-01,
-1.7665e-01, 2.6156e-01, -3.7751e-01, 3.1167e-01, -4.9966e-02,
2.7210e-01, 5.5340e-01, -4.0680e-02, 5.2126e-02, -2.6088e-01]]],
grad_fn=<StackBackward0>))
%% Cell type:code id:1f8d3152359f6658 tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment