Classic - 2019 December 03¶
[1]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import seaborn as sns; sns.set()
from src.dataset import OzeDataset
from src.Transformer import Transformer
[2]:
BATCH_SIZE = 2
NUM_WORKERS = 4
LR = 1e-2
EPOCHS = 5
TIME_CHUNK = False
K = 672 # Time window length
d_model = 48 # Lattent dim
q = 8 # Query size
v = 8 # Value size
h = 4 # Number of heads
N = 4 # Number of encoder and decoder to stack
d_input = 37 # From dataset
d_output = 8 # From dataset
Load dataset¶
[3]:
dataloader = DataLoader(OzeDataset("dataset.npz"),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
Load network¶
[4]:
# Load transformer with Adam optimizer and MSE loss function
net = Transformer(d_input, d_model, d_output, q, v, h, K, N, TIME_CHUNK)
optimizer = optim.Adam(net.parameters(), lr=LR)
loss_function = nn.MSELoss()
Train¶
[5]:
# Prepare loss history
hist_loss = np.zeros(EPOCHS)
for idx_epoch in range(EPOCHS):
running_loss = 0
with tqdm(total=len(dataloader.dataset), desc=f"[Epoch {idx_epoch+1:3d}/{EPOCHS}]") as pbar:
for idx_batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
# Propagate input
netout = net(x)
# Comupte loss
loss = loss_function(netout, y)
# Backpropage loss
loss.backward()
# Update weights
optimizer.step()
running_loss += loss.item()
pbar.set_postfix({'loss': running_loss/(idx_batch+1)})
pbar.update(BATCH_SIZE)
hist_loss[idx_epoch] = running_loss/len(dataloader)
plt.plot(hist_loss, 'o-')
print(f"Loss: {float(hist_loss[-1]):5f}")
[Epoch 1/5]: 100%|██████████| 1000/1000 [05:35<00:00, 2.98it/s, loss=0.019]
[Epoch 2/5]: 100%|██████████| 1000/1000 [05:55<00:00, 2.81it/s, loss=0.0126]
[Epoch 3/5]: 100%|██████████| 1000/1000 [05:33<00:00, 3.00it/s, loss=0.0115]
[Epoch 4/5]: 100%|██████████| 1000/1000 [05:23<00:00, 3.09it/s, loss=0.0108]
[Epoch 5/5]: 100%|██████████| 1000/1000 [05:21<00:00, 3.11it/s, loss=0.0103]
Loss: 0.010339
Plot results sample¶
[6]:
## Select training example
idx = np.random.randint(0, len(dataloader.dataset))
x, y = dataloader.dataset[idx]
# Run predictions
with torch.no_grad():
x = torch.Tensor(x[np.newaxis, ...])
netout = net(x)
plt.figure(figsize=(30, 30))
for idx_output_var in range(8):
# Select real temperature
y_true = y[:, idx_output_var]
y_pred = netout[0, :, idx_output_var]
y_pred = y_pred.numpy()
plt.subplot(8, 1, idx_output_var+1)
plt.plot(y_true, label="Truth")
plt.plot(y_pred, label="Prediction")
plt.title(dataloader.dataset.labels["X"][idx_output_var])
plt.legend()
plt.savefig("fig.jpg")
Display encoding attention map¶
[7]:
# Select first encoding layer
encoder = net.layers_encoding[0]
# Get the first attention map
attn_map = encoder.attention_map[0]
# Plot
plt.figure(figsize=(20, 20))
sns.heatmap(attn_map)
plt.savefig("attention_map.jpg")