Sparse GP Regression on Molecules#
An example notebook for sparse GP regression to enable scalability to large molecular datasests.
[3]:
# Imports
import warnings
warnings.filterwarnings("ignore") # Turn off Graphein warnings
import time
from botorch import fit_gpytorch_model
import gpytorch
from mordred import Calculator, descriptors
import numpy as np
from rdkit import Chem
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler
import torch
from gauche.dataloader import MolPropLoader
from gauche.dataloader.data_utils import transform_data
[4]:
# We define our sparse GP model using and inducing point kernel wrapped around the RQ kernel
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, InducingPointKernel, RQKernel
from gpytorch.distributions import MultivariateNormal
class SparseGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(SparseGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
self.base_covar_module = ScaleKernel(RQKernel())
self.covar_module = InducingPointKernel(self.base_covar_module, inducing_points=train_x[:100, :].clone(), likelihood=likelihood)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
Sparse GP Regression on the Lipophilicity Dataset#
We define our experiment parameters. In this case we are working on the large lipophilicity dataset [1] containing 4200 molecules.
[5]:
# Regression experiments parameters, number of random splits and split size
n_trials = 20
test_set_size = 0.2
Load the Lipophilicity Dataset via the DataLoaderMP class which contains several molecular property prediction benchmark datasets!
[7]:
# Load the Lipophilicity dataset
loader = MolPropLoader()
loader.load_benchmark("Lipophilicity")
# Mordred descriptor computation is expensive
calc = Calculator(descriptors, ignore_3D=False)
mols = [Chem.MolFromSmiles(smi) for smi in loader.features]
t0 = time.time()
X_mordred = [calc(mol) for mol in mols]
t1 = time.time()
print(f'Mordred descriptor computation takes {t1 - t0} seconds')
X_mordred = np.array(X_mordred).astype(np.float64)
y = loader.labels
"""Collect nan indices"""
nan_dims = []
for i in range(len(X_mordred)):
nan_indices = list(np.where(np.isnan(X_mordred[i, :]))[0])
for dim in nan_indices:
if dim not in nan_dims:
nan_dims.append(dim)
X_mordred = np.delete(X_mordred, nan_dims, axis=1)
Mordred descriptor computation takes 639.975240945816 seconds
Model Evaluation#
Here we define a training/evaluation loop assessing performance using the root mean-square error (RMSE), mean average error (MAE), and \(R^2\) metrics. The evaluate_model
function also computes the GP confidence-error curve which will be explained below.
[13]:
import warnings
warnings.filterwarnings("ignore") # Turn off GPyTorch warnings
from matplotlib import pyplot as plt
%matplotlib inline
def evaluate_model(X, y):
"""Helper function for model evaluation.
Args:
X: n x d NumPy array of inputs representing molecules
y: n x 1 NumPy array of output labels
Returns:
regression metrics and confidence-error curve plot.
"""
# initialise performance metric lists
r2_list = []
rmse_list = []
mae_list = []
# We pre-allocate array for plotting confidence-error curves
_, _, _, y_test = train_test_split(X, y, test_size=test_set_size) # To get test set size
n_test = len(y_test)
mae_confidence_list = np.zeros((n_trials, n_test))
print('\nBeginning training loop...')
for i in range(0, n_trials):
print(f'Starting trial {i}')
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_set_size, random_state=i)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
pca_mordred = PCA(n_components=51)
X_train = pca_mordred.fit_transform(X_train)
X_test = pca_mordred.transform(X_test)
# We standardise the outputs
_, y_train, _, y_test, y_scaler = transform_data(X_train, y_train, X_test, y_test)
# Convert numpy arrays to PyTorch tensors and flatten the label vectors
X_train = torch.tensor(X_train.astype(np.float64))
X_test = torch.tensor(X_test.astype(np.float64))
y_train = torch.tensor(y_train).flatten()
y_test = torch.tensor(y_test).flatten()
# initialise GP likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = SparseGPModel(X_train, y_train, likelihood)
# Find optimal model hyperparameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
# Use the BoTorch utility for fitting GPs in order to use the LBFGS-B optimiser (recommended)
fit_gpytorch_model(mll)
# Get into evaluation (predictive posterior) mode
model.eval()
likelihood.eval()
# mean and variance GP prediction
f_pred = model(X_test)
y_pred = f_pred.mean
y_var = f_pred.variance
# Transform back to real data space to compute metrics and detach gradients. Must unsqueeze dimension
# to make compatible with inverse_transform in scikit-learn version > 1
y_pred = y_scaler.inverse_transform(y_pred.detach().unsqueeze(dim=1))
y_test = y_scaler.inverse_transform(y_test.detach().unsqueeze(dim=1))
# Compute scores for confidence curve plotting.
ranked_confidence_list = np.argsort(y_var.detach(), axis=0).flatten()
for k in range(len(y_test)):
# Construct the MAE error for each level of confidence
conf = ranked_confidence_list[0:k+1]
mae = mean_absolute_error(y_test[conf], y_pred[conf])
mae_confidence_list[i, k] = mae
# Output Standardised RMSE and RMSE on Train Set
y_train = y_train.detach()
y_pred_train = model(X_train).mean.detach()
train_rmse_stan = np.sqrt(mean_squared_error(y_train, y_pred_train))
train_rmse = np.sqrt(mean_squared_error(y_scaler.inverse_transform(y_train.unsqueeze(dim=1)),
y_scaler.inverse_transform(y_pred_train.unsqueeze(dim=1))))
# Compute R^2, RMSE and MAE on Test set
score = r2_score(y_test, y_pred)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
mae = mean_absolute_error(y_test, y_pred)
r2_list.append(score)
rmse_list.append(rmse)
mae_list.append(mae)
r2_list = np.array(r2_list)
rmse_list = np.array(rmse_list)
mae_list = np.array(mae_list)
print("\nmean R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list), np.std(r2_list)/np.sqrt(len(r2_list))))
print("mean RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list), np.std(rmse_list)/np.sqrt(len(rmse_list))))
print("mean MAE: {:.4f} +- {:.4f}\n".format(np.mean(mae_list), np.std(mae_list)/np.sqrt(len(mae_list))))
# Plot confidence-error curves
# 1e-14 instead of 0 to for numerical reasons!
confidence_percentiles = np.arange(1e-14, 100, 100/len(y_test))
# We plot the Mean-absolute error confidence-error curves
mae_mean = np.mean(mae_confidence_list, axis=0)
mae_std = np.std(mae_confidence_list, axis=0)
mae_mean = np.flip(mae_mean)
mae_std = np.flip(mae_std)
# 1 sigma errorbars
lower = mae_mean - mae_std
upper = mae_mean + mae_std
plt.plot(confidence_percentiles, mae_mean, label='mean')
plt.fill_between(confidence_percentiles, lower, upper, alpha=0.2)
plt.xlabel('Confidence Percentile')
plt.ylabel('MAE (nm)')
plt.ylim([0, np.max(upper) + 1])
plt.xlim([0, 100 * ((len(y_test) - 1) / len(y_test))])
plt.yticks(np.arange(0, np.max(upper) + 1, 5.0))
plt.show()
return rmse_list, mae_list
Check the perfomance achieved by our sparse GP model.
[14]:
rmse_mordred, mae_mordred = evaluate_model(X_mordred, y)
Beginning training loop...
Starting trial 0
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[14], line 1
----> 1 rmse_mordred, mae_mordred = evaluate_model(X_mordred, y)
Cell In[13], line 63, in evaluate_model(X, y)
60 mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
62 # Use the BoTorch utility for fitting GPs in order to use the LBFGS-B optimiser (recommended)
---> 63 fit_gpytorch_model(mll)
65 # Get into evaluation (predictive posterior) mode
66 model.eval()
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/fit.py:164, in fit_gpytorch_model(mll, optimizer, optimizer_kwargs, exclude, max_retries, **kwargs)
158 with (
159 nullcontext()
160 if exclude is None
161 else requires_grad_ctx(mll, assignments={name: False for name in exclude})
162 ):
163 try:
--> 164 mll = fit_gpytorch_mll(
165 mll,
166 optimizer=optimizer,
167 optimizer_kwargs=optimizer_kwargs,
168 **kwargs,
169 )
170 except ModelFittingError as err:
171 warn(str(err), RuntimeWarning)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
102 if optimizer is not None: # defer to per-method defaults
103 kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
106 mll,
107 type(mll.likelihood),
108 type(mll.model),
109 closure=closure,
110 closure_kwargs=closure_kwargs,
111 optimizer_kwargs=optimizer_kwargs,
112 **kwargs,
113 )
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
91 func = self.__getitem__(types=types)
92 try:
---> 93 return func(*args, **kwargs)
94 except MDNotImplementedError:
95 # Traverses registered methods in order, yields whenever a match is found
96 funcs = self.dispatch_iter(*types)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/fit.py:252, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, warning_handler, caught_exception_types, **ignore)
250 with catch_warnings(record=True) as warning_list, debug(True):
251 simplefilter("always", category=OptimizationWarning)
--> 252 optimizer(mll, closure=closure, **optimizer_kwargs)
254 # Resolved warnings and determine whether or not to retry
255 done = True
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/fit.py:92, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
89 if closure_kwargs is not None:
90 closure = partial(closure, **closure_kwargs)
---> 92 result = scipy_minimize(
93 closure=closure,
94 parameters=parameters,
95 bounds=bounds,
96 method=method,
97 options=options,
98 callback=callback,
99 timeout_sec=timeout_sec,
100 )
101 if result.status != OptimizationStatus.SUCCESS:
102 warn(
103 f"`scipy_minimize` terminated with status {result.status}, displaying"
104 f" original message from `scipy.optimize.minimize`: {result.message}",
105 OptimizationWarning,
106 )
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/core.py:109, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
101 result = OptimizationResult(
102 step=next(call_counter),
103 fval=float(wrapped_closure(x)[0]),
104 status=OptimizationStatus.RUNNING,
105 runtime=monotonic() - start_time,
106 )
107 return callback(parameters, result) # pyre-ignore [29]
--> 109 raw = minimize_with_timeout(
110 wrapped_closure,
111 wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
112 jac=True,
113 bounds=bounds_np,
114 method=method,
115 options=options,
116 callback=wrapped_callback,
117 timeout_sec=timeout_sec,
118 )
120 # Post-processing and outcome handling
121 wrapped_closure.state = asarray(raw.x) # set parameter state to optimal values
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/utils/timeout.py:80, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
77 wrapped_callback = callback
79 try:
---> 80 return optimize.minimize(
81 fun=fun,
82 x0=x0,
83 args=args,
84 method=method,
85 jac=jac,
86 hess=hess,
87 hessp=hessp,
88 bounds=bounds,
89 constraints=constraints,
90 tol=tol,
91 callback=wrapped_callback,
92 options=options,
93 )
94 except OptimizationTimeoutError as e:
95 msg = f"Optimization timed out after {e.runtime} seconds."
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_minimize.py:710, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
707 res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
708 **options)
709 elif meth == 'l-bfgs-b':
--> 710 res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
711 callback=callback, **options)
712 elif meth == 'tnc':
713 res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
714 **options)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_lbfgsb_py.py:365, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
359 task_str = task.tobytes()
360 if task_str.startswith(b'FG'):
361 # The minimization routine wants f and g at the current x.
362 # Note that interruptions due to maxfun are postponed
363 # until the completion of the current minimization iteration.
364 # Overwrite f and g:
--> 365 f, g = func_and_grad(x)
366 elif task_str.startswith(b'NEW_X'):
367 # new iteration
368 n_iterations += 1
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:285, in ScalarFunction.fun_and_grad(self, x)
283 if not np.array_equal(x, self.x):
284 self._update_x_impl(x)
--> 285 self._update_fun()
286 self._update_grad()
287 return self.f, self.g
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
249 def _update_fun(self):
250 if not self.f_updated:
--> 251 self._update_fun_impl()
252 self.f_updated = True
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
154 def update_fun():
--> 155 self.f = fun_wrapped(self.x)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
133 self.nfev += 1
134 # Send a copy because the user may overwrite it.
135 # Overwriting results in undefined behaviour because
136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
138 # Make sure the function returns a true scalar
139 if not np.isscalar(fx):
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_optimize.py:77, in MemoizeJac.__call__(self, x, *args)
75 def __call__(self, x, *args):
76 """ returns the function value """
---> 77 self._compute_if_needed(x, *args)
78 return self._value
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_optimize.py:71, in MemoizeJac._compute_if_needed(self, x, *args)
69 if not np.all(x == self.x) or self._value is None or self.jac is None:
70 self.x = np.asarray(x).copy()
---> 71 fg = self.fun(x, *args)
72 self.jac = fg[1]
73 self._value = fg[0]
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/closures/core.py:150, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
147 self.state = state
149 try:
--> 150 value_tensor, grad_tensors = self.closure(**kwargs)
151 value = self.as_array(value_tensor)
152 grads = self._get_gradient_ndarray(fill_value=self.fill_value)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
64 values = self.forward(**kwargs)
65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
68 grads = tuple(param.grad for param in self.parameters.values())
69 if self.callback:
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
482 if has_torch_function_unary(self):
483 return handle_torch_function(
484 Tensor.backward,
485 (self,),
(...)
490 inputs=inputs,
491 )
--> 492 torch.autograd.backward(
493 self, gradient, retain_graph, create_graph, inputs=inputs
494 )
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
246 retain_graph = create_graph
248 # The reason we repeat the same comment below is that
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
254 retain_graph,
255 create_graph,
256 inputs,
257 allow_unreachable=True,
258 accumulate_grad=True,
259 )
KeyboardInterrupt:
References#
[1] Anna Gaulton, Louisa J Bellis, A Patricia Bento, Jon Chambers, Mark Davies, Anne Hersey, Yvonne Light, Shaun McGlinchey, David Michalovich, Bissan Al-Lazikani, et al. ChEMBL: A large-scale bioactivity database for drug discovery. Nucleic Acids Research, 2012.
[2] Bajusz, D., Rácz, A. and Héberger, K., 2015. Why is Tanimoto index an appropriate choice for fingerprint-based similarity calculations?. Journal of Cheminformatics], 7(1), pp.1-13.