Skip to content
Snippets Groups Projects
Commit 8c22e8ee authored by marvnsch's avatar marvnsch
Browse files

add get training data

parent c415e5c6
No related branches found
No related tags found
No related merge requests found
......@@ -91,3 +91,45 @@ def create_tokenizers(source_data_path: str, target_data_path: str, source_langu
def count_words(string: str) -> int:
return len(string.split())
def training_data(source: list[str],
target: list[str],
dataset_size: int,
batch_size: int = 64,
sort: bool = True) -> tuple[torch.tensor, torch.tensor]:
tokenizer_de.no_padding()
tokenizer_en.no_padding()
if dataset_size > len(source):
raise IndexError("Dataset size is larger than the source data")
# sort the training data if true
if sort:
temp = ([list(a) for a in zip(source[:dataset_size], target[:dataset_size])])
temp.sort(key=lambda s: len(s[0]) + len(s[1]))
source, target = list(zip(*temp))
# select random sentences
for i in range(0, len(source) - batch_size, batch_size):
x_training_data = source[i:i + batch_size]
y_training_data = target[i:i + batch_size]
# tokenize data
tokenizer_en.enable_padding(pad_id=3)
x_training_data = tokenizer_en.encode_batch(x_training_data)
tokenizer_de.enable_padding(pad_id=3)
y_training_data = tokenizer_de.encode_batch(y_training_data)
# extract ids for every sequence
for j in range(batch_size):
x_training_data[j] = x_training_data[j].ids
y_training_data[j] = y_training_data[j].ids
# put data into tensor
x_training_data = torch.tensor(x_training_data, device=device)
y_training_data = torch.tensor(y_training_data, device=device)
# transpose tensors to match input requirements for lstm
x_training_data = torch.transpose(x_training_data, 0, 1)
y_training_data = torch.transpose(y_training_data, 0, 1)
yield x_training_data, y_training_data
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment