GP Regression on Protein Sequences: Subsequence String Kernel#
An example notebook for string kernel-based GP regression on a dataset of protein sequences using the subsequence string kernel (SSK) model of [1, 2]. For the bag-of-amino acids representation of the protein sequence (analagous to the bag-of-SMILES model for molecules) see the ‘protein fitness prediction - bag of amino acids notebook’. The protein dataset consists of 151 sequences with a ‘fitness’ function (target label) of the melting point in degrees Celcius. The dataset is collated from values reported in references [3,4,5]. The sequences are each of length 290 and so it is recommended that a GPU is used in conjunction with the SSK kernel.
In contrast to the bag of amino acids notebook, we do not report results on 20 random train/test splits because this would be too computationally intensive for the SSK kernel.
[1]:
"""Imports"""
# Turn off Graphein warnings
import warnings
warnings.filterwarnings("ignore")
from botorch import fit_gpytorch_model
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.models.fully_bayesian import MIN_INFERRED_NOISE_LEVEL
from gpytorch.constraints import GreaterThan
from gpytorch.kernels import ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import GammaPrior
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
import torch
from gauche.dataloader.data_utils import transform_data
from gauche.kernels.string_kernels.sskkernel import pad, encode_string, build_one_hot, SubsequenceStringKernel
[2]:
"""CPU/GPU"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tkwargs = {"dtype": torch.float, "device": device}
print(tkwargs)
{'dtype': torch.float32, 'device': device(type='cpu')}
The Petase Dataset#
The dataset consists of a set of petase protein sequences with amino acid chains of length 290. An example sequence is given below:
MNFPRASRLMQAAVLGGLMAVSAAATAQTNPYARGPPPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCS
For such long sequences the SSK kernel can struggle computationally and so a “bag of amino acids” model is also compared.
[3]:
"""Regression experiments parameters, number of random splits and split size"""
n_trials = 20
test_set_size = 0.2
[4]:
"""Load the petase dataset"""
import sys
sys.path.append('..')
df = pd.read_csv('../gauche/datasets/proteins/petase_151_mutants.csv')
x = df['sequence'].to_list()
y = df['fitness'].to_numpy().reshape(-1, 1)
print(f'len(sequences) {len(x)} | len(targets) {len(y)}')
len(sequences) 151 | len(targets) 151
[5]:
"""Compute the required sequence properties for modelling with the SSK kernel GP."""
maxlen = np.max([len(seq) for seq in x])
# get alphabet of characters used in candidate set (to init SSK)
alphabet = list({l for word in x for l in word})
print(f'alphabet \n {alphabet} \n length of alphabet {len(alphabet)}')
print(f'maxlen {maxlen}')
alphabet
['F', 'N', 'E', 'Y', 'A', 'K', 'S', 'P', 'M', 'G', 'R', 'I', 'H', 'Q', 'C', 'W', 'L', 'D', 'V', 'T']
length of alphabet 20
maxlen 290
GP Regression on the Petase Dataset#
First we define the GP model for protein sequences.
[6]:
"""Process the inputs x to the string kernel GPs"""
# Compute one-hot encodings and an integer index for the given amino acid alphabet
embds, index = build_one_hot(alphabet)
embds = embds.to(**tkwargs)
# Process the string inputs to the SSK model
x = torch.cat([pad(encode_string(seq, index), maxlen).unsqueeze(0) for seq in x], dim=0)
[7]:
"""Compute the train/test split."""
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=test_set_size, random_state=0)
X_train = X_train.to(**tkwargs)
X_test = X_test.to(**tkwargs)
y_train = torch.tensor(y_train, **tkwargs)
y_test = torch.tensor(y_test, **tkwargs)
[8]:
"""Intialize and fit the models"""
# Likelihood function
likelihood = GaussianLikelihood(
noise_prior=GammaPrior(torch.tensor(0.9, **tkwargs), torch.tensor(10.0, **tkwargs)),
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL),
)
# Covariance function
covar_module = ScaleKernel(SubsequenceStringKernel(embds, index, alphabet, maxlen, **tkwargs))
ssk_gp_model = SingleTaskGP(
train_X=X_train,
train_Y=y_train,
outcome_transform=Standardize(1),
likelihood=likelihood,
covar_module=covar_module,
)
mll = ExactMarginalLogLikelihood(model=ssk_gp_model, likelihood=ssk_gp_model.likelihood)
# ideally we can optimize over the kernel hyper-parameters of the string kernel
# however, the gpu memory usage in batch (GPU) version of the kernel is quite high
# while the standard non-batch version is relatively slow for kernel evaluation.
# Nevertheless, the kernel is very robust to choices of the different hypers.
mll.model.covar_module.base_kernel.raw_order_coefs.requires_grad = False
mll.model.covar_module.base_kernel.raw_match_decay.requires_grad = False
mll.model.covar_module.base_kernel.raw_gap_decay.requires_grad = False
fit_gpytorch_model(mll)
[8]:
ExactMarginalLogLikelihood(
(likelihood): GaussianLikelihood(
(noise_covar): HomoskedasticNoise(
(noise_prior): GammaPrior()
(raw_noise_constraint): GreaterThan(1.000E-04)
)
)
(model): SingleTaskGP(
(likelihood): GaussianLikelihood(
(noise_covar): HomoskedasticNoise(
(noise_prior): GammaPrior()
(raw_noise_constraint): GreaterThan(1.000E-04)
)
)
(mean_module): ConstantMean()
(covar_module): ScaleKernel(
(base_kernel): SubsequenceStringKernel(
(raw_gap_decay_constraint): Interval(0.000E+00, 1.000E+00)
(raw_match_decay_constraint): Interval(0.000E+00, 1.000E+00)
(raw_order_coefs_constraint): Interval(0.000E+00, 1.000E+00)
)
(raw_outputscale_constraint): Positive()
)
(outcome_transform): Standardize()
)
)
[9]:
"""Evaluate the trained model."""
posterior = ssk_gp_model.posterior(X_test)
posterior_mean = posterior.mean.cpu().detach()
posterior_std = torch.sqrt(posterior.variance.cpu().detach())
r2 = r2_score(y_test, posterior_mean.numpy())
print(mean_absolute_error(posterior_mean.squeeze(1), y_test.cpu().detach().squeeze(1)))
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[9], line 3
1 """Evaluate the trained model."""
----> 3 posterior = ssk_gp_model.posterior(X_test)
4 posterior_mean = posterior.mean.cpu().detach()
5 posterior_std = torch.sqrt(posterior.variance.cpu().detach())
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/botorch/models/gpytorch.py:383, in BatchedMultiOutputGPyTorchModel.posterior(self, X, output_indices, observation_noise, posterior_transform, **kwargs)
377 X, output_dim_idx = add_output_dim(
378 X=X, original_batch_shape=self._input_batch_shape
379 )
380 # NOTE: BoTorch's GPyTorchModels also inherit from GPyTorch's ExactGP, thus
381 # self(X) calls GPyTorch's ExactGP's __call__, which computes the posterior,
382 # rather than e.g. SingleTaskGP's forward, which computes the prior.
--> 383 mvn = self(X)
384 if observation_noise is not False:
385 if self._num_outputs > 1:
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:333, in ExactGP.__call__(self, *args, **kwargs)
328 # Make the prediction
329 with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
330 (
331 predictive_mean,
332 predictive_covar,
--> 333 ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
335 # Reshape predictive mean to match the appropriate event shape
336 predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:289, in DefaultPredictionStrategy.exact_prediction(self, joint_mean, joint_covar)
285 test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
286 test_train_covar = joint_covar[..., self.num_train :, : self.num_train]
288 return (
--> 289 self.exact_predictive_mean(test_mean, test_train_covar),
290 self.exact_predictive_covar(test_test_covar, test_train_covar),
291 )
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:306, in DefaultPredictionStrategy.exact_predictive_mean(self, test_mean, test_train_covar)
294 """
295 Computes the posterior predictive covariance of a GP
296
(...)
300 :return: The predictive posterior mean of the test points
301 """
302 # NOTE TO FUTURE SELF:
303 # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
304 # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
--> 306 if len(self.mean_cache.shape) == 4:
307 res = (test_train_covar @ self.mean_cache.squeeze(1).unsqueeze(-1)).squeeze(-1)
308 else:
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:256, in DefaultPredictionStrategy.mean_cache(self)
253 train_mean, train_train_covar = mvn.loc, mvn.lazy_covariance_matrix
255 train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1)
--> 256 mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1)
258 if settings.detach_test_caches.on():
259 mean_cache = mean_cache.detach()
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/linear_operator/operators/added_diag_linear_operator.py:209, in AddedDiagLinearOperator.evaluate_kernel(self)
208 def evaluate_kernel(self):
--> 209 added_diag_linear_op = self.representation_tree()(*self.representation())
210 return added_diag_linear_op._linear_op + added_diag_linear_op._diag_tensor
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:2064, in LinearOperator.representation_tree(self)
2054 def representation_tree(self) -> LinearOperatorRepresentationTree:
2055 """
2056 Returns a
2057 :obj:`linear_operator.operators.LinearOperatorRepresentationTree` tree
(...)
2062 including all subobjects. This is used internally.
2063 """
-> 2064 return LinearOperatorRepresentationTree(self)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/linear_operator/operators/linear_operator_representation_tree.py:15, in LinearOperatorRepresentationTree.__init__(self, linear_op)
13 for arg in itertools.chain(linear_op._args, linear_op._differentiable_kwargs.values()):
14 if hasattr(arg, "representation") and callable(arg.representation): # Is it a lazy tensor?
---> 15 representation_size = len(arg.representation())
16 self.children.append((slice(counter, counter + representation_size, None), arg.representation_tree()))
17 counter += representation_size
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:397, in LazyEvaluatedKernelTensor.representation(self)
393 return super().representation()
394 # Otherwise, we'll evaluate the kernel (or at least its LinearOperator representation) and use its
395 # representation
396 else:
--> 397 return self.evaluate_kernel().representation()
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25, in recall_grad_state.<locals>.wrapped(self, *args, **kwargs)
22 @functools.wraps(method)
23 def wrapped(self, *args, **kwargs):
24 with torch.set_grad_enabled(self._is_grad_enabled):
---> 25 output = method(self, *args, **kwargs)
26 return output
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:355, in LazyEvaluatedKernelTensor.evaluate_kernel(self)
353 temp_active_dims = self.kernel.active_dims
354 self.kernel.active_dims = None
--> 355 res = self.kernel(
356 x1,
357 x2,
358 diag=False,
359 last_dim_is_batch=self.last_dim_is_batch,
360 **self.params,
361 )
362 self.kernel.active_dims = temp_active_dims
364 # Check the size of the output
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/kernels/kernel.py:530, in Kernel.__call__(self, x1, x2, diag, last_dim_is_batch, **params)
527 res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params)
528 else:
529 res = to_linear_operator(
--> 530 super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
531 )
532 return res
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs)
30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31 outputs = self.forward(*inputs, **kwargs)
32 if isinstance(outputs, list):
33 return [_validate_module_outputs(output) for output in outputs]
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gpytorch/kernels/scale_kernel.py:109, in ScaleKernel.forward(self, x1, x2, last_dim_is_batch, diag, **params)
108 def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params):
--> 109 orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
110 outputscales = self.outputscale
111 if last_dim_is_batch:
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gauche/kernels/string_kernels/sskkernel.py:192, in SubsequenceStringKernel.forward(self, X1, X2, diag, **params)
188 K[batch_idx, :, :] = self._compute_kernel(
189 X1[batch_idx], X2[batch_idx]
190 )
191 else:
--> 192 K = self._compute_kernel(X1, X2, **params)
193 if diag is True:
194 return torch.diag(K)
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gauche/kernels/string_kernels/sskkernel.py:221, in SubsequenceStringKernel._compute_kernel(self, X1, X2, **params)
219 X1_batch = X1.index_select(dim=0, index=X1_batch_indicies)
220 X2_batch = X2.index_select(dim=0, index=X2_batch_indicies)
--> 221 k_result = self._k(X1_batch, X2_batch)
222 for j in range(0, len(tuples_batch)):
223 if (
224 self.normalize
225 and X1_diag_Ks[tuples_batch[j][0]] != 0
226 and X2_diag_Ks[tuples_batch[j][1]] != 0
227 ):
File ~/miniconda3/envs/gauche/lib/python3.11/site-packages/gauche/kernels/string_kernels/sskkernel.py:267, in SubsequenceStringKernel._k(self, s1, s2)
265 aux = aux2 * match_sq
266 aux = aux.transpose(-2, -1) @ self.D
--> 267 Kp.append(aux.transpose(-2, -1))
269 Kp = torch.cat([x.unsqueeze(0) for x in Kp], dim=0)
270 final_aux1 = S * Kp
KeyboardInterrupt:
[ ]:
"""Plot the R^2"""
fig, ax = plt.subplots(1, 2, figsize = (16, 6))
ax = ax.reshape(-1)
ax.scatter(y_test, posterior_mean.numpy())
ax.set_title(f'Test set $R^2 = {r2:.2f}$')
ax.plot(np.unique(y_test), np.poly1d(np.polyfit(y_test, posterior_mean.numpy(), 1)(np.unique(y_test)), color='k', linewidth=0.4))
References#
[1] Lodhi, H., Saunders, C., Shawe-Taylor, J., Cristianini, N. and Watkins, C., 2002. Text classification using string kernels. The Journal of Machine Learning Research, pp.419-444.
[2] Cancedda, N., Gaussier, E., Goutte, C. and Renders, J.M., 2003. Word sequence kernels. The Journal of Machine Learning Research, pp.1059-1082.
[3] Cui, Y., Chen, Y., Liu, X., Dong, S., Tian, Y.E., Qiao, Y., Mitra, R., Han, J., Li, C., Han, X. and Liu, W., 2021. Computational redesign of a PETase for plastic biodegradation under ambient condition by the GRAPE strategy. ACS Catalysis, 11(3), pp.1340-1350.
[4] Liu, B., He, L., Wang, L., Li, T., Li, C., Liu, H., Luo, Y. and Bao, R., 2018. Protein crystallography and site‐direct mutagenesis analysis of the poly (ethylene terephthalate) hydrolase PETase from Ideonella sakaiensis. ChemBioChem, 19(14), pp.1471-1475.
[5] Joo, S., Cho, I.J., Seo, H., Son, H.F., Sagong, H.Y., Shin, T.J., Choi, S.Y., Lee, S.Y. and Kim, K.J., 2018. Structural insight into molecular mechanism of poly (ethylene terephthalate) degradation. Nature communications, 9(1), p.382.