Classic - 2020 June 27

[1]:
import datetime

import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import seaborn as sns

from tst import Transformer
from tst.loss import OZELoss

from src.dataset import OzeDataset
from src.utils import compute_loss
from src.visualization import map_plot_function, plot_values_distribution, plot_error_distribution, plot_errors_threshold, plot_visual_sample
[2]:
# Training parameters
DATASET_PATH = 'datasets/dataset_57M.npz'
BATCH_SIZE = 8
NUM_WORKERS = 0
LR = 2e-4
EPOCHS = 30

# Model parameters
d_model = 64 # Lattent dim
q = 8 # Query size
v = 8 # Value size
h = 8 # Number of heads
N = 4 # Number of encoder and decoder to stack
attention_size = 12 # Attention window size
dropout = 0.2 # Dropout rate
pe = None # Positional encoding
chunk_mode = None

d_input = 27 # From dataset
d_output = 8 # From dataset

# Config
sns.set()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
Using device cuda:0

Training

Load dataset

[3]:
ozeDataset = OzeDataset(DATASET_PATH)
[4]:
dataset_train, dataset_val, dataset_test = random_split(ozeDataset, (23000, 1000, 1000))

dataloader_train = DataLoader(dataset_train,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=NUM_WORKERS,
                              pin_memory=False
                             )

dataloader_val = DataLoader(dataset_val,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=NUM_WORKERS
                           )

dataloader_test = DataLoader(dataset_test,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=NUM_WORKERS
                            )

Load network

[5]:
# Load transformer with Adam optimizer and MSE loss function
net = Transformer(d_input, d_model, d_output, q, v, h, N, attention_size=attention_size, dropout=dropout, chunk_mode=chunk_mode, pe=pe).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
loss_function = OZELoss(alpha=0.3)

Train

[6]:
model_save_path = f'models/model_{datetime.datetime.now().strftime("%Y_%m_%d__%H%M%S")}.pth'
val_loss_best = np.inf

# Prepare loss history
hist_loss = np.zeros(EPOCHS)
hist_loss_val = np.zeros(EPOCHS)
for idx_epoch in range(EPOCHS):
    running_loss = 0
    with tqdm(total=len(dataloader_train.dataset), desc=f"[Epoch {idx_epoch+1:3d}/{EPOCHS}]") as pbar:
        for idx_batch, (x, y) in enumerate(dataloader_train):
            optimizer.zero_grad()

            # Propagate input
            netout = net(x.to(device))

            # Comupte loss
            loss = loss_function(y.to(device), netout)

            # Backpropage loss
            loss.backward()

            # Update weights
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({'loss': running_loss/(idx_batch+1)})
            pbar.update(x.shape[0])

        train_loss = running_loss/len(dataloader_train)
        val_loss = compute_loss(net, dataloader_val, loss_function, device).item()
        pbar.set_postfix({'loss': train_loss, 'val_loss': val_loss})

        hist_loss[idx_epoch] = train_loss
        hist_loss_val[idx_epoch] = val_loss

        if val_loss < val_loss_best:
            val_loss_best = val_loss
            torch.save(net.state_dict(), model_save_path)

plt.plot(hist_loss, 'o-', label='train')
plt.plot(hist_loss_val, 'o-', label='val')
plt.legend()
print(f"model exported to {model_save_path} with loss {val_loss_best:5f}")
[Epoch   1/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.44it/s, loss=0.0043, val_loss=0.00177]
[Epoch   2/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.48it/s, loss=0.00127, val_loss=0.0013]
[Epoch   3/30]: 100%|██████████| 23000/23000 [05:02<00:00, 76.07it/s, loss=0.000871, val_loss=0.000957]
[Epoch   4/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.47it/s, loss=0.000632, val_loss=0.000511]
[Epoch   5/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.65it/s, loss=0.000491, val_loss=0.000418]
[Epoch   6/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.60it/s, loss=0.000394, val_loss=0.000349]
[Epoch   7/30]: 100%|██████████| 23000/23000 [05:05<00:00, 75.27it/s, loss=0.000325, val_loss=0.000378]
[Epoch   8/30]: 100%|██████████| 23000/23000 [05:03<00:00, 75.82it/s, loss=0.000285, val_loss=0.000268]
[Epoch   9/30]: 100%|██████████| 23000/23000 [05:02<00:00, 75.96it/s, loss=0.000254, val_loss=0.000223]
[Epoch  10/30]: 100%|██████████| 23000/23000 [05:05<00:00, 75.38it/s, loss=0.000222, val_loss=0.00022]
[Epoch  11/30]: 100%|██████████| 23000/23000 [05:03<00:00, 75.86it/s, loss=0.000206, val_loss=0.000187]
[Epoch  12/30]: 100%|██████████| 23000/23000 [05:02<00:00, 75.97it/s, loss=0.000191, val_loss=0.000182]
[Epoch  13/30]: 100%|██████████| 23000/23000 [05:05<00:00, 75.40it/s, loss=0.000177, val_loss=0.000174]
[Epoch  14/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.60it/s, loss=0.000169, val_loss=0.000169]
[Epoch  15/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.42it/s, loss=0.00016, val_loss=0.00015]
[Epoch  16/30]: 100%|██████████| 23000/23000 [05:05<00:00, 75.40it/s, loss=0.000149, val_loss=0.00014]
[Epoch  17/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.46it/s, loss=0.000145, val_loss=0.000163]
[Epoch  18/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.53it/s, loss=0.000138, val_loss=0.000142]
[Epoch  19/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.54it/s, loss=0.000132, val_loss=0.000162]
[Epoch  20/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.46it/s, loss=0.000127, val_loss=0.000135]
[Epoch  21/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.59it/s, loss=0.000121, val_loss=0.000136]
[Epoch  22/30]: 100%|██████████| 23000/23000 [05:03<00:00, 75.79it/s, loss=0.000119, val_loss=0.000127]
[Epoch  23/30]: 100%|██████████| 23000/23000 [05:03<00:00, 75.73it/s, loss=0.000112, val_loss=0.000122]
[Epoch  24/30]: 100%|██████████| 23000/23000 [05:05<00:00, 75.37it/s, loss=0.000109, val_loss=0.000107]
[Epoch  25/30]: 100%|██████████| 23000/23000 [05:03<00:00, 75.67it/s, loss=0.000107, val_loss=0.000147]
[Epoch  26/30]: 100%|██████████| 23000/23000 [05:03<00:00, 75.68it/s, loss=0.000103, val_loss=0.000114]
[Epoch  27/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.60it/s, loss=0.000101, val_loss=0.000108]
[Epoch  28/30]: 100%|██████████| 23000/23000 [05:05<00:00, 75.23it/s, loss=9.82e-5, val_loss=0.000108]
[Epoch  29/30]: 100%|██████████| 23000/23000 [05:05<00:00, 75.32it/s, loss=9.44e-5, val_loss=0.000102]
[Epoch  30/30]: 100%|██████████| 23000/23000 [05:04<00:00, 75.50it/s, loss=9.13e-5, val_loss=0.000107]
model exported to models/model_2020_06_27__062220.pth with loss 0.000102
../../_images/notebooks_trainings_training_2020_06_27__164648_10_2.png

Validation

[7]:
_ = net.eval()

Evaluate on the test dataset

[8]:
predictions = np.empty(shape=(len(dataloader_test.dataset), 168, 8))

idx_prediction = 0
with torch.no_grad():
    for x, y in tqdm(dataloader_test, total=len(dataloader_test)):
        netout = net(x.to(device)).cpu().numpy()
        predictions[idx_prediction:idx_prediction+x.shape[0]] = netout
        idx_prediction += x.shape[0]
100%|██████████| 125/125 [00:04<00:00, 26.91it/s]

Plot results on a sample

[9]:
map_plot_function(ozeDataset, predictions, plot_visual_sample, dataset_indices=dataloader_test.dataset.indices)
../../_images/notebooks_trainings_training_2020_06_27__164648_16_0.png

Plot error distributions

[10]:
map_plot_function(ozeDataset, predictions, plot_error_distribution, dataset_indices=dataloader_test.dataset.indices, time_limit=24)
../../_images/notebooks_trainings_training_2020_06_27__164648_18_0.png

Plot mispredictions thresholds

[11]:
map_plot_function(ozeDataset, predictions, plot_errors_threshold, plot_kwargs={'error_band': 0.1}, dataset_indices=dataloader_test.dataset.indices)
../../_images/notebooks_trainings_training_2020_06_27__164648_20_0.png