Benchmark LSTM - 2020 March 31¶
[1]:
import datetime
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import seaborn as sns
from tst.loss import OZELoss
from src.benchmark import LSTM
from src.dataset import OzeDataset
from src.utils import compute_loss
from src.visualization import map_plot_function, plot_values_distribution, plot_error_distribution, plot_errors_threshold, plot_visual_sample
[2]:
# Training parameters
DATASET_PATH = 'datasets/dataset_CAPT_v7.npz'
BATCH_SIZE = 8
NUM_WORKERS = 4
LR = 1e-4
EPOCHS = 30
# Model parameters
d_model = 48 # Lattent dim
N = 4 # Number of layers
dropout = 0.2 # Dropout rate
d_input = 38 # From dataset
d_output = 8 # From dataset
# Config
sns.set()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
Using device cuda:0
Training¶
Load dataset¶
[3]:
ozeDataset = OzeDataset(DATASET_PATH)
dataset_train, dataset_val, dataset_test = random_split(ozeDataset, (38000, 1000, 1000))
[4]:
dataloader_train = DataLoader(dataset_train,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=False
)
dataloader_val = DataLoader(dataset_val,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
dataloader_test = DataLoader(dataset_test,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
Load network¶
[5]:
# Load transformer with Adam optimizer and MSE loss function
net = LSTM(d_input, d_model, d_output, N, dropout=dropout).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
loss_function = OZELoss(alpha=0.3)
Train¶
[6]:
model_save_path = f'models/model_LSTM_{datetime.datetime.now().strftime("%Y_%m_%d__%H%M%S")}.pth'
val_loss_best = np.inf
# Prepare loss history
hist_loss = np.zeros(EPOCHS)
hist_loss_val = np.zeros(EPOCHS)
for idx_epoch in range(EPOCHS):
running_loss = 0
with tqdm(total=len(dataloader_train.dataset), desc=f"[Epoch {idx_epoch+1:3d}/{EPOCHS}]") as pbar:
for idx_batch, (x, y) in enumerate(dataloader_train):
optimizer.zero_grad()
# Propagate input
netout = net(x.to(device))
# Comupte loss
loss = loss_function(y.to(device), netout)
# Backpropage loss
loss.backward()
# Update weights
optimizer.step()
running_loss += loss.item()
pbar.set_postfix({'loss': running_loss/(idx_batch+1)})
pbar.update(x.shape[0])
train_loss = running_loss/len(dataloader_train)
val_loss = compute_loss(net, dataloader_val, loss_function, device).item()
pbar.set_postfix({'loss': train_loss, 'val_loss': val_loss})
hist_loss[idx_epoch] = train_loss
hist_loss_val[idx_epoch] = val_loss
if val_loss < val_loss_best:
val_loss_best = val_loss
torch.save(net.state_dict(), model_save_path)
plt.plot(hist_loss, 'o-', label='train')
plt.plot(hist_loss_val, 'o-', label='val')
plt.legend()
print(f"model exported to {model_save_path} with loss {val_loss_best:5f}")
[Epoch 1/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.30it/s, loss=0.0153, val_loss=0.00872]
[Epoch 2/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.49it/s, loss=0.00701, val_loss=0.00584]
[Epoch 3/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.29it/s, loss=0.00527, val_loss=0.00495]
[Epoch 4/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.41it/s, loss=0.00461, val_loss=0.00438]
[Epoch 5/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.48it/s, loss=0.00417, val_loss=0.00407]
[Epoch 6/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.29it/s, loss=0.00387, val_loss=0.00379]
[Epoch 7/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.41it/s, loss=0.00363, val_loss=0.00355]
[Epoch 8/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.48it/s, loss=0.00343, val_loss=0.00344]
[Epoch 9/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.26it/s, loss=0.00326, val_loss=0.00322]
[Epoch 10/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.48it/s, loss=0.00313, val_loss=0.00312]
[Epoch 11/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.47it/s, loss=0.00302, val_loss=0.00299]
[Epoch 12/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.31it/s, loss=0.00292, val_loss=0.00289]
[Epoch 13/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.41it/s, loss=0.00283, val_loss=0.00282]
[Epoch 14/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.52it/s, loss=0.00275, val_loss=0.00273]
[Epoch 15/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.27it/s, loss=0.00267, val_loss=0.00268]
[Epoch 16/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.42it/s, loss=0.00259, val_loss=0.00259]
[Epoch 17/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.48it/s, loss=0.00252, val_loss=0.0025]
[Epoch 18/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.29it/s, loss=0.00245, val_loss=0.0025]
[Epoch 19/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.39it/s, loss=0.00239, val_loss=0.00239]
[Epoch 20/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.55it/s, loss=0.00233, val_loss=0.00232]
[Epoch 21/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.33it/s, loss=0.00226, val_loss=0.00232]
[Epoch 22/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.46it/s, loss=0.00222, val_loss=0.00225]
[Epoch 23/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.51it/s, loss=0.00218, val_loss=0.00218]
[Epoch 24/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.33it/s, loss=0.00215, val_loss=0.00216]
[Epoch 25/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.48it/s, loss=0.00213, val_loss=0.00212]
[Epoch 26/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.46it/s, loss=0.0021, val_loss=0.00212]
[Epoch 27/30]: 100%|██████████| 38000/38000 [09:33<00:00, 66.30it/s, loss=0.00207, val_loss=0.00209]
[Epoch 28/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.40it/s, loss=0.00205, val_loss=0.00208]
[Epoch 29/30]: 100%|██████████| 38000/38000 [09:31<00:00, 66.49it/s, loss=0.00203, val_loss=0.00206]
[Epoch 30/30]: 100%|██████████| 38000/38000 [09:32<00:00, 66.33it/s, loss=0.00201, val_loss=0.00201]
model exported to models/model_LSTM_2020_03_31__112637.pth with loss 0.002010
Validation¶
[7]:
_ = net.eval()
Evaluate on the test dataset¶
[8]:
predictions = np.empty(shape=(len(dataloader_test.dataset), 168, 8))
idx_prediction = 0
with torch.no_grad():
for x, y in tqdm(dataloader_test, total=len(dataloader_test)):
netout = net(x.to(device)).cpu().numpy()
predictions[idx_prediction:idx_prediction+x.shape[0]] = netout
idx_prediction += x.shape[0]
100%|██████████| 125/125 [00:03<00:00, 35.96it/s]
Plot results on a sample¶
[9]:
map_plot_function(ozeDataset, predictions, plot_visual_sample, dataset_indices=dataloader_test.dataset.indices)
Plot error distributions¶
[10]:
map_plot_function(ozeDataset, predictions, plot_error_distribution, dataset_indices=dataloader_test.dataset.indices, time_limit=24)
Plot mispredictions thresholds¶
[11]:
map_plot_function(ozeDataset, predictions, plot_errors_threshold, plot_kwargs={'error_band': 0.1}, dataset_indices=dataloader_test.dataset.indices)