Chunk - 2019 December 23¶
[1]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import seaborn as sns
from src.dataset import OzeDataset
from src.Transformer import Transformer
from src.utils import visual_sample
[2]:
# Training parameters
DATASET_PATH = 'datasets/dataset_large.npz'
BATCH_SIZE = 4
NUM_WORKERS = 4
LR = 3e-4
EPOCHS = 20
TIME_CHUNK = True
# Testing parameters
TEST_DATASET_PATH = 'datasets/dataset_test.npz'
# Model parameters
K = 672 # Time window length
d_model = 48 # Lattent dim
q = 8 # Query size
v = 8 # Value size
h = 4 # Number of heads
N = 4 # Number of encoder and decoder to stack
pe = None # Positional encoding
d_input = 37 # From dataset
d_output = 8 # From dataset
# Config
sns.set()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
Using device cuda:0
Training¶
Load dataset¶
[3]:
dataloader = DataLoader(OzeDataset(DATASET_PATH),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
Load network¶
[4]:
# Load transformer with Adam optimizer and MSE loss function
net = Transformer(d_input, d_model, d_output, q, v, h, K, N, TIME_CHUNK, pe).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
temperature_loss_function = nn.MSELoss()
consumption_loss_function = nn.MSELoss()
Train¶
[5]:
# Prepare loss history
hist_loss = np.zeros(EPOCHS)
for idx_epoch in range(EPOCHS):
running_loss = 0
with tqdm(total=len(dataloader.dataset), desc=f"[Epoch {idx_epoch+1:3d}/{EPOCHS}]") as pbar:
for idx_batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
# Propagate input
netout = net(x.to(device))
# Comupte loss
y = y.to(device)
delta_Q = consumption_loss_function(netout[..., :-1], y[..., :-1])
delta_T = temperature_loss_function(netout[..., -1], y[..., -1])
loss = torch.log(1 + delta_T) + 0.3 * torch.log(1 + delta_Q)
# Backpropage loss
loss.backward()
# Update weights
optimizer.step()
running_loss += loss.item()
pbar.set_postfix({'loss': running_loss/(idx_batch+1)})
pbar.update(x.shape[0])
hist_loss[idx_epoch] = running_loss/len(dataloader)
plt.plot(hist_loss, 'o-')
print(f"Loss: {float(hist_loss[-1]):5f}")
str_loss = str(hist_loss[-1]).split('.')[-1][:5]
torch.save(net, f"models/model_{str_loss}.pth")
[Epoch 1/20]: 100%|██████████| 7500/7500 [03:08<00:00, 39.71it/s, loss=0.0127]
[Epoch 2/20]: 100%|██████████| 7500/7500 [03:10<00:00, 39.42it/s, loss=0.00693]
[Epoch 3/20]: 100%|██████████| 7500/7500 [03:08<00:00, 39.89it/s, loss=0.00605]
[Epoch 4/20]: 100%|██████████| 7500/7500 [03:07<00:00, 40.03it/s, loss=0.00541]
[Epoch 5/20]: 100%|██████████| 7500/7500 [03:07<00:00, 39.94it/s, loss=0.00508]
[Epoch 6/20]: 100%|██████████| 7500/7500 [03:08<00:00, 39.79it/s, loss=0.00466]
[Epoch 7/20]: 100%|██████████| 7500/7500 [03:10<00:00, 39.32it/s, loss=0.00428]
[Epoch 8/20]: 100%|██████████| 7500/7500 [03:11<00:00, 39.22it/s, loss=0.00394]
[Epoch 9/20]: 100%|██████████| 7500/7500 [03:10<00:00, 39.43it/s, loss=0.00372]
[Epoch 10/20]: 100%|██████████| 7500/7500 [03:09<00:00, 39.66it/s, loss=0.00344]
[Epoch 11/20]: 100%|██████████| 7500/7500 [03:09<00:00, 39.52it/s, loss=0.00331]
[Epoch 12/20]: 100%|██████████| 7500/7500 [03:08<00:00, 39.71it/s, loss=0.0031]
[Epoch 13/20]: 100%|██████████| 7500/7500 [03:08<00:00, 39.75it/s, loss=0.00293]
[Epoch 14/20]: 100%|██████████| 7500/7500 [03:08<00:00, 39.80it/s, loss=0.00283]
[Epoch 15/20]: 100%|██████████| 7500/7500 [03:07<00:00, 40.09it/s, loss=0.00269]
[Epoch 16/20]: 100%|██████████| 7500/7500 [03:07<00:00, 40.07it/s, loss=0.0026]
[Epoch 17/20]: 100%|██████████| 7500/7500 [03:06<00:00, 40.24it/s, loss=0.00246]
[Epoch 18/20]: 100%|██████████| 7500/7500 [03:06<00:00, 40.13it/s, loss=0.00238]
[Epoch 19/20]: 100%|██████████| 7500/7500 [03:07<00:00, 40.03it/s, loss=0.00227]
[Epoch 20/20]: 100%|██████████| 7500/7500 [03:08<00:00, 39.85it/s, loss=0.0022]
Loss: 0.002201
Validation¶
Load dataset and network¶
[6]:
datatestloader = DataLoader(OzeDataset(TEST_DATASET_PATH),
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
[7]:
# net = torch.load('models/model_00247.pth', map_location=device)
Plot results on a sample¶
[8]:
visual_sample(datatestloader, net, device)
plt.savefig("fig.jpg")
Plot encoding attention map¶
[9]:
# Select first encoding layer
encoder = net.layers_encoding[0]
# Get the first attention map
attn_map = encoder.attention_map[0].cpu()
# Plot
plt.figure(figsize=(20, 20))
sns.heatmap(attn_map)
plt.savefig("attention_map.jpg")
Evaluate on the test dataset¶
[10]:
predictions = np.empty(shape=(len(datatestloader.dataset), K, 8))
idx_prediction = 0
with torch.no_grad():
for x, y in tqdm(datatestloader, total=len(datatestloader)):
netout = net(x.to(device)).cpu().numpy()
predictions[idx_prediction:idx_prediction+x.shape[0]] = netout
idx_prediction += x.shape[0]
100%|██████████| 125/125 [00:05<00:00, 24.60it/s]
[11]:
fig, axes = plt.subplots(8, 1)
fig.set_figwidth(20)
fig.set_figheight(40)
plt.subplots_adjust(bottom=0.05)
occupancy = (datatestloader.dataset._x.numpy()[..., datatestloader.dataset.labels["Z"].index("occupancy")].mean(axis=0)>0.5).astype(float)
for idx_label, (label, ax) in enumerate(zip(datatestloader.dataset.labels['X'], axes)):
# Select output to plot
y_true = datatestloader.dataset._y.numpy()[..., idx_label]
y_pred = predictions[..., idx_label]
# Rescale
y_true = datatestloader.dataset.rescale(y_true, idx_label)
y_pred = datatestloader.dataset.rescale(y_pred, idx_label)
# Compute delta, mean and std
delta = np.abs(y_true - y_pred)
mean = delta.mean(axis=0)
std = delta.std(axis=0)
# Plot
# Labels for consumption and temperature
if label.startswith('Q_'):
y_label_unit = 'kW'
else:
y_label_unit = '°C'
# Occupancy
occupancy_idxes = np.where(np.diff(occupancy) != 0)[0]
for idx in range(0, len(occupancy_idxes), 2):
ax.axvspan(occupancy_idxes[idx], occupancy_idxes[idx+1], facecolor='green', alpha=.15)
# Std
ax.fill_between(np.arange(mean.shape[0]), (mean - std), (mean + std), alpha=.4, label='std')
# Mean
ax.plot(mean, label='mean')
# Title and labels
ax.set_title(label)
ax.set_xlabel('time', fontsize=16)
ax.set_ylabel(y_label_unit, fontsize=16)
ax.legend()
plt.savefig('error_mean_std.jpg')