Skip to content
Snippets Groups Projects
Commit 7e2a208b authored by marvnsch's avatar marvnsch
Browse files

Fix some stuff

parent 2ddf20a9
No related branches found
No related tags found
No related merge requests found
......@@ -135,6 +135,7 @@ def data_loader(source: list[str],
data_split: tuple[float, float, float] = (0.8, 0.1, 0.1),
sort: bool = True):
def data_loader_generator():
if sum(data_split) != 1.0:
raise ValueError("The data split must add up to one")
......@@ -156,6 +157,7 @@ def data_loader(source: list[str],
target_tokenizer=target_tokenizer,
torch_device=torch_device,
sort=sort) for split in splits)
return data_loader_generator
def get_data_generator(source_data_raw: list[str], target_data_raw: list[str], batch_size: int,
......
from datetime import datetime
import os
import random
from pathlib import Path
from datetime import datetime
import torch
import torch.nn as nn
......@@ -104,15 +105,18 @@ vocab_size = 10000
input_size_encoder = vocab_size
input_size_decoder = vocab_size
output_size_decoder = vocab_size
encoder_embedding_size = 300
decoder_embedding_size = 300
encoder_embedding_size = 200
decoder_embedding_size = 200
model_hidden_size = 1024
model_num_layers = 2
num_epochs = 1
num_epochs = 10
learning_rate = 0.001
batch_size = 64
dataset_size = 1000
dataset_size = 10000
train_dev_val_split = (.8, .1, .1)
train_batches_count = int(train_dev_val_split[0] * dataset_size // batch_size)
# create model
encoder_net = Encoder(input_size=input_size_encoder,
......@@ -142,20 +146,31 @@ source_data, target_data = data.preprocessing.get_prepared_data(source_data_path
source_tokenizer, target_tokenizer = data.preprocessing.create_tokenizers(source_data_path=source_data_path,
target_data_path=target_data_path,
vocab_size=vocab_size)
training_loader, develop_loader, test_loader = data.preprocessing.data_loader(source=source_data,
data_loader = data.preprocessing.data_loader(source=source_data,
target=target_data,
batch_size=batch_size,
source_tokenizer=source_tokenizer,
target_tokenizer=target_tokenizer,
dataset_size=dataset_size,
torch_device=device)
torch_device=device,
data_split=train_dev_val_split)
source_data = None
target_data = None
# create checkpoints directory
try:
os.mkdir(work_dir / "./checkpoints")
except FileExistsError:
pass
# train the model
utils.training.train(model=model,
train_loader=training_loader,
val_loader=test_loader,
data_loader=data_loader,
criterion=criterion,
optimizer=optimizer,
num_epochs=num_epochs,
num_of_batches_per_epoch=train_batches_count,
saving_interval=1000,
model_output_path=model_output_path)
import itertools
import torch
from progressbar import progressbar
import progressbar
def train(model, train_loader, val_loader,
def train(model, data_loader, num_of_batches_per_epoch: int,
criterion: torch.nn.modules.loss, optimizer: torch.optim,
num_epochs: int, saving_interval: int, model_output_path: str):
"""
Train a model based on training data and validation data
:param num_of_batches_per_epoch: count of batches for train epoch
:param model_output_path: path to save the model to
:param model: the model to train (nn.Module)
:param train_loader: the generator object containing the training data
:param val_loader: the generator object containing the validation data
......@@ -23,7 +27,14 @@ def train(model, train_loader, val_loader,
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
for batch_idx, (x_train, y_train) in enumerate(progressbar(train_loader)):
# get data generators
train_loader, _, val_loader = data_loader()
# reset progress bar value
progress = 0
with progressbar.ProgressBar(max_value=num_of_batches_per_epoch) as bar:
for batch_idx, (x_train, y_train) in enumerate(train_loader):
optimizer.zero_grad()
predict = model(x_train, y_train)
......@@ -42,6 +53,11 @@ def train(model, train_loader, val_loader,
# saving the model
torch.save(model.state_dict(), model_output_path)
# update the progress bar (and the counter)
bar.update(progress)
progress += 1
save_counter += 1
# checking loss with validation data
loss_value = 0
val_batch_count = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment