Sparse GP Regression on Molecules#

An example notebook for sparse GP regression to enable scalability to large molecular datasests.

[3]:
# Imports

import warnings
warnings.filterwarnings("ignore") # Turn off Graphein warnings

import time

from botorch import fit_gpytorch_model
import gpytorch
from mordred import Calculator, descriptors
import numpy as np
from rdkit import Chem
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler
import torch

from gauche.dataloader import MolPropLoader
from gauche.dataloader.data_utils import transform_data
[4]:
# We define our sparse GP model using and inducing point kernel wrapped around the RQ kernel

from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, InducingPointKernel, RQKernel
from gpytorch.distributions import MultivariateNormal

class SparseGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(SparseGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.base_covar_module = ScaleKernel(RQKernel())
        self.covar_module = InducingPointKernel(self.base_covar_module, inducing_points=train_x[:100, :].clone(), likelihood=likelihood)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

Sparse GP Regression on the Lipophilicity Dataset#

We define our experiment parameters. In this case we are working on the large lipophilicity dataset [1] containing 4200 molecules.

[5]:
# Regression experiments parameters, number of random splits and split size

n_trials = 20
test_set_size = 0.2

Load the Lipophilicity Dataset via the DataLoaderMP class which contains several molecular property prediction benchmark datasets!

[7]:
# Load the Lipophilicity dataset

loader = MolPropLoader()
loader.load_benchmark("Lipophilicity")

# Mordred descriptor computation is expensive
calc = Calculator(descriptors, ignore_3D=False)
mols = [Chem.MolFromSmiles(smi) for smi in loader.features]
t0 = time.time()
X_mordred = [calc(mol) for mol in mols]
t1 = time.time()
print(f'Mordred descriptor computation takes {t1 - t0} seconds')
X_mordred = np.array(X_mordred).astype(np.float64)
y = loader.labels

"""Collect nan indices"""

nan_dims = []

for i in range(len(X_mordred)):
    nan_indices = list(np.where(np.isnan(X_mordred[i, :]))[0])
    for dim in nan_indices:
        if dim not in nan_dims:
            nan_dims.append(dim)

X_mordred = np.delete(X_mordred, nan_dims, axis=1)
Mordred descriptor computation takes 639.975240945816 seconds

Model Evaluation#

Here we define a training/evaluation loop assessing performance using the root mean-square error (RMSE), mean average error (MAE), and \(R^2\) metrics. The evaluate_model function also computes the GP confidence-error curve which will be explained below.

[13]:
import warnings
warnings.filterwarnings("ignore") # Turn off GPyTorch warnings

from matplotlib import pyplot as plt
%matplotlib inline


def evaluate_model(X, y):
    """Helper function for model evaluation.

    Args:
        X: n x d NumPy array of inputs representing molecules
        y: n x 1 NumPy array of output labels
    Returns:
        regression metrics and confidence-error curve plot.
    """

    # initialise performance metric lists
    r2_list = []
    rmse_list = []
    mae_list = []

    # We pre-allocate array for plotting confidence-error curves

    _, _, _, y_test = train_test_split(X, y, test_size=test_set_size)  # To get test set size
    n_test = len(y_test)

    mae_confidence_list = np.zeros((n_trials, n_test))

    print('\nBeginning training loop...')

    for i in range(0, n_trials):

        print(f'Starting trial {i}')

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_set_size, random_state=i)

        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)
        pca_mordred = PCA(n_components=51)
        X_train = pca_mordred.fit_transform(X_train)
        X_test = pca_mordred.transform(X_test)

        #  We standardise the outputs
        _, y_train, _, y_test, y_scaler = transform_data(X_train, y_train, X_test, y_test)

        # Convert numpy arrays to PyTorch tensors and flatten the label vectors
        X_train = torch.tensor(X_train.astype(np.float64))
        X_test = torch.tensor(X_test.astype(np.float64))
        y_train = torch.tensor(y_train).flatten()
        y_test = torch.tensor(y_test).flatten()

        # initialise GP likelihood and model
        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        model = SparseGPModel(X_train, y_train, likelihood)

        # Find optimal model hyperparameters
        # "Loss" for GPs - the marginal log likelihood
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

        # Use the BoTorch utility for fitting GPs in order to use the LBFGS-B optimiser (recommended)
        fit_gpytorch_model(mll)

        # Get into evaluation (predictive posterior) mode
        model.eval()
        likelihood.eval()

        # mean and variance GP prediction
        f_pred = model(X_test)

        y_pred = f_pred.mean
        y_var = f_pred.variance

        # Transform back to real data space to compute metrics and detach gradients. Must unsqueeze dimension
        # to make compatible with inverse_transform in scikit-learn version > 1
        y_pred = y_scaler.inverse_transform(y_pred.detach().unsqueeze(dim=1))
        y_test = y_scaler.inverse_transform(y_test.detach().unsqueeze(dim=1))

        # Compute scores for confidence curve plotting.

        ranked_confidence_list = np.argsort(y_var.detach(), axis=0).flatten()

        for k in range(len(y_test)):

            # Construct the MAE error for each level of confidence

            conf = ranked_confidence_list[0:k+1]
            mae = mean_absolute_error(y_test[conf], y_pred[conf])
            mae_confidence_list[i, k] = mae

        # Output Standardised RMSE and RMSE on Train Set
        y_train = y_train.detach()
        y_pred_train = model(X_train).mean.detach()
        train_rmse_stan = np.sqrt(mean_squared_error(y_train, y_pred_train))
        train_rmse = np.sqrt(mean_squared_error(y_scaler.inverse_transform(y_train.unsqueeze(dim=1)),
                                                y_scaler.inverse_transform(y_pred_train.unsqueeze(dim=1))))

        # Compute R^2, RMSE and MAE on Test set
        score = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        mae = mean_absolute_error(y_test, y_pred)

        r2_list.append(score)
        rmse_list.append(rmse)
        mae_list.append(mae)

    r2_list = np.array(r2_list)
    rmse_list = np.array(rmse_list)
    mae_list = np.array(mae_list)

    print("\nmean R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list), np.std(r2_list)/np.sqrt(len(r2_list))))
    print("mean RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list), np.std(rmse_list)/np.sqrt(len(rmse_list))))
    print("mean MAE: {:.4f} +- {:.4f}\n".format(np.mean(mae_list), np.std(mae_list)/np.sqrt(len(mae_list))))

    # Plot confidence-error curves

    # 1e-14 instead of 0 to for numerical reasons!
    confidence_percentiles = np.arange(1e-14, 100, 100/len(y_test))

    # We plot the Mean-absolute error confidence-error curves

    mae_mean = np.mean(mae_confidence_list, axis=0)
    mae_std = np.std(mae_confidence_list, axis=0)

    mae_mean = np.flip(mae_mean)
    mae_std = np.flip(mae_std)

    # 1 sigma errorbars

    lower = mae_mean - mae_std
    upper = mae_mean + mae_std

    plt.plot(confidence_percentiles, mae_mean, label='mean')
    plt.fill_between(confidence_percentiles, lower, upper, alpha=0.2)
    plt.xlabel('Confidence Percentile')
    plt.ylabel('MAE (nm)')
    plt.ylim([0, np.max(upper) + 1])
    plt.xlim([0, 100 * ((len(y_test) - 1) / len(y_test))])
    plt.yticks(np.arange(0, np.max(upper) + 1, 5.0))
    plt.show()

    return rmse_list, mae_list

Check the perfomance achieved by our sparse GP model.

[14]:
rmse_mordred, mae_mordred = evaluate_model(X_mordred, y)

Beginning training loop...
Starting trial 0
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[14], line 1
----> 1 rmse_mordred, mae_mordred = evaluate_model(X_mordred, y)

Cell In[13], line 63, in evaluate_model(X, y)
     60 mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
     62 # Use the BoTorch utility for fitting GPs in order to use the LBFGS-B optimiser (recommended)
---> 63 fit_gpytorch_model(mll)
     65 # Get into evaluation (predictive posterior) mode
     66 model.eval()

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/fit.py:164, in fit_gpytorch_model(mll, optimizer, optimizer_kwargs, exclude, max_retries, **kwargs)
    158 with (
    159     nullcontext()
    160     if exclude is None
    161     else requires_grad_ctx(mll, assignments={name: False for name in exclude})
    162 ):
    163     try:
--> 164         mll = fit_gpytorch_mll(
    165             mll,
    166             optimizer=optimizer,
    167             optimizer_kwargs=optimizer_kwargs,
    168             **kwargs,
    169         )
    170     except ModelFittingError as err:
    171         warn(str(err), RuntimeWarning)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/fit.py:252, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, warning_handler, caught_exception_types, **ignore)
    250 with catch_warnings(record=True) as warning_list, debug(True):
    251     simplefilter("always", category=OptimizationWarning)
--> 252     optimizer(mll, closure=closure, **optimizer_kwargs)
    254 # Resolved warnings and determine whether or not to retry
    255 done = True

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/fit.py:92, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
     89 if closure_kwargs is not None:
     90     closure = partial(closure, **closure_kwargs)
---> 92 result = scipy_minimize(
     93     closure=closure,
     94     parameters=parameters,
     95     bounds=bounds,
     96     method=method,
     97     options=options,
     98     callback=callback,
     99     timeout_sec=timeout_sec,
    100 )
    101 if result.status != OptimizationStatus.SUCCESS:
    102     warn(
    103         f"`scipy_minimize` terminated with status {result.status}, displaying"
    104         f" original message from `scipy.optimize.minimize`: {result.message}",
    105         OptimizationWarning,
    106     )

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/core.py:109, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
    101         result = OptimizationResult(
    102             step=next(call_counter),
    103             fval=float(wrapped_closure(x)[0]),
    104             status=OptimizationStatus.RUNNING,
    105             runtime=monotonic() - start_time,
    106         )
    107         return callback(parameters, result)  # pyre-ignore [29]
--> 109 raw = minimize_with_timeout(
    110     wrapped_closure,
    111     wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
    112     jac=True,
    113     bounds=bounds_np,
    114     method=method,
    115     options=options,
    116     callback=wrapped_callback,
    117     timeout_sec=timeout_sec,
    118 )
    120 # Post-processing and outcome handling
    121 wrapped_closure.state = asarray(raw.x)  # set parameter state to optimal values

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/utils/timeout.py:80, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
     77     wrapped_callback = callback
     79 try:
---> 80     return optimize.minimize(
     81         fun=fun,
     82         x0=x0,
     83         args=args,
     84         method=method,
     85         jac=jac,
     86         hess=hess,
     87         hessp=hessp,
     88         bounds=bounds,
     89         constraints=constraints,
     90         tol=tol,
     91         callback=wrapped_callback,
     92         options=options,
     93     )
     94 except OptimizationTimeoutError as e:
     95     msg = f"Optimization timed out after {e.runtime} seconds."

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_minimize.py:710, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    707     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    708                              **options)
    709 elif meth == 'l-bfgs-b':
--> 710     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    711                            callback=callback, **options)
    712 elif meth == 'tnc':
    713     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    714                         **options)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_lbfgsb_py.py:365, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    359 task_str = task.tobytes()
    360 if task_str.startswith(b'FG'):
    361     # The minimization routine wants f and g at the current x.
    362     # Note that interruptions due to maxfun are postponed
    363     # until the completion of the current minimization iteration.
    364     # Overwrite f and g:
--> 365     f, g = func_and_grad(x)
    366 elif task_str.startswith(b'NEW_X'):
    367     # new iteration
    368     n_iterations += 1

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:285, in ScalarFunction.fun_and_grad(self, x)
    283 if not np.array_equal(x, self.x):
    284     self._update_x_impl(x)
--> 285 self._update_fun()
    286 self._update_grad()
    287 return self.f, self.g

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
    249 def _update_fun(self):
    250     if not self.f_updated:
--> 251         self._update_fun_impl()
    252         self.f_updated = True

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
    154 def update_fun():
--> 155     self.f = fun_wrapped(self.x)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
    133 self.nfev += 1
    134 # Send a copy because the user may overwrite it.
    135 # Overwriting results in undefined behaviour because
    136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
    138 # Make sure the function returns a true scalar
    139 if not np.isscalar(fx):

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_optimize.py:77, in MemoizeJac.__call__(self, x, *args)
     75 def __call__(self, x, *args):
     76     """ returns the function value """
---> 77     self._compute_if_needed(x, *args)
     78     return self._value

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/scipy/optimize/_optimize.py:71, in MemoizeJac._compute_if_needed(self, x, *args)
     69 if not np.all(x == self.x) or self._value is None or self.jac is None:
     70     self.x = np.asarray(x).copy()
---> 71     fg = self.fun(x, *args)
     72     self.jac = fg[1]
     73     self._value = fg[0]

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/closures/core.py:150, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    147     self.state = state
    149 try:
--> 150     value_tensor, grad_tensors = self.closure(**kwargs)
    151     value = self.as_array(value_tensor)
    152     grads = self._get_gradient_ndarray(fill_value=self.fill_value)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 values = self.forward(**kwargs)
     65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
     68 grads = tuple(param.grad for param in self.parameters.values())
     69 if self.callback:

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

KeyboardInterrupt:

References#

[1] Anna Gaulton, Louisa J Bellis, A Patricia Bento, Jon Chambers, Mark Davies, Anne Hersey, Yvonne Light, Shaun McGlinchey, David Michalovich, Bissan Al-Lazikani, et al. ChEMBL: A large-scale bioactivity database for drug discovery. Nucleic Acids Research, 2012.

[2] Bajusz, D., Rácz, A. and Héberger, K., 2015. Why is Tanimoto index an appropriate choice for fingerprint-based similarity calculations?. Journal of Cheminformatics], 7(1), pp.1-13.