Bayesian GNNs for Molecular Property Prediction#
In this notebook, we will train Bayesian GNNs and compare their regression performance to GPs.
We will compare the molecular property prediction performance of Bayesian GNNs, inspired by Hwang et al. 2020 (https://pubs.acs.org/doi/abs/10.1021/acs.jcim.0c00416). The feature extractor used here is the GIN architecture, with the same graph features used in the graph kernel experiments. Here we rely on a final Bayesian linear layers from Bayesian-Torch (https://github.com/IntelLabs/bayesian-torch).
The densely connected final layer will have weight distributions rather than deterministic weights. The uncertainty of the model will be obtained by repeatedly sampling the network for predictions. We recommend using the CUDA to increase the speed of training the GNN.
Install and import dependencies#
[1]:
# install gauche and other dependencies
%%capture
!pip install gauche[graphs] bayesian-torch torch_geometric
[2]:
import os, sys
sys.path.append('..')
from tqdm import tqdm
import copy
import warnings
warnings.filterwarnings("ignore")
# import sklearn
from gauche.dataloader import MolPropLoader
from gauche.dataloader.data_utils import transform_data
import pandas as pd
import numpy as np
import rdkit.Chem.AllChem as Chem
import matplotlib.pyplot as plt
from scipy.stats import norm
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.utils import add_self_loops
from bayesian_torch.layers import LinearFlipout, LinearReparameterization
Featurise Molecules and PyTorch Geometric Graphs#
To train GNNs on molecular data, we first need to featurise the molecules and convert them into PyTorch Geometric Data objects. We will represent each molecule as a graph with atoms as nodes and bonds as edges, using element numbers and chirality as node features and bond type and E/Z double bond stereo information as edge labels. To apply this featuriser to the benchmark datasets, we will simply provide it as a custom featuriser to the MolPropLoader().featurize
function.
[3]:
# define a custom PyTorch Geometric featuriser that captures
# element number, bond types and chirality
allowable_features = {
"possible_atomic_num_list": list(range(1, 119)),
"possible_chirality_list": [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER,
],
"possible_bonds": [
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC,
],
# (E)/(Z) double bond stereo information
"possible_bond_dirs": [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT,
],
}
# define constants for featurisation and embedding
num_atom_type = 120 # including the extra mask tokens
num_chirality_tag = 3
num_bond_type = 6 # including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3
self_loop_token = 4 # bond type for self-loop edge
masked_bond_token = 5 # bond type for masked edges
def mol_to_pyg(smiles):
"""
A featuriser that accepts an smiles STRING and
converts it to a PyTorch Geometric data object that
is compatible with the GNN modules below.
Args:
smiles: SMILES string
Returns: PyTorch Geometric data object
"""
mol = Chem.MolFromSmiles(smiles)
# derive atom features: atomic number + chirality tag
atom_features = []
for atom in mol.GetAtoms():
atom_features.append(
[
allowable_features["possible_atomic_num_list"].index(
atom.GetAtomicNum()
),
allowable_features["possible_chirality_list"].index(
atom.GetChiralTag()
),
]
)
atom_features = torch.tensor(np.array(atom_features), dtype=torch.long)
# derive bond features: bond type + bond direction
# PyTorch Geometric only uses directed edges,
# so feature information needs to be added twice
edge_index = []
edge_attr = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_index.append((i, j))
edge_index.append((j, i))
# calculate edge features and append them to feature list
edge_feature = [
allowable_features["possible_bonds"].index(bond.GetBondType()),
allowable_features["possible_bond_dirs"].index(bond.GetBondDir()),
]
edge_attr.append(edge_feature)
edge_attr.append(edge_feature)
# set data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edge_index).T, dtype=torch.long)
# set data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_attr), dtype=torch.long)
return Data(x=atom_features, edge_index=edge_index, edge_attr=edge_attr)
[4]:
# load PhotoSwitch dataset and apply mol_to_pyg featuriser
dataset = "Photoswitch"
loader = MolPropLoader()
loader.load_benchmark("Photoswitch")
loader.featurize(lambda smiles: [mol_to_pyg(s) for s in smiles])
Found 13 invalid labels [nan nan nan nan nan nan nan nan nan nan nan nan nan] at indices [41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 158]
To turn validation off, use dataloader.read_csv(..., validate=False).
Define GIN Layers and GNN#
Next, we need to define the GNN architecture. We will use Graph Isomorphism Network (GIN) convolutions from Xu et al., How Powerful are Graph Neural Networks? defined in the GINConv
class. The GNN
class stacks multiple GINConv
layers and applies batch normalisation and a final linear layer to map the node representations to the desired output dimension.
[5]:
class GINConv(MessagePassing):
"""
Extension of the Graph Isomorphism Network to incorporate
edge information by concatenating edge embeddings.
"""
def __init__(self, emb_dim, aggr="add"):
"""
Initialise GIN convolutional layer.
Args:
emb_dim: latent node embedding dimension
aggr: aggregation procedure
"""
super(GINConv, self).__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(emb_dim, 2 * emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * emb_dim, emb_dim),
)
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.aggr = aggr
def forward(self, x, edge_index, edge_attr):
"""
Message passing and aggregation function
of the adapted GIN convolutional layer.
Args:
x: node features
edge_index: adjacency list
edge_attr: edge features
Returns: transformed and aggregated node embeddings
"""
# add self loops to edge index
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# update edge attributes to represent self-loop edges
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:, 0] = self_loop_token
self_loop_attr = self_loop_attr.to(edge_attr.device).to(
edge_attr.dtype
)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)
# generate edge embeddings and propagate
edge_embeddings = self.edge_embedding1(
edge_attr[:, 0]
) + self.edge_embedding2(edge_attr[:, 1])
return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
def message(self, x_j, edge_attr):
return x_j + edge_attr
def update(self, aggr_out):
return self.mlp(aggr_out)
class GNN(torch.nn.Module):
"""
Combine multiple GNN layers into a network.
"""
def __init__(self, num_layers=5, embed_dim=300, gnn_type="gin"):
"""
Compose convolution layers into GNN. Pretrained parameters
exist for a 5-layer network with 300 hidden units.
Args:
num_layers: number of convolution layers
embed_dim: dimension of node embeddings
gnn_type: type of convolutional layer to use
"""
self.num_layers = num_layers
self.embed_dim = embed_dim
self.gnn_type = gnn_type
super(GNN, self).__init__()
# initialise label embeddings
self.x_embedding1 = torch.nn.Embedding(num_atom_type, self.embed_dim)
self.x_embedding2 = torch.nn.Embedding(
num_chirality_tag, self.embed_dim
)
torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
# initialise GNN layers
self.gnns = torch.nn.ModuleList()
for layer in range(self.num_layers):
if gnn_type == "gin":
self.gnns.append(GINConv(emb_dim=self.embed_dim))
else:
raise NotImplementedError("Invalid GNN layer type.")
# initialise BatchNorm layers
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers):
self.batch_norms.append(torch.nn.BatchNorm1d(self.embed_dim))
def forward(self, x, edge_index, edge_attr):
"""
Forward function of the GNN class that takes a PyTorch geometric
representation of a molecule or a batch of molecules
and generates the node embeddings for each atom.
Args:
x: node features
edge_index: adjacency list
edge_attr: edge features
Returns: tensor of num_nodes x embedding_dim embeddings
"""
# x[:, 0] corresponds to 'possible_atomic_num_list',
# x[:, 1] corresponds to 'possible_chirality_list'
x = self.x_embedding1(x[:, 0]) + self.x_embedding2(x[:, 1])
for layer in range(self.num_layers):
# x are atom features of the molecule and edge_attr the atomic features of the molecule
x = self.gnns[layer](x, edge_index, edge_attr)
x = self.batch_norms[layer](x)
if layer != self.num_layers - 1:
x = F.relu(x)
return x
Construct the Bayesian GNN module and define the training and evaluation protocol#
Finally, we will define the Bayesian GNN module that takes the node-wise representations from the GNN, combines them into a graph-level representation and applies the LinearReparameterization
layer from bayesian_torch
to define the final Bayesian linear layer. We will also define the training and evaluation protocol for the Bayesian GNNs: These are similar to the ones you might use for deterministic GNNs, with the difference that we can draw multiple sample from the distribution over
network weights to obtain predictions and uncertainty estimates.
[6]:
class BayesianGNN(nn.Module):
def __init__(self, embed_dim=300, num_layers=5, gnn_type='gin', output_dim=1):
super().__init__()
self.gnn = GNN(num_layers=num_layers, embed_dim=embed_dim, gnn_type=gnn_type)
self.pooling = global_mean_pool
self.bayes_layer = LinearReparameterization(embed_dim, output_dim)
def forward(self, data):
x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
res = self.gnn(x, edge_index, edge_attr)
res = self.pooling(res, batch)
# bayesian layer
kl_sum = 0
res, kl = self.bayes_layer(res)
kl_sum += kl
return res, kl_sum
[7]:
def nlpd(y, y_pred, y_std):
nld = 0
for y_true, mu, std in zip(y.ravel(), y_pred.ravel(), y_std.ravel()):
nld += -norm(mu, std).logpdf(y_true)
return nld / len(y)
def predict(regressor, X, samples = 100):
preds = [regressor(X)[0] for i in range(samples)]
preds = torch.stack(preds)
means = preds.mean(axis=0)
var = preds.var(axis=0)
return means, var
def graph_append_label(X, y, device):
G = []
for g, label in zip(X, y):
g.y = label
g = g.to(device)
G.append(g)
return G
[8]:
def evaluate_model(X, y, n_epochs=100, n_trials=20, kld_beta = 1., verbose=True):
test_set_size = 0.2
batch_size = 32
r2_list = []
rmse_list = []
mae_list = []
nlpd_list = []
# We pre-allocate array for plotting confidence-error curves
_, y_test = train_test_split(y, test_size=test_set_size) # To get test set size
n_test = len(y_test)
mae_confidence_list = np.zeros((n_trials, n_test))
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Device being used: {device}')
print('\nBeginning training loop...')
for i in range(0, n_trials):
print(f'Starting trial {i}')
# split data and perform standardization
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_set_size, random_state=i)
_, y_train, _, y_test, y_scaler = transform_data(y_train, y_train, y_test, y_test)
# include y in the pyg graph structure
G_train = graph_append_label(X_train, y_train, device)
G_test = graph_append_label(X_test, y_test, device)
dataloader_train = DataLoader(G_train, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader_test = DataLoader(G_test, batch_size=len(G_test))
# initialize model and optimizer
model = BayesianGNN(gnn_type='gin').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()
training_loss = []
status = {}
best_loss = np.inf
patience = 50
count = 0
pbar = tqdm(range(n_epochs))
for epoch in pbar:
running_kld_loss = 0
running_mse_loss = 0
running_loss = 0
for data in dataloader_train:
optimizer.zero_grad()
output, kl = model(data)
# calculate loss with kl term for Bayesian layers
target = torch.tensor(np.array(data.y), dtype=torch.float, device=device)
mse_loss = criterion(output, target)
loss = mse_loss + kl * kld_beta / batch_size
loss.backward()
optimizer.step()
running_mse_loss += mse_loss.detach().cpu().numpy()
running_kld_loss += kl.detach().cpu().numpy()
running_loss += loss.detach().cpu().numpy()
status.update({
'Epoch': epoch,
'loss': running_loss/len(dataloader_train),
'kl': running_kld_loss/len(dataloader_train),
'mse': running_mse_loss/len(dataloader_train)
})
training_loss.append(status)
pbar.set_postfix(status)
with torch.no_grad():
for data in dataloader_test:
y_pred, y_var = predict(model, data)
target = torch.tensor(np.array(data.y), dtype=torch.float, device=device)
val_loss = criterion(y_pred, target)
val_loss = val_loss.detach().cpu().numpy()
status.update({'val_loss': val_loss})
if best_loss > val_loss:
best_loss = val_loss
best_model = copy.deepcopy(model)
count = 0
else:
count += 1
if count >= patience:
if verbose: print(f'Early stopping reached! Best validation loss {best_loss}')
break
pbar.set_postfix(status)
# Get into evaluation (predictive posterior) mode
model = best_model
model.eval()
with torch.no_grad():
# mean and variance by sampling
for data in dataloader_test:
y_pred, y_var = predict(model, data, samples=100)
y_pred = y_pred.detach().cpu().numpy()
y_var = y_var.detach().cpu().numpy()
uq_nlpd = nlpd(y_test, y_pred, np.sqrt(y_var))
if verbose: print(f'NLPD: {uq_nlpd}')
# Transform back to real data space to compute metrics and detach gradients.
y_pred = y_scaler.inverse_transform(y_pred)
y_test = y_scaler.inverse_transform(y_test)
# Compute scores for confidence curve plotting.
ranked_confidence_list = np.argsort(y_var, axis=0).flatten()
for k in range(len(y_test)):
# Construct the MAE error for each level of confidence
conf = ranked_confidence_list[0:k+1]
mae = mean_absolute_error(y_test[conf], y_pred[conf])
mae_confidence_list[i, k] = mae
# Compute R^2, RMSE and MAE on Test set
score = r2_score(y_test, y_pred)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
mae = mean_absolute_error(y_test, y_pred)
r2_list.append(score)
rmse_list.append(rmse)
mae_list.append(mae)
nlpd_list.append(uq_nlpd)
r2_list = np.array(r2_list)
rmse_list = np.array(rmse_list)
mae_list = np.array(mae_list)
nlpd_list = np.array(nlpd_list)
print("\nmean R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list), np.std(r2_list)/np.sqrt(len(r2_list))))
print("mean RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list), np.std(rmse_list)/np.sqrt(len(rmse_list))))
print("mean MAE: {:.4f} +- {:.4f}\n".format(np.mean(mae_list), np.std(mae_list)/np.sqrt(len(mae_list))))
print("mean NLPD: {:.4f} +- {:.4f}\n".format(np.mean(nlpd_list), np.std(nlpd_list)/np.sqrt(len(nlpd_list))))
print()
# Plot confidence-error curves
# 1e-14 instead of 0 to for numerical reasons!
confidence_percentiles = np.arange(1e-14, 100, 100/len(y_test))
# We plot the Mean-absolute error confidence-error curves
mae_mean = np.mean(mae_confidence_list, axis=0)
mae_std = np.std(mae_confidence_list, axis=0)
mae_mean = np.flip(mae_mean)
mae_std = np.flip(mae_std)
# 1 sigma errorbars
lower = mae_mean - mae_std
upper = mae_mean + mae_std
fig = plt.figure()
plt.plot(confidence_percentiles, mae_mean, label='mean')
plt.fill_between(confidence_percentiles, lower, upper, alpha=0.2)
plt.xlabel('Confidence Percentile')
plt.ylabel('MAE')
plt.ylim([0, np.max(upper) + 1])
plt.xlim([0, 100 * ((len(y_test) - 1) / len(y_test))])
results = {
'confidence_percentiles': confidence_percentiles,
'mae_mean': mae_mean,
'mae_std': mae_std,
'mae': mae_list,
'rmse': rmse_list,
'r2': r2_list,
'nlpd': nlpd_list,
}
return results, fig
Run the training and evaluation loop#
[9]:
results, fig = evaluate_model(loader.features, loader.labels, n_epochs=300, n_trials=10, kld_beta=50.0)
fig.show()
Device being used: cuda
Beginning training loop...
Starting trial 0
47%|████▋ | 142/300 [02:00<02:14, 1.18it/s, Epoch=142, loss=3.74, kl=1.33, mse=1.66, val_loss=0.30306697]
Early stopping reached! Best validation loss 0.11581621319055557
NLPD: 0.7891262299567953
Starting trial 1
69%|██████▊ | 206/300 [02:35<01:11, 1.32it/s, Epoch=206, loss=3.04, kl=1.03, mse=1.43, val_loss=0.19800645]
Early stopping reached! Best validation loss 0.1911984235048294
NLPD: 0.9329165919360254
Starting trial 2
35%|███▍ | 104/300 [01:19<02:29, 1.31it/s, Epoch=104, loss=3.31, kl=1.7, mse=0.646, val_loss=0.29850987]
Early stopping reached! Best validation loss 0.22155441343784332
NLPD: 0.9101015414270804
Starting trial 3
24%|██▍ | 72/300 [00:55<02:54, 1.30it/s, Epoch=72, loss=4.45, kl=1.89, mse=1.49, val_loss=0.1720525]
Early stopping reached! Best validation loss 0.15270735323429108
NLPD: 0.513528300588651
Starting trial 4
22%|██▏ | 66/300 [00:50<03:00, 1.30it/s, Epoch=66, loss=3.92, kl=1.73, mse=1.22, val_loss=0.5326283]
Early stopping reached! Best validation loss 0.3228945732116699
NLPD: 1.683367726229072
Starting trial 5
53%|█████▎ | 158/300 [02:02<01:50, 1.29it/s, Epoch=158, loss=3.21, kl=1.09, mse=1.5, val_loss=0.3413453]
Early stopping reached! Best validation loss 0.2585623860359192
NLPD: 0.9266979137971999
Starting trial 6
30%|███ | 91/300 [01:08<02:38, 1.32it/s, Epoch=91, loss=3.96, kl=1.75, mse=1.21, val_loss=0.38592163]
Early stopping reached! Best validation loss 0.2153647243976593
NLPD: 0.6858583467681377
Starting trial 7
53%|█████▎ | 158/300 [01:59<01:47, 1.32it/s, Epoch=158, loss=2.86, kl=1.22, mse=0.95, val_loss=0.32246053]
Early stopping reached! Best validation loss 0.1867567002773285
NLPD: 0.996825519470283
Starting trial 8
100%|██████████| 300/300 [03:47<00:00, 1.32it/s, Epoch=299, loss=2.2, kl=0.876, mse=0.828, val_loss=0.25577274]
NLPD: 1.02045369536678
Starting trial 9
25%|██▌ | 75/300 [00:56<02:50, 1.32it/s, Epoch=75, loss=3.63, kl=1.99, mse=0.516, val_loss=0.24762282]
Early stopping reached! Best validation loss 0.16555529832839966
NLPD: 0.5259638768835783
mean R^2: 0.7894 +- 0.0227
mean RMSE: 29.7135 +- 1.3321
mean MAE: 21.4297 +- 0.9820
mean NLPD: 0.8985 +- 0.0993
For graph features input, the results for the Bayesian GNN are below for the various datasets:
Photoswitch |
Freesolv |
ESOL |
Lipophilicity |
|
---|---|---|---|---|
R2 |
0.8048 +- 0.0155 |
0.7884 +- 0.0056 |
0.8224 +- 0.0044 |
0.6208 +- 0.0199 |
RMSE |
28.5302 +- 1.2050 |
0.9610 +- 0.0148 |
0.8800 +- 0.0098 |
0.7317 +- 0.0175 |
MAE |
20.7182 +- 0.9928 |
0.7264 +- 0.0161 |
0.6622 +- 0.0079 |
0.5328 +- 0.0111 |
NLPD |
0.9960 +- 0.1286 |
1.0060 +- 0.0153 |
1.6990 +- 0.1085 |
1.1406 +- 0.0120 |