Multitask GP Regression on Molecules#

An example notebook for multitask GP regression on a molecular dataset. We use a multioutput GP model, the intrinsic coregionalisation model (ICM) [1] on the Photoswitch Dataset [2] — using a Tanimoto kernel applied to fragprint representations [2]. The paper and code for the dataset is available here:

Paper: https://pubs.rsc.org/en/content/articlelanding/2022/sc/d2sc04306h

Code: https://github.com/Ryan-Rhys/The-Photoswitch-Dataset

Multitask Learning with Gaussian Processes#

Multitask learning is concerned with using a shared representation to learn several tasks; the idea being that predictive performance on a given task may benefit from the training signals of related tasks. Multioutput Gaussian processes (MOGPs) is the term given to models that perform multitask learning in the Gaussian process framework.

Formally, we seek to carry out Bayesian inference over a stochastic function \(f: \mathbb{R}^D \to \mathbb{R}^P\) where \(P\) is the number of tasks and we have access to observations \(\{(\mathbf{x_{11}}, y_{11}), \dotsc , (\mathbf{x_{1N}}, y_{1N}), \dotsc , (\mathbf{x_{P1}}, y_{P1}), \dotsc , (\mathbf{x_{PN}}, y_{PN})\}\). For each input, we may only have labels for a subset of the tasks.

To build a MOGP we compute a kernel \(k(\mathbf{x}, \mathbf{x'}) \cdot B[i, j]\) where \(B\) is a positive semidefinite \(P \times P\) matrix , where the \((i, j)\text{th}\) entry of the matrix \(B\) multiplies the covariance of the \(i\)-th function at \(\mathbf{x}\) and the \(j\)-th function at \(\mathbf{x'}\). \(B\) is often referred to as an index kernel because it indexes the tasks.

Inference proceeds in analogous fashion to vanilla Gaussian processes by substituting the new expression for the kernel into the equations for the predictive mean and variance.

Positive semi-definiteness of \(B\) is guaranteed by parametrising the Cholesky decomposition \(LL^{\top}\) where \(L\), the Cholesky factor, is a lower triangular matrix and the parameters may be learned alongside the kernel hyperparameters by optimising the marginal likelihood.

An example of what correlated tasks for continuous input spaces might look like is provided below. Data taken from the GPflow tutorial (https://gpflow.readthedocs.io/en/v1.5.1-docs/notebooks/advanced/coregionalisation.html).

a6f82fbaa7aa4d12b121c5a64b129d12

[1]:
# Imports

import warnings
warnings.filterwarnings("ignore") # Turn off Graphein warnings

from botorch import fit_gpytorch_model
import gpytorch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import torch

from gauche.dataloader import MolPropLoader
from gauche.dataloader.data_utils import transform_data

We define our model. See

https://docs.gpytorch.ai/en/stable/examples/03_Multitask_Exact_GPs/Multitask_GP_Regression.html

for a tutorial for the use of the base multioutput GP on non-molecular data!

[2]:
# We define our MOGP model using the Tanimoto kernel

from gauche.kernels.fingerprint_kernels.tanimoto_kernel import TanimotoKernel

num_tasks = 4 # number of tasks i.e. labels
rank = 1 # increasing the rank hyperparameter allows the model to learn more expressive
         # correlations between objectives at the expense of increasing the number of
         # model hyperparameters and potentially overfitting.

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = TanimotoKernel()

        # We learn an IndexKernel for 4 tasks
        # (so we'll actually learn 4x4=16 tasks with correlations)
        self.task_covar_module = gpytorch.kernels.IndexKernel(num_tasks=4, rank=1)

    def forward(self, x, i):
        mean_x = self.mean_module(x)

        # Get input-input covariance
        covar_x = self.covar_module(x)
        # Get task-task covariance
        covar_i = self.task_covar_module(i)
        # Multiply the two together to get the covariance we want
        covar = covar_x.mul(covar_i)

        return gpytorch.distributions.MultivariateNormal(mean_x, covar)

We define our experiment parameters. In this case we are reproducing the results of the multioutput GP prediction task from https://pubs.rsc.org/en/content/articlelanding/2022/sc/d2sc04306h using 20 random splits in the ratio 80/20.

[3]:
# Regression experiment parameters, number of random splits and train/test split size

n_trials = 20
test_set_size = 0.2

Load the Photoswitch Dataset via the MolPropLoader class.

[4]:
# Load the Photoswitch dataset

loader = MolPropLoader()

# Define a utility function for dataloading

def load_task_data(task,
                   loader=MolPropLoader(),
                   path='Photoswitch',
                   representation='ecfp_fragprints'):
    """Load data for a given task.

    Args:
        task: str specifying the task to load data for.
        One of ['Photoswitch', 'Photoswitch_E_n_pi', 'Photoswitch_Z_pi_pi', 'Photoswitch_Z_n_pi']
        loader: DataLoader object
        path: str specifying dataset.
        representation: str specifying representation. One of ['ecfp_fingerprints', 'ecfp_fragprints', 'fragments']

    Returns:
        X_task: tensor of features for task
        y_task: tensor of labels for task
    """

    if representation not in ['ecfp_fragprints', 'ecfp_fingerprints', 'fragments']:
        raise ValueError('representation not valid.'
                         'Please choose one of ecfp_fragprints, ecfp_fingerprints, fragments')

    if task not in ['Photoswitch', 'Photoswitch_E_n_pi', 'Photoswitch_Z_pi_pi', 'Photoswitch_Z_n_pi']:
        raise ValueError('task not valid. Please choose one of Photoswitch,'
                         'Photoswitch_E_n_pi, Photoswitch_Z_pi_pi, Photoswitch_Z_n_pi')

    loader.load_benchmark(path)

    # Featurise the molecules.
    # We use the fragprints representations (a concatenation of Morgan fingerprints and RDKit fragment features)

    loader.featurize(representation)
    X_task = torch.from_numpy(loader.features)
    y_task = torch.from_numpy(loader.labels)

    return X_task, y_task

# Load features X1-X4 and properties (tasks) y1-y4.

X1, y1 = load_task_data('Photoswitch')
X2, y2 = load_task_data('Photoswitch_E_n_pi')
X3, y3 = load_task_data('Photoswitch_Z_pi_pi')
X4, y4 = load_task_data('Photoswitch_Z_n_pi')
Found 13 invalid labels [nan nan nan nan nan nan nan nan nan nan nan nan nan] at indices [41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 158]
To turn validation off, use dataloader.read_csv(..., validate=False).
Found 13 invalid labels [nan nan nan nan nan nan nan nan nan nan nan nan nan] at indices [41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 158]
To turn validation off, use dataloader.read_csv(..., validate=False).
Found 13 invalid labels [nan nan nan nan nan nan nan nan nan nan nan nan nan] at indices [41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 158]
To turn validation off, use dataloader.read_csv(..., validate=False).
Found 13 invalid labels [nan nan nan nan nan nan nan nan nan nan nan nan nan] at indices [41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 158]
To turn validation off, use dataloader.read_csv(..., validate=False).
[ ]:
"""Helper function for model evaluation.
"""

def prevent_test_leakage(x1, x2, x3, y1, y2, y3, X_test):
    """
    Function to prevent test leakage in train/test splits for multitask learning, for example,
    for test data point x_i, do not provide the model with auxiliary tasks P2-P4 when predicting P1.

    param: x1, x2, x3: input molecules for other tasks
    param: y1, y2, y3: labels for other tasks
    param: X_test: the test molecules
    """

    other_tasks = [x1, x2, x3]
    other_labels = [y1, y2, y3]
    for i in range(len(other_tasks)):
        indices_to_delete = []
        for j in range(len(other_tasks[i])):
            other_mol = other_tasks[i][j]
            if np.any([np.array_equal(other_mol, mol) for mol in X_test]) == True:
                indices_to_delete.append(j)
        indices_to_delete.reverse()
        for index in indices_to_delete:
            other_tasks[i] = np.delete(other_tasks[i], index, axis=0)
            other_labels[i] = np.delete(other_labels[i], index, axis=0)

    x1, x2, x3 = other_tasks[0], other_tasks[1], other_tasks[2]
    y1, y2, y3 = other_labels[0], other_labels[1], other_labels[2]

    return x1, x2, x3, y1, y2, y3

# Experiment parameters, train/test split and task to run prediction for
test_set_size = 0.2
task = 'e_iso_pi'

r2_list = []
rmse_list = []
mae_list = []

print('\nBeginning training loop...')

for i in range(0, n_trials):

    print(f'Starting trial {i}')

    if task == 'e_iso_pi':
        X_task = X1
        y_task = y1
    elif task == 'z_iso_pi':
        X_task = X2
        y_task = y2
    elif task == 'e_iso_n':
        X_task = X3
        y_task = y3
    else:
        X_task = X4
        y_task = y4

    X_train, X_test, y_train, y_test = train_test_split(X_task, y_task, test_size=test_set_size, random_state=i)

    # Partition the training data into tasks (most difficult part of training a multioutput GP!)

    if task == 'e_iso_pi':

        # Modify the x-values for the other tasks to exclude X_test
        X2_new, X3_new, X4_new, y2_new, y3_new, y4_new = \
            prevent_test_leakage(X2, X3, X4, y2, y3, y4, X_test)

        train_i_task1 = torch.full((X_train.shape[0], 1), dtype=torch.long, fill_value=0)
        train_i_task2 = torch.full((X2_new.shape[0], 1), dtype=torch.long, fill_value=1)
        train_i_task3 = torch.full((X3_new.shape[0], 1), dtype=torch.long, fill_value=2)
        train_i_task4 = torch.full((X4_new.shape[0], 1), dtype=torch.long, fill_value=3)

        full_train_x = torch.cat([X_train, X2_new, X3_new, X4_new])
        full_train_y = torch.cat([y_train, y2_new, y3_new, y4_new]).flatten()

        test_i_task = torch.full((X_test.shape[0], 1), dtype=torch.long, fill_value=0)


    elif task == 'e_iso_n':
        X1, X3, X4, y1, y3, y4 = \
            prevent_test_leakage(X1, X3, X4, y1, y3, y4, X_test)

        train_i_task1 = torch.full((X1.shape[0], 1), dtype=torch.long, fill_value=0)
        train_i_task2 = torch.full((X_train.shape[0], 1), dtype=torch.long, fill_value=1)
        train_i_task3 = torch.full((X3.shape[0], 1), dtype=torch.long, fill_value=2)
        train_i_task4 = torch.full((X4.shape[0], 1), dtype=torch.long, fill_value=3)

        full_train_x = torch.cat([X1, X_train, X3, X4])
        full_train_y = torch.cat([y1, y_train, y3, y4])

        test_i_task = torch.full((X_test.shape[0], 1), dtype=torch.long, fill_value=1)


    elif task == 'z_iso_pi':
        X1, X2, X4, y1, y2, y4 = \
            prevent_test_leakage(X1, X2, X4, y1, y2, y4, X_test)

        train_i_task1 = torch.full((X1.shape[0], 1), dtype=torch.long, fill_value=0)
        train_i_task2 = torch.full((X2.shape[0], 1), dtype=torch.long, fill_value=1)
        train_i_task3 = torch.full((X_train.shape[0], 1), dtype=torch.long, fill_value=2)
        train_i_task4 = torch.full((X4.shape[0], 1), dtype=torch.long, fill_value=3)

        full_train_x = torch.cat([X1, X2, X_train, X4])
        full_train_y = torch.cat([y1, y2, y_train, y4])

        test_i_task = torch.full((X_test.shape[0], 1), dtype=torch.long, fill_value=2)


    else:
        X1, X2, X3, y1, y2, y3 = \
            prevent_test_leakage(X1, X2, X3, y1, y2, y3, X_test)

        train_i_task1 = torch.full((X1.shape[0], 1), dtype=torch.long, fill_value=0)
        train_i_task2 = torch.full((X2.shape[0], 1), dtype=torch.long, fill_value=1)
        train_i_task3 = torch.full((X3.shape[0], 1), dtype=torch.long, fill_value=2)
        train_i_task4 = torch.full((X_train.shape[0], 1), dtype=torch.long, fill_value=3)

        full_train_x = torch.cat([X1, X2, X3, X_train])
        full_train_y = torch.cat([y1, y2, y3, y_train])

        test_i_task = torch.full((X_test.shape[0], 1), dtype=torch.long, fill_value=3)


    full_train_i = torch.cat([train_i_task1, train_i_task2, train_i_task3, train_i_task4])

    # Gaussian likelihood
    likelihood = gpytorch.likelihoods.GaussianLikelihood()

    # Here we have two items that we're passing in as train_inputs
    model = MultitaskGPModel((full_train_x.float(), full_train_i.float()), full_train_y.float(), likelihood)

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    # Use the BoTorch utility for fitting GPs in order to use the LBFGS-B optimiser (recommended)
    # Set the jitter level larger than the default for the MOGP
    with gpytorch.settings.cholesky_jitter(1e-3):
        fit_gpytorch_model(mll)

    # Get into evaluation (predictive posterior) mode
    model.eval()
    likelihood.eval()

    # The gpytorch.settings.fast_pred_var flag activates LOVE (for fast variances)
    # See https://arxiv.org/abs/1803.06058
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        observed_pred_y = likelihood(model(X_test.float(), test_i_task.float()))

    y_pred = observed_pred_y.mean

    # Compute R^2, RMSE and MAE on Test set
    score = r2_score(y_test, y_pred)
    rmse = np.sqrt(mean_squared_error(y_test, y_pred))
    mae = mean_absolute_error(y_test, y_pred)

    print(rmse)

    r2_list.append(score)
    rmse_list.append(rmse)
    mae_list.append(mae)

r2_list = np.array(r2_list)
rmse_list = np.array(rmse_list)
mae_list = np.array(mae_list)

print("\nmean R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list), np.std(r2_list)/np.sqrt(len(r2_list))))
print("mean RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list), np.std(rmse_list)/np.sqrt(len(rmse_list))))
print("mean MAE: {:.4f} +- {:.4f}\n".format(np.mean(mae_list), np.std(mae_list)/np.sqrt(len(mae_list))))

Beginning training loop...
Starting trial 0
19.47147948595366
Starting trial 1
23.791256668510048
Starting trial 2
26.206456838126353
Starting trial 3
20.030068524271805
Starting trial 4
32.67818405073743
Starting trial 5
23.453557811764345
Starting trial 6
20.75033301782167
Starting trial 7
23.209958101999543
Starting trial 8
24.654899652465076
Starting trial 9
17.65168764479574
Starting trial 10
23.086582644605222
Starting trial 11
16.02576800479441
Starting trial 12
24.76230610058969
Starting trial 13
18.59783949898034
Starting trial 14
23.796203872804114
Starting trial 15
18.30003157481471
Starting trial 16
22.394982557795256
Starting trial 17
23.632087011071746
Starting trial 18
23.739771849691213
Starting trial 19

Multitask learning is especially powerful when there are correlations between different tasks as in the case of photoswitch transition wavelengths. Additionally, when seeking to predict the properties of a molecule for which correlated task labels are available, the multioutput Gaussian process can leverage this information to inform its predictions. For more on multitask learning on molecules cf Ramsundar et al. [3]. For more on multioutput Gaussian processes see [4].

References#

[1] Bonilla, E.V., Chai, K. and Williams, C., Multi-task Gaussian process prediction. Advances in Neural Information Processing Systems, 20, 2007.

[2] Griffiths, R.R., Greenfield, J.L., Thawani, A.R., Jamasb, A.R., Moss, H.B., Bourached, A., Jones, P., McCorkindale, W., Aldrick, A.A. and Fuchter, M.J., 2022. Data-driven discovery of molecular photoswitches with multioutput Gaussian processes. Chemical Science.

[3] Ramsundar, B., Kearnes, S., Riley, P., Webster, D., Konerding, D. and Pande, V., 2015. Massively multitask networks for drug discovery. arXiv preprint arXiv:1502.02072.

[4] Gaussian Processes: from one to many outputs, Invenia Blog.