Classic - 2020 February 25

[1]:
import datetime

import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import seaborn as sns

from tst import Transformer
from tst.loss import OZELoss

from src.dataset import OzeDataset
from src.utils import visual_sample, compute_loss
[2]:
# Training parameters
DATASET_PATH = 'datasets/dataset_v6_full.npz'
BATCH_SIZE = 8
NUM_WORKERS = 4
LR = 2e-4
EPOCHS = 50

# Model parameters
d_model = 64 # Lattent dim
q = 8 # Query size
v = 8 # Value size
h = 4 # Number of heads
N = 4 # Number of encoder and decoder to stack
attention_size = 24 # Attention window size
dropout = 0.2 # Dropout rate
pe = None # Positional encoding
chunk_mode = None

d_input = 38 # From dataset
d_output = 8 # From dataset

# Config
sns.set()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
Using device cuda:0

Training

Load dataset

[3]:
ozeDataset = OzeDataset(DATASET_PATH)

dataset_train, dataset_val, dataset_test = random_split(ozeDataset, (38000, 500, 500))

dataloader_train = DataLoader(dataset_train,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=NUM_WORKERS,
                              pin_memory=False
                             )

dataloader_val = DataLoader(dataset_val,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=NUM_WORKERS
                           )

dataloader_test = DataLoader(dataset_test,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=NUM_WORKERS
                            )

Load network

[4]:
# Load transformer with Adam optimizer and MSE loss function
net = Transformer(d_input, d_model, d_output, q, v, h, N, attention_size=attention_size, dropout=dropout, chunk_mode=chunk_mode, pe=pe).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
loss_function = OZELoss(alpha=0.3)

Train

[5]:
model_save_path = f'models/model_{datetime.datetime.now().strftime("%Y_%m_%d__%H%M%S")}.pth'
val_loss_best = np.inf

# Prepare loss history
hist_loss = np.zeros(EPOCHS)
hist_loss_val = np.zeros(EPOCHS)
for idx_epoch in range(EPOCHS):
    running_loss = 0
    with tqdm(total=len(dataloader_train.dataset), desc=f"[Epoch {idx_epoch+1:3d}/{EPOCHS}]") as pbar:
        for idx_batch, (x, y) in enumerate(dataloader_train):
            optimizer.zero_grad()

            # Propagate input
            netout = net(x.to(device))

            # Comupte loss
            loss = loss_function(y.to(device), netout)

            # Backpropage loss
            loss.backward()

            # Update weights
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({'loss': running_loss/(idx_batch+1)})
            pbar.update(x.shape[0])

        train_loss = running_loss/len(dataloader_train)
        val_loss = compute_loss(net, dataloader_val, loss_function, device).item()
        pbar.set_postfix({'loss': train_loss, 'val_loss': val_loss})

        hist_loss[idx_epoch] = train_loss
        hist_loss_val[idx_epoch] = val_loss

        if val_loss < val_loss_best:
            val_loss_best = val_loss
            torch.save(net.state_dict(), model_save_path)

plt.plot(hist_loss, 'o-', label='train')
plt.plot(hist_loss_val, 'o-', label='val')
plt.legend()
print(f"model exported to {model_save_path} with loss {val_loss_best:5f}")
[Epoch   1/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.49it/s, loss=0.00563, val_loss=0.00277]
[Epoch   2/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.40it/s, loss=0.00223, val_loss=0.00155]
[Epoch   3/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.50it/s, loss=0.00149, val_loss=0.00123]
[Epoch   4/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.00113, val_loss=0.000995]
[Epoch   5/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.50it/s, loss=0.000901, val_loss=0.00084]
[Epoch   6/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000759, val_loss=0.000615]
[Epoch   7/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.53it/s, loss=0.00065, val_loss=0.000555]
[Epoch   8/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000573, val_loss=0.000527]
[Epoch   9/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.53it/s, loss=0.000514, val_loss=0.000619]
[Epoch  10/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000473, val_loss=0.000503]
[Epoch  11/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.53it/s, loss=0.000445, val_loss=0.000407]
[Epoch  12/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000402, val_loss=0.000384]
[Epoch  13/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.55it/s, loss=0.000388, val_loss=0.000408]
[Epoch  14/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000371, val_loss=0.000333]
[Epoch  15/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.52it/s, loss=0.000344, val_loss=0.000333]
[Epoch  16/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000331, val_loss=0.000407]
[Epoch  17/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.52it/s, loss=0.000309, val_loss=0.000326]
[Epoch  18/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000304, val_loss=0.000302]
[Epoch  19/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.54it/s, loss=0.00029, val_loss=0.000312]
[Epoch  20/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000287, val_loss=0.000266]
[Epoch  21/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.54it/s, loss=0.000269, val_loss=0.00029]
[Epoch  22/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000265, val_loss=0.000237]
[Epoch  23/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.54it/s, loss=0.000255, val_loss=0.000237]
[Epoch  24/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000255, val_loss=0.00024]
[Epoch  25/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.53it/s, loss=0.000244, val_loss=0.000225]
[Epoch  26/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000239, val_loss=0.000231]
[Epoch  27/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.54it/s, loss=0.000229, val_loss=0.000241]
[Epoch  28/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000226, val_loss=0.000245]
[Epoch  29/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.52it/s, loss=0.000221, val_loss=0.000221]
[Epoch  30/50]: 100%|██████████| 38000/38000 [14:34<00:00, 43.43it/s, loss=0.000226, val_loss=0.000208]
[Epoch  31/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.54it/s, loss=0.000209, val_loss=0.000219]
[Epoch  32/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000223, val_loss=0.000222]
[Epoch  33/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.55it/s, loss=0.000217, val_loss=0.000224]
[Epoch  34/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000202, val_loss=0.000199]
[Epoch  35/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.54it/s, loss=0.000194, val_loss=0.000191]
[Epoch  36/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000198, val_loss=0.000185]
[Epoch  37/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.51it/s, loss=0.000189, val_loss=0.000211]
[Epoch  38/50]: 100%|██████████| 38000/38000 [14:36<00:00, 43.35it/s, loss=0.000195, val_loss=0.00018]
[Epoch  39/50]: 100%|██████████| 38000/38000 [14:33<00:00, 43.51it/s, loss=0.000183, val_loss=0.00029]
[Epoch  40/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000183, val_loss=0.000161]
[Epoch  41/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.55it/s, loss=0.000181, val_loss=0.000168]
[Epoch  42/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000178, val_loss=0.000179]
[Epoch  43/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.53it/s, loss=0.000174, val_loss=0.000174]
[Epoch  44/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.000181, val_loss=0.000155]
[Epoch  45/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.53it/s, loss=0.000168, val_loss=0.000191]
[Epoch  46/50]: 100%|██████████| 38000/38000 [14:34<00:00, 43.43it/s, loss=0.000165, val_loss=0.000185]
[Epoch  47/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.53it/s, loss=0.00017, val_loss=0.000159]
[Epoch  48/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.42it/s, loss=0.00017, val_loss=0.000159]
[Epoch  49/50]: 100%|██████████| 38000/38000 [14:32<00:00, 43.54it/s, loss=0.000165, val_loss=0.000173]
[Epoch  50/50]: 100%|██████████| 38000/38000 [14:35<00:00, 43.41it/s, loss=0.000161, val_loss=0.000166]
model exported to models/model_2020_02_25__102558.pth with loss 0.000155

../../_images/notebooks_trainings_training_2020_02_25__224128_9_3.png

Validation

[6]:
_ = net.eval()

Plot results on a sample

[7]:
visual_sample(dataloader_test, net, device)
plt.savefig("fig")
../../_images/notebooks_trainings_training_2020_02_25__224128_13_0.png

Plot encoding attention map

[8]:
# Select first encoding layer
encoder = net.layers_encoding[0]

# Get the first attention map
attn_map = encoder.attention_map[0].cpu()

# Plot
plt.figure(figsize=(20, 20))
sns.heatmap(attn_map)
plt.savefig("attention_map")
../../_images/notebooks_trainings_training_2020_02_25__224128_15_0.png

Evaluate on the test dataset

[9]:
predictions = np.empty(shape=(len(dataloader_test.dataset), 168, 8))

idx_prediction = 0
with torch.no_grad():
    for x, y in tqdm(dataloader_test, total=len(dataloader_test)):
        netout = net(x.to(device)).cpu().numpy()
        predictions[idx_prediction:idx_prediction+x.shape[0]] = netout
        idx_prediction += x.shape[0]
100%|██████████| 63/63 [00:05<00:00, 12.26it/s]
[10]:
fig, axes = plt.subplots(8, 1)
fig.set_figwidth(20)
fig.set_figheight(40)
plt.subplots_adjust(bottom=0.05)

occupancy = (dataloader_test.dataset.dataset._x.numpy()[..., dataloader_test.dataset.dataset.labels["Z"].index("occupancy")].mean(axis=0)>0.5).astype(float)
y_true_full = dataloader_test.dataset.dataset._y[dataloader_test.dataset.indices].numpy()

for idx_label, (label, ax) in enumerate(zip(dataloader_test.dataset.dataset.labels['X'], axes)):
    # Select output to plot
    y_true = y_true_full[..., idx_label]
    y_pred = predictions[..., idx_label]

    # Rescale
    y_true = dataloader_test.dataset.dataset.rescale(y_true, idx_label)
    y_pred = dataloader_test.dataset.dataset.rescale(y_pred, idx_label)

    if label.startswith('Q_'):
        # Convert kJ/h to kW
        y_true /= 3600
        y_pred /= 3600

    # Compute delta, mean and std
    delta = np.abs(y_true - y_pred)

    mean = delta.mean(axis=0)
    std = delta.std(axis=0)

    # Plot
    # Labels for consumption and temperature
    if label.startswith('Q_'):
        y_label_unit = 'kW'


    else:
        y_label_unit = '°C'

    # Occupancy
    occupancy_idxes = np.where(np.diff(occupancy) != 0)[0]
    for idx in range(0, len(occupancy_idxes), 2):
        ax.axvspan(occupancy_idxes[idx], occupancy_idxes[idx+1], facecolor='green', alpha=.15)

    # Std
    ax.fill_between(np.arange(mean.shape[0]), (mean - std), (mean + std), alpha=.4, label='std')

    # Mean
    ax.plot(mean, label='mean')

    # Title and labels
    ax.set_title(label)
    ax.set_xlabel('time', fontsize=16)
    ax.set_ylabel(y_label_unit, fontsize=16)

    ax.legend()

plt.savefig('error_mean_std')
../../_images/notebooks_trainings_training_2020_02_25__224128_18_0.png