from atai.core import *
atai
Atomic AI is a flexible, minimalist deep neural network training framework based on Jeremy Howard’s miniai from the fast.ai 2022 course.
Install
pip install atai
How to use
The following example demonstrates how the Atomic AI training framework can be used to train a custom model that predicts protein solubility.
Imports
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import init
from torch import optim
from torcheval.metrics import BinaryAccuracy, BinaryAUROC
from torcheval.metrics.functional import binary_auroc, binary_accuracy
from torchmetrics.classification import BinaryMatthewsCorrCoef
from torchmetrics.functional.classification import binary_matthews_corrcoef
import fastcore.all as fc
from functools import partial
Load Protein Solubility
This example uses the dataset from the DeepSol paper by Khurana et al. which was obtained at https://zenodo.org/records/1162886. It consists of amino acid sequences of peptides along with solubility labels that are 1
if the peptide is soluble and 0
if the peptide is insoluble.
= open('sol_data/train_src', 'r').read().splitlines()
train_sqs = list(map(int, open('sol_data/train_tgt', 'r').read().splitlines()))
train_tgs
= open('sol_data/val_src', 'r').read().splitlines()
valid_sqs = list(map(int, open('sol_data/val_tgt', 'r').read().splitlines()))
valid_tgs
2], train_tgs[:2] train_sqs[:
(['GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK',
'MAHHHHHHMSFFRMKRRLNFVVKRGIEELWENSFLDNNVDMKKIEYSKTGDAWPCVLLRKKSFEDLHKLYYICLKEKNKLLGEQYFHLQNSTKMLQHGRLKKVKLTMKRILTVLSRRAIHDQCLRAKDMLKKQEEREFYEIQKFKLNEQLLCLKHKMNILKKYNSFSLEQISLTFSIKKIENKIQQIDIILNPLRKETMYLLIPHFKYQRKYSDLPGFISWKKQNIIALRNNMSKLHRLY'],
[1, 0])
len(train_sqs), len(train_tgs), len(valid_sqs), len(valid_tgs)
(62478, 62478, 6942, 6942)
Data Preparation
Create a sorted list of amino acid sequences aas
including an empty string for padding and determine the size of the vocabulary.
= sorted(list(set("".join(train_sqs))) + [""])
aas = len(aas)
vocab_size aas, vocab_size
(['',
'A',
'C',
'D',
'E',
'F',
'G',
'H',
'I',
'K',
'L',
'M',
'N',
'P',
'Q',
'R',
'S',
'T',
'V',
'W',
'Y'],
21)
Create dictionaries that translate between string and integer representations of amino acids and define the corresponding encode
and decode
functions.
= {aa:i for i, aa in enumerate(aas)}
str2int = {i:aa for i, aa in enumerate(aas)}
int2str = lambda s: [str2int[aa] for aa in s]
encode = lambda l: ''.join([int2str[i] for i in l])
decode
print(encode("AYWCCCGGGHH"))
print(decode(encode("AYWCCCGGGHH")))
[1, 20, 19, 2, 2, 2, 6, 6, 6, 7, 7]
AYWCCCGGGHH
Figure out what the range of lengths of amino acid sequences in the dataset is.
= list(map(len, train_sqs))
train_lens min(train_lens), max(train_lens)
(19, 1691)
Create a function that drops all sequences above a chosen threshold and also returns a list of indices of the sequences that meet the threshold that can be used to obtain the correct labels.
def drop_long_sqs(sqs, threshold=1200):
= []
new_sqs = []
idx for i, sq in enumerate(sqs):
if len(sq) <= threshold:
new_sqs.append(sq)
idx.append(i)return new_sqs, idx
Drop all sequences above your chosen threshold.
= drop_long_sqs(train_sqs, threshold=200)
trnsqs, trnidx = drop_long_sqs(valid_sqs, threshold=200) vldsqs, vldidx
len(trnidx), len(vldidx)
(18066, 1971)
max(map(len, trnsqs))
200
Create a function for zero padding all sequences.
def zero_pad(sq, length=1200):
= sq.copy()
new_sq if len(new_sq) < length:
0] * (length-len(new_sq)))
new_sq.extend([return new_sq
Now encode and zero pad all sequences and make sure that it worked out correctly.
= list(map(encode, trnsqs))
trn = list(map(encode, vldsqs))
vld print(f"Length of the first two sequences before zero padding: {len(trn[0])}, {len(trn[1])}")
= list(map(partial(zero_pad, length=200), trn))
trn = list(map(partial(zero_pad, length=200), vld))
vld print(f"Length of the first two sequences after zero padding: {len(trn[0])}, {len(trn[1])}");
Length of the first two sequences before zero padding: 116, 135
Length of the first two sequences after zero padding: 200, 200
Convert the data to torch.tensor
s unsing dtype=torch.int64
and check for correctness.
= torch.tensor(trn, dtype=torch.int64)
trntns = torch.tensor(vld, dtype=torch.int64)
vldtns 0] trntns.shape, trntns[
(torch.Size([18066, 200]),
tensor([11, 9, 1, 10, 2, 10, 10, 10, 10, 13, 18, 10, 6, 10, 10, 18, 16, 16, 9, 17, 10, 2, 16, 11, 4, 4, 1, 8, 12, 4, 15, 8, 14,
4, 18, 1, 6, 16, 10, 8, 5, 15, 1, 8, 16, 16, 8, 6, 10, 4, 2, 14, 16, 18, 17, 16, 15, 6, 3, 10, 1, 17, 2, 13, 15, 6,
5, 1, 18, 17, 6, 2, 17, 2, 6, 16, 1, 2, 6, 16, 19, 3, 18, 15, 1, 4, 17, 17, 2, 7, 2, 14, 2, 1, 6, 11, 3, 19, 17,
6, 1, 15, 2, 2, 15, 18, 14, 13, 10, 4, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]))
trntns.shape, vldtns.shape
(torch.Size([18066, 200]), torch.Size([1971, 200]))
Obtain the correct labels using the lists of indices obtained from the drop_long_sqs
function and convert the lists of labels to tensors in torch.float32
format.
= torch.tensor(train_tgs, dtype=torch.float32)[trnidx]
trnlbs = torch.tensor(valid_tgs, dtype=torch.float32)[vldidx]
vldlbs trnlbs.shape, vldlbs.shape
(torch.Size([18066]), torch.Size([1971]))
Calculate the ratios of soluble peptides in the train and valid data.
sum().item()/trnlbs.shape[0], vldlbs.sum().item()/vldlbs.shape[0] trnlbs.
(0.4722129967895494, 0.4657534246575342)
These ratios tell us that there are slightly less than half soluble proteins in the training an validation data, and slightly more than half in the test set.
Dataset and DataLoaders
Turn train and valid data into datasets using the Dataset
class.
= Dataset(trntns, trnlbs)
trnds = Dataset(vldtns, vldlbs)
vldds 0] trnds[
(tensor([11, 9, 1, 10, 2, 10, 10, 10, 10, 13, 18, 10, 6, 10, 10, 18, 16, 16, 9, 17, 10, 2, 16, 11, 4, 4, 1, 8, 12, 4, 15, 8, 14,
4, 18, 1, 6, 16, 10, 8, 5, 15, 1, 8, 16, 16, 8, 6, 10, 4, 2, 14, 16, 18, 17, 16, 15, 6, 3, 10, 1, 17, 2, 13, 15, 6,
5, 1, 18, 17, 6, 2, 17, 2, 6, 16, 1, 2, 6, 16, 19, 3, 18, 15, 1, 4, 17, 17, 2, 7, 2, 14, 2, 1, 6, 11, 3, 19, 17,
6, 1, 15, 2, 2, 15, 18, 14, 13, 10, 4, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]),
tensor(0.))
Use the get_dls
function to obtain the dataloaders from the train and valid datasets.
= get_dls(trnds, vldds, bs=32)
dls next(iter(dls.train))[0][:2], next(iter(dls.train))[1][:2]
(tensor([[11, 1, 7, 7, 7, 7, 7, 7, 11, 11, 13, 15, 16, 4, 12, 14, 9, 4, 4, 4, 4, 19, 4, 10, 7, 13, 10, 5, 10, 16, 9, 8, 13,
12, 9, 9, 3, 8, 3, 9, 12, 13, 1, 10, 16, 1, 10, 8, 17, 10, 8, 12, 4, 4, 4, 4, 9, 4, 18, 5, 16, 20, 4, 13, 15, 15,
9, 12, 10, 9, 9, 9, 8, 17, 4, 9, 6, 14, 8, 8, 20, 9, 9, 3, 15, 15, 12, 10, 20, 4, 13, 20, 14, 14, 12, 9, 12, 3, 9,
8, 12, 20, 5, 4, 9, 9, 12, 5, 6, 6, 12, 1, 3, 8, 16, 9, 4, 4, 4, 18, 10, 3, 18, 4, 11, 3, 4, 4, 6, 17, 17, 18,
17, 17, 1, 4, 14, 6, 6, 3, 7, 16, 16, 14, 12, 18, 16, 12, 12, 14, 4, 1, 17, 3, 14, 17, 16, 8, 6, 4, 18, 10, 18, 2, 10,
16, 11, 11, 12, 10, 9, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[11, 10, 1, 11, 18, 12, 9, 18, 10, 3, 19, 8, 15, 16, 10, 5, 19, 9, 4, 4, 11, 4, 10, 17, 10, 18, 6, 10, 14, 12, 16, 6, 9,
17, 17, 5, 18, 12, 18, 8, 1, 16, 6, 14, 5, 17, 4, 3, 11, 8, 13, 17, 18, 6, 5, 12, 11, 15, 9, 8, 17, 9, 6, 12, 18, 17,
8, 9, 10, 19, 3, 8, 6, 6, 14, 13, 15, 5, 15, 16, 11, 19, 4, 15, 20, 2, 15, 6, 18, 12, 1, 8, 18, 5, 11, 18, 3, 1, 1,
3, 4, 4, 9, 10, 4, 1, 16, 15, 12, 4, 10, 11, 14, 10, 10, 3, 9, 13, 14, 10, 3, 1, 8, 13, 18, 10, 18, 10, 6, 12, 9, 9,
3, 10, 13, 6, 1, 10, 3, 4, 15, 14, 10, 8, 4, 15, 11, 12, 10, 16, 16, 8, 14, 12, 15, 4, 8, 2, 2, 20, 16, 8, 16, 2, 9,
4, 9, 4, 12, 8, 3, 8, 17, 10, 14, 19, 10, 8, 3, 7, 16, 9, 1, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]]),
tensor([1., 0.]))
Design Your Model
Let’s create a tiny model (~10k parameters) that uses a sequence of 1-dimensional convolutional layers with skip connections, kaiming he initialization, leaky relus, batchnorm, and dropout.
First, obtain a single batch from dls
to help design the model.
= next(iter(dls.train))[0] ## a single batch
idx idx, idx.shape
(tensor([[11, 9, 17, ..., 0, 0, 0],
[16, 12, 1, ..., 0, 0, 0],
[16, 1, 14, ..., 0, 0, 0],
...,
[11, 1, 7, ..., 0, 0, 0],
[11, 3, 13, ..., 0, 0, 0],
[16, 17, 12, ..., 0, 0, 0]]),
torch.Size([32, 200]))
Custom Modules
def conv1d(ni, nf, ks=3, stride=2, act=nn.ReLU, norm=None, bias=None):
if bias is None: bias = not isinstance(norm, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d))
= [nn.Conv1d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias)]
layers if norm: layers.append(norm(nf))
if act: layers.append(act())
return nn.Sequential(*layers)
def _conv1d_block(ni, nf, stride, act=nn.ReLU, norm=None, ks=3):
return nn.Sequential(conv1d(ni, nf, stride=1, act=act, norm=norm, ks=ks),
=stride, act=None, norm=norm, ks=ks))
conv1d(nf, nf, stride
class ResBlock1d(nn.Module):
def __init__(self, ni, nf, stride=1, ks=3, act=nn.ReLU, norm=None):
super().__init__()
self.convs = _conv1d_block(ni, nf, stride=stride, ks=ks, act=act, norm=norm)
self.idconv = fc.noop if ni==nf else conv1d(ni, nf, stride=1, ks=1, act=None)
self.pool = fc.noop if stride==1 else nn.AvgPool1d(stride, ceil_mode=True)
self.act = act()
def forward(self, x): return self.act(self.convs(x) + self.pool(self.idconv(x)))
The following module switches the rank order from BLC to BCL.
class Reshape(nn.Module):
def forward(self, x):
= x.shape
B, L, C return x.view(B, C, L)
Model Architecture
= 1e-2
lr = 30
epochs = 16
n_embd = get_dls(trnds, vldds, bs=32)
dls = partial(GeneralRelu, leak=0.1, sub=0.4)
act_genrelu
= nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),
model 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(n_embd, 2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(1, -1),
nn.Flatten(32, 1),
nn.Linear(0, -1),
nn.Flatten(
nn.Sigmoid()) model(idx).shape
torch.Size([32])
= partial(init_weights, leaky=0.1)
iw = model.apply(iw)
model = MetricsCB(BinaryAccuracy(), BinaryMatthewsCorrCoef(), BinaryAUROC())
metrics = ActivationStats(fc.risinstance(GeneralRelu))
astats = [DeviceCB(), ProgressCB(plot=False), metrics, astats]
cbs = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)
learn print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
=1e-4, gamma=1.05, av_over=3, max_mult=5) learn.lr_find(start_lr
Parameters total: 10175
This is a pretty noisy training set, so the learning rate finder does not work very well. Yet it is possible to get a somewhat informative result using the av_over
keyword argument that tells lr_find
to average over the specified number of batches for each learning rate tested. It also helps to dial the gamma
value down from its default value of 1.3
.
Training
learn.fit(epochs)
BinaryAccuracy | BinaryMatthewsCorrCoef | BinaryAUROC | loss | epoch | train |
---|---|---|---|---|---|
0.510 | 0.008 | 0.507 | 0.722 | 0 | train |
0.527 | -0.033 | 0.494 | 0.691 | 0 | eval |
0.515 | 0.004 | 0.506 | 0.695 | 1 | train |
0.534 | 0.003 | 0.520 | 0.691 | 1 | eval |
0.537 | 0.054 | 0.526 | 0.692 | 2 | train |
0.530 | 0.058 | 0.550 | 0.693 | 2 | eval |
0.553 | 0.090 | 0.551 | 0.688 | 3 | train |
0.562 | 0.108 | 0.575 | 0.684 | 3 | eval |
0.589 | 0.170 | 0.607 | 0.671 | 4 | train |
0.619 | 0.234 | 0.663 | 0.647 | 4 | eval |
0.622 | 0.247 | 0.630 | 0.653 | 5 | train |
0.643 | 0.302 | 0.676 | 0.637 | 5 | eval |
0.627 | 0.260 | 0.642 | 0.650 | 6 | train |
0.649 | 0.313 | 0.675 | 0.628 | 6 | eval |
0.633 | 0.272 | 0.647 | 0.644 | 7 | train |
0.650 | 0.309 | 0.648 | 0.637 | 7 | eval |
0.634 | 0.276 | 0.652 | 0.640 | 8 | train |
0.650 | 0.314 | 0.683 | 0.620 | 8 | eval |
0.637 | 0.284 | 0.646 | 0.641 | 9 | train |
0.651 | 0.320 | 0.685 | 0.617 | 9 | eval |
0.639 | 0.287 | 0.661 | 0.636 | 10 | train |
0.645 | 0.308 | 0.671 | 0.647 | 10 | eval |
0.638 | 0.281 | 0.663 | 0.636 | 11 | train |
0.646 | 0.298 | 0.690 | 0.620 | 11 | eval |
0.641 | 0.286 | 0.670 | 0.632 | 12 | train |
0.648 | 0.296 | 0.685 | 0.619 | 12 | eval |
0.641 | 0.290 | 0.668 | 0.632 | 13 | train |
0.662 | 0.323 | 0.691 | 0.618 | 13 | eval |
0.643 | 0.290 | 0.675 | 0.630 | 14 | train |
0.641 | 0.286 | 0.677 | 0.627 | 14 | eval |
0.643 | 0.289 | 0.676 | 0.630 | 15 | train |
0.659 | 0.311 | 0.699 | 0.616 | 15 | eval |
0.644 | 0.291 | 0.677 | 0.629 | 16 | train |
0.652 | 0.314 | 0.690 | 0.613 | 16 | eval |
0.646 | 0.296 | 0.675 | 0.626 | 17 | train |
0.645 | 0.282 | 0.694 | 0.622 | 17 | eval |
0.642 | 0.288 | 0.678 | 0.626 | 18 | train |
0.648 | 0.332 | 0.677 | 0.639 | 18 | eval |
0.645 | 0.292 | 0.685 | 0.625 | 19 | train |
0.634 | 0.260 | 0.698 | 0.615 | 19 | eval |
0.649 | 0.302 | 0.689 | 0.621 | 20 | train |
0.651 | 0.344 | 0.710 | 0.617 | 20 | eval |
0.648 | 0.299 | 0.685 | 0.624 | 21 | train |
0.660 | 0.315 | 0.700 | 0.614 | 21 | eval |
0.648 | 0.297 | 0.691 | 0.620 | 22 | train |
0.563 | 0.168 | 0.679 | 0.672 | 22 | eval |
0.651 | 0.303 | 0.690 | 0.620 | 23 | train |
0.654 | 0.330 | 0.710 | 0.611 | 23 | eval |
0.650 | 0.304 | 0.691 | 0.620 | 24 | train |
0.668 | 0.344 | 0.711 | 0.599 | 24 | eval |
0.654 | 0.311 | 0.692 | 0.617 | 25 | train |
0.649 | 0.294 | 0.698 | 0.620 | 25 | eval |
0.650 | 0.301 | 0.690 | 0.617 | 26 | train |
0.642 | 0.320 | 0.697 | 0.611 | 26 | eval |
0.650 | 0.303 | 0.688 | 0.620 | 27 | train |
0.663 | 0.334 | 0.708 | 0.625 | 27 | eval |
0.652 | 0.308 | 0.694 | 0.616 | 28 | train |
0.672 | 0.356 | 0.711 | 0.597 | 28 | eval |
0.651 | 0.304 | 0.695 | 0.617 | 29 | train |
0.659 | 0.322 | 0.702 | 0.603 | 29 | eval |
Inspect Activations
= get_dls(trnds, vldds, bs=256)
dls
= nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),
model 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(n_embd, 2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(1, -1),
nn.Flatten(32, 1),
nn.Linear(0, -1),
nn.Flatten(
nn.Sigmoid())
= model.apply(iw)
model = ActivationStats(fc.risinstance(GeneralRelu))
astats = [DeviceCB(), astats]
cbs = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)
learn print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
1) learn.fit(
Parameters total: 10175
astats.color_dim()
astats.plot_stats()
astats.dead_chart()