Classic - 2019 December 03

[1]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import seaborn as sns; sns.set()

from src.dataset import OzeDataset
from src.Transformer import Transformer
[2]:
BATCH_SIZE = 2
NUM_WORKERS = 4
LR = 1e-2
EPOCHS = 5
TIME_CHUNK = False

K = 672 # Time window length
d_model = 48 # Lattent dim
q = 8 # Query size
v = 8 # Value size
h = 4 # Number of heads
N = 4 # Number of encoder and decoder to stack

d_input = 37 # From dataset
d_output = 8 # From dataset

Load dataset

[3]:
dataloader = DataLoader(OzeDataset("dataset.npz"),
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        num_workers=NUM_WORKERS
                       )

Load network

[4]:
# Load transformer with Adam optimizer and MSE loss function
net = Transformer(d_input, d_model, d_output, q, v, h, K, N, TIME_CHUNK)
optimizer = optim.Adam(net.parameters(), lr=LR)
loss_function = nn.MSELoss()

Train

[5]:
# Prepare loss history
hist_loss = np.zeros(EPOCHS)
for idx_epoch in range(EPOCHS):
    running_loss = 0
    with tqdm(total=len(dataloader.dataset), desc=f"[Epoch {idx_epoch+1:3d}/{EPOCHS}]") as pbar:
        for idx_batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            # Propagate input
            netout = net(x)

            # Comupte loss
            loss = loss_function(netout, y)

            # Backpropage loss
            loss.backward()

            # Update weights
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({'loss': running_loss/(idx_batch+1)})
            pbar.update(BATCH_SIZE)

    hist_loss[idx_epoch] = running_loss/len(dataloader)
plt.plot(hist_loss, 'o-')
print(f"Loss: {float(hist_loss[-1]):5f}")
[Epoch   1/5]: 100%|██████████| 1000/1000 [05:35<00:00,  2.98it/s, loss=0.019]
[Epoch   2/5]: 100%|██████████| 1000/1000 [05:55<00:00,  2.81it/s, loss=0.0126]
[Epoch   3/5]: 100%|██████████| 1000/1000 [05:33<00:00,  3.00it/s, loss=0.0115]
[Epoch   4/5]: 100%|██████████| 1000/1000 [05:23<00:00,  3.09it/s, loss=0.0108]
[Epoch   5/5]: 100%|██████████| 1000/1000 [05:21<00:00,  3.11it/s, loss=0.0103]
Loss: 0.010339
../../_images/notebooks_trainings_training_2019_12_03__170100_8_2.png

Plot results sample

[6]:
## Select training example
idx = np.random.randint(0, len(dataloader.dataset))
x, y = dataloader.dataset[idx]

# Run predictions
with torch.no_grad():
    x = torch.Tensor(x[np.newaxis, ...])
    netout = net(x)

plt.figure(figsize=(30, 30))
for idx_output_var in range(8):
    # Select real temperature
    y_true = y[:, idx_output_var]

    y_pred = netout[0, :, idx_output_var]
    y_pred = y_pred.numpy()

    plt.subplot(8, 1, idx_output_var+1)

    plt.plot(y_true, label="Truth")
    plt.plot(y_pred, label="Prediction")
    plt.title(dataloader.dataset.labels["X"][idx_output_var])
plt.legend()
plt.savefig("fig.jpg")
../../_images/notebooks_trainings_training_2019_12_03__170100_10_0.png

Display encoding attention map

[7]:
# Select first encoding layer
encoder = net.layers_encoding[0]

# Get the first attention map
attn_map = encoder.attention_map[0]

# Plot
plt.figure(figsize=(20, 20))
sns.heatmap(attn_map)
plt.savefig("attention_map.jpg")
../../_images/notebooks_trainings_training_2019_12_03__170100_12_0.png