Skip to content
Snippets Groups Projects
Commit 0f1ffd05 authored by Konstantin Julius Lotzgeselle's avatar Konstantin Julius Lotzgeselle :speech_balloon:
Browse files
parents 74eab182 7ee878e7
No related branches found
No related tags found
No related merge requests found
...@@ -105,43 +105,74 @@ def create_tokenizers(source_data_path: str, target_data_path: str, source_langu ...@@ -105,43 +105,74 @@ def create_tokenizers(source_data_path: str, target_data_path: str, source_langu
tokenizer_de.save(str(workdir / "tokenizer_de.json")) tokenizer_de.save(str(workdir / "tokenizer_de.json"))
def training_data(source: list[str], def data_loader(source: list[str],
target: list[str], target: list[str],
source_tokenizer: Tokenizer,
target_tokenizer: Tokenizer,
dataset_size: int, dataset_size: int,
torch_device: torch.device,
batch_size: int = 64, batch_size: int = 64,
sort: bool = True) -> tuple[torch.tensor, torch.tensor]: data_split: tuple[float, float, float] = (0.8, 0.1, 0.1),
tokenizer_de.no_padding() sort: bool = True):
tokenizer_en.no_padding()
if sum(data_split) != 1.0:
raise ValueError(f"The data split must add up to one")
if dataset_size > len(source): if dataset_size > len(source):
raise IndexError("Dataset size is larger than the source data") raise IndexError("Dataset size is larger than the source data")
# split the data
splits = []
lower_border = 0
for split in data_split:
upper_border = lower_border + int(split * dataset_size)
splits.append((source[lower_border:upper_border], target[lower_border:upper_border]))
lower_border = upper_border
return (get_data_generator(source_data_raw=split[0],
target_data_raw=split[1],
batch_size=batch_size,
source_tokenizer=source_tokenizer,
target_tokenizer=target_tokenizer,
torch_device=torch_device,
sort=sort) for split in splits)
def get_data_generator(source_data_raw: list[str], target_data_raw: list[str], batch_size: int,
source_tokenizer: Tokenizer, target_tokenizer: Tokenizer,
torch_device: torch.device, sort: bool) -> tuple[torch.tensor, torch.tensor]:
source_tokenizer.no_padding()
target_tokenizer.no_padding()
# sort the training data if true # sort the training data if true
if sort: if sort:
temp = ([list(a) for a in zip(source[:dataset_size], target[:dataset_size])]) temp = ([list(a) for a in zip(source_data_raw, target_data_raw)])
temp.sort(key=lambda s: len(s[0]) + len(s[1])) temp.sort(key=lambda s: len(s[0]) + len(s[1]))
source, target = list(zip(*temp)) source_data, target_data = list(zip(*temp))
else:
source_data = source_data_raw
target_data = target_data_raw
# select random sentences for i in range(0, len(source_data) - batch_size, batch_size):
for i in range(0, len(source) - batch_size, batch_size): x_data = source_data[i:i + batch_size]
x_training_data = source[i:i + batch_size] y_data = target_data[i:i + batch_size]
y_training_data = target[i:i + batch_size]
# tokenize data # tokenize data
tokenizer_en.enable_padding(pad_id=3) source_tokenizer.enable_padding(pad_id=3)
x_training_data = tokenizer_en.encode_batch(x_training_data) x_data = source_tokenizer.encode_batch(x_data)
tokenizer_de.enable_padding(pad_id=3) target_tokenizer.enable_padding(pad_id=3)
y_training_data = tokenizer_de.encode_batch(y_training_data) y_data = target_tokenizer.encode_batch(y_data)
# extract ids for every sequence # extract ids for every sequence
for j in range(batch_size): for j in range(batch_size):
x_training_data[j] = x_training_data[j].ids x_data[j] = x_data[j].ids
y_training_data[j] = y_training_data[j].ids y_data[j] = y_data[j].ids
# put data into tensor # put data into tensor
x_training_data = torch.tensor(x_training_data, device=device) x_data = torch.tensor(x_data, device=torch_device)
y_training_data = torch.tensor(y_training_data, device=device) y_data = torch.tensor(y_data, device=torch_device)
# transpose tensors to match input requirements for lstm # transpose tensors to match input requirements for lstm
x_training_data = torch.transpose(x_training_data, 0, 1) x_data = torch.transpose(x_data, 0, 1)
y_training_data = torch.transpose(y_training_data, 0, 1) y_data = torch.transpose(y_data, 0, 1)
yield x_training_data, y_training_data yield x_data, y_data
\ No newline at end of file
import torch import torch
from prettytable import PrettyTable
def get_available_device() -> torch.device: def get_available_device() -> torch.device:
...@@ -12,3 +13,16 @@ def get_available_device() -> torch.device: ...@@ -12,3 +13,16 @@ def get_available_device() -> torch.device:
device = torch.device("cpu") device = torch.device("cpu")
print("device: cpu") print("device: cpu")
return device return device
def print_model_parameters(model: torch.nn.Module):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
params = parameter.numel()
table.add_row([name, params])
total_params += params
print(table)
print(f"Total Trainable Params: {total_params}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment