GP Regression on Protein Sequences: Subsequence String Kernel#

An example notebook for string kernel-based GP regression on a dataset of protein sequences using the subsequence string kernel (SSK) model of [1, 2]. For the bag-of-amino acids representation of the protein sequence (analagous to the bag-of-SMILES model for molecules) see the ‘protein fitness prediction - bag of amino acids notebook’. The protein dataset consists of 151 sequences with a ‘fitness’ function (target label) of the melting point in degrees Celcius. The dataset is collated from values reported in references [3,4,5]. The sequences are each of length 290 and so it is recommended that a GPU is used in conjunction with the SSK kernel.

In contrast to the bag of amino acids notebook, we do not report results on 20 random train/test splits because this would be too computationally intensive for the SSK kernel.

[1]:
"""Imports"""

# Turn off Graphein warnings
import warnings
warnings.filterwarnings("ignore")

from botorch import fit_gpytorch_model
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.models.fully_bayesian import MIN_INFERRED_NOISE_LEVEL
from gpytorch.constraints import GreaterThan
from gpytorch.kernels import ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import GammaPrior
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
import torch

from gauche.dataloader.data_utils import transform_data
from gauche.kernels.string_kernels.sskkernel import pad, encode_string, build_one_hot, SubsequenceStringKernel
[2]:
"""CPU/GPU"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tkwargs = {"dtype": torch.float, "device": device}
print(tkwargs)
{'dtype': torch.float32, 'device': device(type='cpu')}

The Petase Dataset#

The dataset consists of a set of petase protein sequences with amino acid chains of length 290. An example sequence is given below:

MNFPRASRLMQAAVLGGLMAVSAAATAQTNPYARGPPPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCS

For such long sequences the SSK kernel can struggle computationally and so a “bag of amino acids” model is also compared.

[3]:
"""Regression experiments parameters, number of random splits and split size"""

n_trials = 20
test_set_size = 0.2
[4]:
"""Load the petase dataset"""

import sys
sys.path.append('..')

df = pd.read_csv('../gauche/datasets/proteins/petase_151_mutants.csv')
x = df['sequence'].to_list()
y = df['fitness'].to_numpy().reshape(-1, 1)
print(f'len(sequences) {len(x)} | len(targets) {len(y)}')
len(sequences) 151 | len(targets) 151
[5]:
"""Compute the required sequence properties for modelling with the SSK kernel GP."""

maxlen = np.max([len(seq) for seq in x])
# get alphabet of characters used in candidate set (to init SSK)
alphabet = list({l for word in x for l in word})
print(f'alphabet \n {alphabet} \n length of alphabet {len(alphabet)}')
print(f'maxlen {maxlen}')
alphabet
 ['F', 'N', 'E', 'Y', 'A', 'K', 'S', 'P', 'M', 'G', 'R', 'I', 'H', 'Q', 'C', 'W', 'L', 'D', 'V', 'T']
 length of alphabet 20
maxlen 290

GP Regression on the Petase Dataset#

First we define the GP model for protein sequences.

[6]:
"""Process the inputs x to the string kernel GPs"""

# Compute one-hot encodings and an integer index for the given amino acid alphabet
embds, index = build_one_hot(alphabet)
embds = embds.to(**tkwargs)

# Process the string inputs to the SSK model
x = torch.cat([pad(encode_string(seq, index), maxlen).unsqueeze(0) for seq in x], dim=0)
[7]:
"""Compute the train/test split."""

X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=test_set_size, random_state=0)
X_train = X_train.to(**tkwargs)
X_test = X_test.to(**tkwargs)
y_train = torch.tensor(y_train, **tkwargs)
y_test = torch.tensor(y_test, **tkwargs)
[8]:
"""Intialize and fit the models"""

# Likelihood function
likelihood = GaussianLikelihood(
    noise_prior=GammaPrior(torch.tensor(0.9, **tkwargs), torch.tensor(10.0, **tkwargs)),
    noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL),
)

# Covariance function
covar_module = ScaleKernel(SubsequenceStringKernel(embds, index, alphabet, maxlen, **tkwargs))


ssk_gp_model = SingleTaskGP(
    train_X=X_train,
    train_Y=y_train,
    outcome_transform=Standardize(1),
    likelihood=likelihood,
    covar_module=covar_module,
)

mll = ExactMarginalLogLikelihood(model=ssk_gp_model, likelihood=ssk_gp_model.likelihood)
# ideally we can optimize over the kernel hyper-parameters of the string kernel
# however, the gpu memory usage in batch (GPU) version of the kernel is quite high
# while the standard non-batch version is relatively slow for kernel evaluation.
# Nevertheless, the kernel is very robust to choices of the different hypers.
mll.model.covar_module.base_kernel.raw_order_coefs.requires_grad = False
mll.model.covar_module.base_kernel.raw_match_decay.requires_grad = False
mll.model.covar_module.base_kernel.raw_gap_decay.requires_grad = False

fit_gpytorch_model(mll)
[8]:
ExactMarginalLogLikelihood(
  (likelihood): GaussianLikelihood(
    (noise_covar): HomoskedasticNoise(
      (noise_prior): GammaPrior()
      (raw_noise_constraint): GreaterThan(1.000E-04)
    )
  )
  (model): SingleTaskGP(
    (likelihood): GaussianLikelihood(
      (noise_covar): HomoskedasticNoise(
        (noise_prior): GammaPrior()
        (raw_noise_constraint): GreaterThan(1.000E-04)
      )
    )
    (mean_module): ConstantMean()
    (covar_module): ScaleKernel(
      (base_kernel): SubsequenceStringKernel(
        (raw_gap_decay_constraint): Interval(0.000E+00, 1.000E+00)
        (raw_match_decay_constraint): Interval(0.000E+00, 1.000E+00)
        (raw_order_coefs_constraint): Interval(0.000E+00, 1.000E+00)
      )
      (raw_outputscale_constraint): Positive()
    )
    (outcome_transform): Standardize()
  )
)
[9]:
"""Evaluate the trained model."""

posterior = ssk_gp_model.posterior(X_test)
posterior_mean = posterior.mean.cpu().detach()
posterior_std = torch.sqrt(posterior.variance.cpu().detach())

r2 = r2_score(y_test, posterior_mean.numpy())
print(mean_absolute_error(posterior_mean.squeeze(1), y_test.cpu().detach().squeeze(1)))
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[9], line 3
      1 """Evaluate the trained model."""
----> 3 posterior = ssk_gp_model.posterior(X_test)
      4 posterior_mean = posterior.mean.cpu().detach()
      5 posterior_std = torch.sqrt(posterior.variance.cpu().detach())

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/models/gpytorch.py:383, in BatchedMultiOutputGPyTorchModel.posterior(self, X, output_indices, observation_noise, posterior_transform, **kwargs)
    377     X, output_dim_idx = add_output_dim(
    378         X=X, original_batch_shape=self._input_batch_shape
    379     )
    380 # NOTE: BoTorch's GPyTorchModels also inherit from GPyTorch's ExactGP, thus
    381 # self(X) calls GPyTorch's ExactGP's __call__, which computes the posterior,
    382 # rather than e.g. SingleTaskGP's forward, which computes the prior.
--> 383 mvn = self(X)
    384 if observation_noise is not False:
    385     if self._num_outputs > 1:

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:333, in ExactGP.__call__(self, *args, **kwargs)
    328 # Make the prediction
    329 with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
    330     (
    331         predictive_mean,
    332         predictive_covar,
--> 333     ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
    335 # Reshape predictive mean to match the appropriate event shape
    336 predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:289, in DefaultPredictionStrategy.exact_prediction(self, joint_mean, joint_covar)
    285     test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
    286     test_train_covar = joint_covar[..., self.num_train :, : self.num_train]
    288 return (
--> 289     self.exact_predictive_mean(test_mean, test_train_covar),
    290     self.exact_predictive_covar(test_test_covar, test_train_covar),
    291 )

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:306, in DefaultPredictionStrategy.exact_predictive_mean(self, test_mean, test_train_covar)
    294 """
    295 Computes the posterior predictive covariance of a GP
    296
   (...)
    300 :return: The predictive posterior mean of the test points
    301 """
    302 # NOTE TO FUTURE SELF:
    303 # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
    304 # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
--> 306 if len(self.mean_cache.shape) == 4:
    307     res = (test_train_covar @ self.mean_cache.squeeze(1).unsqueeze(-1)).squeeze(-1)
    308 else:

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     57 kwargs_pkl = pickle.dumps(kwargs)
     58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:256, in DefaultPredictionStrategy.mean_cache(self)
    253 train_mean, train_train_covar = mvn.loc, mvn.lazy_covariance_matrix
    255 train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1)
--> 256 mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1)
    258 if settings.detach_test_caches.on():
    259     mean_cache = mean_cache.detach()

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/linear_operator/operators/added_diag_linear_operator.py:209, in AddedDiagLinearOperator.evaluate_kernel(self)
    208 def evaluate_kernel(self):
--> 209     added_diag_linear_op = self.representation_tree()(*self.representation())
    210     return added_diag_linear_op._linear_op + added_diag_linear_op._diag_tensor

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:2064, in LinearOperator.representation_tree(self)
   2054 def representation_tree(self) -> LinearOperatorRepresentationTree:
   2055     """
   2056     Returns a
   2057     :obj:`linear_operator.operators.LinearOperatorRepresentationTree` tree
   (...)
   2062     including all subobjects. This is used internally.
   2063     """
-> 2064     return LinearOperatorRepresentationTree(self)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/linear_operator/operators/linear_operator_representation_tree.py:15, in LinearOperatorRepresentationTree.__init__(self, linear_op)
     13 for arg in itertools.chain(linear_op._args, linear_op._differentiable_kwargs.values()):
     14     if hasattr(arg, "representation") and callable(arg.representation):  # Is it a lazy tensor?
---> 15         representation_size = len(arg.representation())
     16         self.children.append((slice(counter, counter + representation_size, None), arg.representation_tree()))
     17         counter += representation_size

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:397, in LazyEvaluatedKernelTensor.representation(self)
    393     return super().representation()
    394 # Otherwise, we'll evaluate the kernel (or at least its LinearOperator representation) and use its
    395 # representation
    396 else:
--> 397     return self.evaluate_kernel().representation()

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     57 kwargs_pkl = pickle.dumps(kwargs)
     58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25, in recall_grad_state.<locals>.wrapped(self, *args, **kwargs)
     22 @functools.wraps(method)
     23 def wrapped(self, *args, **kwargs):
     24     with torch.set_grad_enabled(self._is_grad_enabled):
---> 25         output = method(self, *args, **kwargs)
     26     return output

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:355, in LazyEvaluatedKernelTensor.evaluate_kernel(self)
    353     temp_active_dims = self.kernel.active_dims
    354     self.kernel.active_dims = None
--> 355     res = self.kernel(
    356         x1,
    357         x2,
    358         diag=False,
    359         last_dim_is_batch=self.last_dim_is_batch,
    360         **self.params,
    361     )
    362     self.kernel.active_dims = temp_active_dims
    364 # Check the size of the output

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/kernels/kernel.py:530, in Kernel.__call__(self, x1, x2, diag, last_dim_is_batch, **params)
    527     res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params)
    528 else:
    529     res = to_linear_operator(
--> 530         super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
    531     )
    532 return res

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs)
     30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31     outputs = self.forward(*inputs, **kwargs)
     32     if isinstance(outputs, list):
     33         return [_validate_module_outputs(output) for output in outputs]

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/kernels/scale_kernel.py:109, in ScaleKernel.forward(self, x1, x2, last_dim_is_batch, diag, **params)
    108 def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params):
--> 109     orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
    110     outputscales = self.outputscale
    111     if last_dim_is_batch:

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gauche/kernels/string_kernels/sskkernel.py:192, in SubsequenceStringKernel.forward(self, X1, X2, diag, **params)
    188         K[batch_idx, :, :] = self._compute_kernel(
    189             X1[batch_idx], X2[batch_idx]
    190         )
    191 else:
--> 192     K = self._compute_kernel(X1, X2, **params)
    193 if diag is True:
    194     return torch.diag(K)

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gauche/kernels/string_kernels/sskkernel.py:221, in SubsequenceStringKernel._compute_kernel(self, X1, X2, **params)
    219 X1_batch = X1.index_select(dim=0, index=X1_batch_indicies)
    220 X2_batch = X2.index_select(dim=0, index=X2_batch_indicies)
--> 221 k_result = self._k(X1_batch, X2_batch)
    222 for j in range(0, len(tuples_batch)):
    223     if (
    224         self.normalize
    225         and X1_diag_Ks[tuples_batch[j][0]] != 0
    226         and X2_diag_Ks[tuples_batch[j][1]] != 0
    227     ):

File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gauche/kernels/string_kernels/sskkernel.py:267, in SubsequenceStringKernel._k(self, s1, s2)
    265     aux = aux2 * match_sq
    266     aux = aux.transpose(-2, -1) @ self.D
--> 267     Kp.append(aux.transpose(-2, -1))
    269 Kp = torch.cat([x.unsqueeze(0) for x in Kp], dim=0)
    270 final_aux1 = S * Kp

KeyboardInterrupt:
[ ]:
"""Plot the R^2"""

fig, ax = plt.subplots(1, 2, figsize = (16, 6))
ax = ax.reshape(-1)

ax.scatter(y_test, posterior_mean.numpy())
ax.set_title(f'Test set $R^2 = {r2:.2f}$')
ax.plot(np.unique(y_test), np.poly1d(np.polyfit(y_test, posterior_mean.numpy(), 1)(np.unique(y_test)), color='k', linewidth=0.4))

References#

[1] Lodhi, H., Saunders, C., Shawe-Taylor, J., Cristianini, N. and Watkins, C., 2002. Text classification using string kernels. The Journal of Machine Learning Research, pp.419-444.

[2] Cancedda, N., Gaussier, E., Goutte, C. and Renders, J.M., 2003. Word sequence kernels. The Journal of Machine Learning Research, pp.1059-1082.

[3] Cui, Y., Chen, Y., Liu, X., Dong, S., Tian, Y.E., Qiao, Y., Mitra, R., Han, J., Li, C., Han, X. and Liu, W., 2021. Computational redesign of a PETase for plastic biodegradation under ambient condition by the GRAPE strategy. ACS Catalysis, 11(3), pp.1340-1350.

[4] Liu, B., He, L., Wang, L., Li, T., Li, C., Liu, H., Luo, Y. and Bao, R., 2018. Protein crystallography and site‐direct mutagenesis analysis of the poly (ethylene terephthalate) hydrolase PETase from Ideonella sakaiensis. ChemBioChem, 19(14), pp.1471-1475.

[5] Joo, S., Cho, I.J., Seo, H., Son, H.F., Sagong, H.Y., Shin, T.J., Choi, S.Y., Lee, S.Y. and Kim, K.J., 2018. Structural insight into molecular mechanism of poly (ethylene terephthalate) degradation. Nature communications, 9(1), p.382.