core

The core package contains all functions, variables, and classes needed to train a deep neural network model and inspect its activations.

Load Protein Solubility Data

The example dataset is from the DeepSol paper by Khurana et al. and was obtained at https://zenodo.org/records/1162886.

train_sqs = open('sol_data/train_src', 'r').read().splitlines()
train_tgs = list(map(int, open('sol_data/train_tgt', 'r').read().splitlines()))
train_sqs[:2], train_tgs[:2]
(['GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK',
  'MAHHHHHHMSFFRMKRRLNFVVKRGIEELWENSFLDNNVDMKKIEYSKTGDAWPCVLLRKKSFEDLHKLYYICLKEKNKLLGEQYFHLQNSTKMLQHGRLKKVKLTMKRILTVLSRRAIHDQCLRAKDMLKKQEEREFYEIQKFKLNEQLLCLKHKMNILKKYNSFSLEQISLTFSIKKIENKIQQIDIILNPLRKETMYLLIPHFKYQRKYSDLPGFISWKKQNIIALRNNMSKLHRLY'],
 [1, 0])
valid_sqs = open('sol_data/val_src', 'r').read().splitlines()
valid_tgs = list(map(int, open('sol_data/val_tgt', 'r').read().splitlines()))
valid_sqs[:2], valid_tgs[:2]
(['SRLYRHNLMEDVFNMENESFMQETRLMENEYSVNLPTRFYYKKRWNNGFVNIVNIFRACMVIGTPGSGKSYAIVNSYIRQLIAKGFAIYIYDYKFDDLSTIAYNSLLKNMDKYEVKPRFYVINFDDPRRSHRCNPINPEFMTDISDAYEASYTIMLNLNRTWIEKQGDFFVESPIILLAAIIWYLKIYKNGIYCTFPHAVELLNKPYSDLFTILTSYPELENYLSPFMDAWKGNAQDQLQGQIASAKIPLTRMISPQLYWVMTGNDFSLDINNPKEPKLLCVGNNPDRQNIYSAALGLYNSRIVKLINKKKQLKCAVIIDELPTIYFRGLDNLIATARSNKVGVLLGFQDFSQLTRDYGEKESKVIQNTVGNIFSGQVVGETAKTLSERFGKVLQQRQSVSINRQDVSTSINTQLDSLIPASKIANLSQGTFVGAVADNFDERIEQKIFHAEIVVDHTKISAEEKAYQKIPVINDFKDRNGNDIMMQQIQRNYDQIKADAQAIINEEMRRIKNDPELRKRLGLEDEKGKDPDKS',
  'ATTYNAVVSKSSSDGKTFKTIADAIASAPAGSTPFVILIKNGVYNERLTITRNNLHLKGESRNGAVIAAATAAGTLKSDGSKWGTAGSSTITISAKDFSAQSLTIRNDFDFPANQAKSDSDSSKIKDTQAVALYVTKSGDRAYFKDVSLVGYQATLYVSGGRSFFSDCRISGTVDFIFGDGTALFNNCDLVSRYRADVKSGNVSGYLTAPSTNINQKYGLVITNSRVIRESDSVPAKSYGLGRPWHPTTTFSDGRYADPNAIGQTVFLNTSMDNHIYGWDKMSGKDKNGNTIWFNPEDSRFFEYKSYGAGATVSKDRRQLTDAQAAEYTQSKVLGDWTPTLP'],
 [0, 1])
test_sqs = open('sol_data/test_src', 'r').read().splitlines()
test_tgs = list(map(int, open('sol_data/test_tgt', 'r').read().splitlines()))
test_sqs[:2], test_tgs[:2]
(['MLSVRIAAAVARALPRRAGLVSKNALGSSFVGTRNLHASNTRLQKTGTAEMSSILEERILGADTSVDLEETGRVLSIGDGIARVHGLRNVQAEEMVEFSSGLKGMSLNLEPDNVGVVVFGNDKLIKEGDIVKRTGAIVDVPVGDELLGRVVDALGNAIDGKGPVGSKIRRRVGLKAPGIIPRISVREPMQTGIKAVDSLVPIGRGQRELIIGDRQTGKTSIAIDTIINQKRFNDGTDEKKKLYCIYVAIGQKRSTVAQLVKRLTDADAMKYTIVVSATASDAAPLQYLAPYSGCSMGEYFRDNGKHALIIYDDLSKQAVAYRQMSLLLRRPPGREAYPGDVFYLHSRLLERAAKMNDSFGGGSLTALPVIETQAGDVSAYIPTNVISITDGQIFLETELFYKGIRPAINVGLSVSRVGSAAQTRAMKQVAGTMKLELAQYREVAAFAQFGSDLDAATQQLLSRGVRLTELLKQGQYSPMAIEEQVAVIYAGVRGYLDKLEPSKITKFESAFLSHVVSQHQSLLGNIRSDGKISEQSDAKLKEIVTNFLAGFEP',
  'MDHMISENGETSAEGSICGYDSLHQLLSANLKPELYQEVNRLLLGRNCGRSLEQIVLPESAKALSSKHDFDLQAASFSADKEQMRNPRVVRVGLIQNSIALPTTAPFSDQTRGIFDKLKPIIDAAGVAGVNILCLQEAWTMPFAFCTRERRWCEFAEPVDGESTKFLQELAKKYNMVIVSPILERDIDHGEVLWNTAVIIGNNGNIIGKHRKNHIPRVGDFNESTYYMEGDTGHPVFETVFGKIAVNICYGRHHPLNWLAFGLNGAEIVFNPSATVGELSEPMWPIEARNAAIANSYFVGSINRVGTEVFPNPFTSGDGKPQHNDFGHFYGSSHFSAPDASCTPSLSRYKDGLLISDMDLNLCRQYKDKWGFRMTARYEVYADLLAKYIKPDFKPQVVSDPLLHKNST'],
 [1, 1])
len(train_sqs), len(train_tgs), len(valid_sqs), len(valid_tgs), len(test_sqs), len(test_tgs)
(62478, 62478, 6942, 6942, 1999, 1999)

Create a sorted list of amino acid sequences aas including an empty string for padding and determine the size of the vocabulary.

aas = sorted(list(set("".join(train_sqs))) + [""])
vocab_size = len(aas)
aas, vocab_size
(['',
  'A',
  'C',
  'D',
  'E',
  'F',
  'G',
  'H',
  'I',
  'K',
  'L',
  'M',
  'N',
  'P',
  'Q',
  'R',
  'S',
  'T',
  'V',
  'W',
  'Y'],
 21)

Create dictionaries that translate between string and integer representations of amino acids and define the corresponding encode and decode functions.

str2int = {aa:i for i, aa in enumerate(aas)}
int2str = {i:aa for i, aa in enumerate(aas)}
encode = lambda s: [str2int[aa] for aa in s]
decode = lambda l: ''.join([int2str[i] for i in l])

print(encode("AYWCCCGGGHH"))
print(decode(encode("AYWCCCGGGHH")))
[1, 20, 19, 2, 2, 2, 6, 6, 6, 7, 7]
AYWCCCGGGHH

Figure out what the lengths of amino acid sequences in the dataset are and inspect the longest sequence.

train_lens = list(map(len, train_sqs))
max(train_lens)
1691
longest = train_sqs[np.argmax(train_lens)]
longest
'MSGEVRLRQLEQFILDGPAQTNGQCFSVETLLDILICLYDECNNSPLRREKNILEYLEWAKPFTSKVKQMRLHREDFEILKVIGRGAFGEVAVVKLKNADKVFAMKILNKWEMLKRAETACFREERDVLVNGDNKWITTLHYAFQDDNNLYLVMDYYVGGDLLTLLSKFEDRLPEDMARFYLAEMVIAIDSVHQLHYVHRDIKPDNILMDMNGHIRLADFGSCLKLMEDGTVQSSVAVGTPDYISPEILQAMEDGKGRYGPECDWWSLGVCMYEMLYGETPFYAESLVETYGKIMNHKERFQFPAQVTDVSENAKDLIRRLICSREHRLGQNGIEDFKKHPFFSGIDWDNIRNCEAPYIPEVSSPTDTSNFDVDDDCLKNSETMPPPTHTAFSGHHLPFVGFTYTSSCVLSDRSCLRVTAGPTSLDLDVNVQRTLDNNLATEAYERRIKRLEQEKLELSRKLQESTQTVQALQYSTVDGPLTASKDLEIKNLKEEIEKLRKQVTESSHLEQQLEEANAVRQELDDAFRQIKAYEKQIKTLQQEREDLNKELVQASERLKNQSKELKDAHCQRKLAMQEFMEINERLTELHTQKQKLARHVRDKEEEVDLVMQKVESLRQELRRTERAKKELEVHTEALAAEASKDRKLREQSEHYSKQLENELEGLKQKQISYSPGVCSIEHQQEITKLKTDLEKKSIFYEEELSKREGIHANEIKNLKKELHDSEGQQLALNKEIMILKDKLEKTRRESQSEREEFESEFKQQYEREKVLLTEENKKLTSELDKLTTLYENLSIHNQQLEEEVKDLADKKESVAHWEAQITEIIQWVSDEKDARGYLQALASKMTEELEALRNSSLGTRATDMPWKMRRFAKLDMSARLELQSALDAEIRAKQAIQEELNKVKASNIITECKLKDSEKKNLELLSEIEQLIKDTEELRSEKGIEHQDSQHSFLAFLNTPTDALDQFERKTHQFFVKSFTTPTKCHQCTSLMVGLIRQGCSCEVCGFSCHITCVNKAPTTCPVPPEQTKGPLGIDPQKGIGTAYEGHVRIPKPAGVKKGWQRALAIVCDFKLFLYDIAEGKASQPSVVISQVIDMRDEEFSVSSVLASDVIHASRKDIPCIFRVTASQLSASNNKCSILMLADTENEKNKWVGVLSELHKILKKNKFRDRSVYVPKEAYDSTLPLIKTTQAAAIIDHERIALGNEEGLFVVHVTKDEIIRVGDNKKIHQIELIPNDQLVAVISGRNRHVRLFPMSALDGRETDFYKLSETKGCQTVTSGKVRHGALTCLCVAMKRQVLCYELFQSKTRHRKFKEIQVPYNVQWMAIFSEQLCVGFQSGFLRYPLNGEGNPYSMLHSNDHTLSFIAHQPMDAICAVEISSKEYLLCFNSIGIYTDCQGRRSRQQELMWPANPSSCCYNAPYLSVYSENAVDIFDVNSMEWIQTLPLKKVRPLNNEGSLNLLGLETIRLIYFKNKMAEGDELVVPETSDNSRKQMVRNINNKRRYSFRVPEEERMQQRREMLRDPEMRNKLISNPTNFNHIAHMGPGDGIQILKDLPMNPRPQESRTVFSGSVSIPSITKSRPEPGRSMSASSGLSARSSAQNGSALKREFSGGSYSAKRQPMPSPSEGSLSSGGMDQGSDAPARDFDGEDSDSPRHSTASNSSNLSSPPSPVSPRKTKSLSLESTDRGSWDP'

Check how many sequences in the training set are longer than 1200 amino acids.

long_sqs = []
for sq in train_sqs:
    if len(sq) > 1200:
        long_sqs.append(sq)
len(long_sqs)
132

Create a function that drops all sequences above a chosen threshold and also returns a list of indices of the sequences that meet the threshold that can be used to obtain the correct labels.

def drop_long_sqs(sqs, threshold=1200):
    new_sqs = []
    idx = []
    for i, sq in enumerate(sqs):
        if len(sq) <= threshold:
            new_sqs.append(sq)
            idx.append(i)
    return new_sqs, idx

Drop all sequences above your threshold.

trnsqs, trnidx = drop_long_sqs(train_sqs, threshold=200)
vldsqs, vldidx = drop_long_sqs(valid_sqs, threshold=200)
tstsqs, tstidx = drop_long_sqs(test_sqs, threshold=200)
len(trnidx), len(vldidx), len(tstidx)
(18066, 1971, 699)

Make sure that it worked.

trnls = map(len, trnsqs)
vldls = map(len, vldsqs)
tstls = map(len, tstsqs)
max(trnls), max(vldls), max(tstls)
(200, 200, 200)

Create a function for zero padding all sequences.

def zero_pad(sq, length=1200):
    new_sq = sq.copy()
    if len(new_sq) < length:
        new_sq.extend([0] * (length-len(new_sq)))
    return new_sq

Now encode and zero pad all sequences and make sure that it worked out correctly.

trn = list(map(encode, trnsqs))
vld = list(map(encode, vldsqs))
tst = list(map(encode, tstsqs))
print(f"Length of the first two sequences before zero padding: {len(trn[0])}, {len(trn[1])}")
trn = list(map(partial(zero_pad, length=200), trn))
vld = list(map(partial(zero_pad, length=200), vld))
tst = list(map(partial(zero_pad, length=200), tst))
print(f"Length of the first two sequences after zero padding:  {len(trn[0])}, {len(trn[1])}");
Length of the first two sequences before zero padding: 116, 135
Length of the first two sequences after zero padding:  200, 200

Convert the data to torch.tensors unsing dtype=torch.int64 and check for correctness.

trntns = torch.tensor(trn, dtype=torch.int64)
vldtns = torch.tensor(vld, dtype=torch.int64)
tsttns = torch.tensor(tst, dtype=torch.int64)
trntns.shape, trntns[0]
(torch.Size([18066, 200]),
 tensor([11,  9,  1, 10,  2, 10, 10, 10, 10, 13, 18, 10,  6, 10, 10, 18, 16, 16,  9, 17, 10,  2, 16, 11,  4,  4,  1,  8, 12,  4, 15,  8, 14,
          4, 18,  1,  6, 16, 10,  8,  5, 15,  1,  8, 16, 16,  8,  6, 10,  4,  2, 14, 16, 18, 17, 16, 15,  6,  3, 10,  1, 17,  2, 13, 15,  6,
          5,  1, 18, 17,  6,  2, 17,  2,  6, 16,  1,  2,  6, 16, 19,  3, 18, 15,  1,  4, 17, 17,  2,  7,  2, 14,  2,  1,  6, 11,  3, 19, 17,
          6,  1, 15,  2,  2, 15, 18, 14, 13, 10,  4,  7,  7,  7,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]))
trntns.shape, vldtns.shape, tsttns.shape
(torch.Size([18066, 200]), torch.Size([1971, 200]), torch.Size([699, 200]))

Obtain the correct labels using the lists of indices obtained from the drop_long_sqs function and convert the lists of labels to tensors in torch.float32 format.

trnlbs = torch.tensor(train_tgs, dtype=torch.float32)[trnidx]
vldlbs = torch.tensor(valid_tgs, dtype=torch.float32)[vldidx]
tstlbs = torch.tensor(test_tgs, dtype=torch.float32)[tstidx]
trnlbs.shape, vldlbs.shape, tstlbs.shape
(torch.Size([18066]), torch.Size([1971]), torch.Size([699]))
trnlbs.sum().item()/trnlbs.shape[0], vldlbs.sum().item()/vldlbs.shape[0], tstlbs.sum().item()/tstlbs.shape[0]
(0.4722129967895494, 0.4657534246575342, 0.5665236051502146)

Above ratios tell us that there are slightly less than half soluble proteins in the training an validation data, and slightly more than half in the test set.

Dataset and DataLoaders

Create a Dataset class and combine tokens and labels into datasets.


source

Dataset

 Dataset (x, y)

Combines features and lables in a dataset.

trnds = Dataset(trntns, trnlbs)
vldds = Dataset(vldtns, vldlbs)
tstds = Dataset(tsttns, tstlbs)
trnds[0]
(tensor([11,  9,  1, 10,  2, 10, 10, 10, 10, 13, 18, 10,  6, 10, 10, 18, 16, 16,  9, 17, 10,  2, 16, 11,  4,  4,  1,  8, 12,  4, 15,  8, 14,
          4, 18,  1,  6, 16, 10,  8,  5, 15,  1,  8, 16, 16,  8,  6, 10,  4,  2, 14, 16, 18, 17, 16, 15,  6,  3, 10,  1, 17,  2, 13, 15,  6,
          5,  1, 18, 17,  6,  2, 17,  2,  6, 16,  1,  2,  6, 16, 19,  3, 18, 15,  1,  4, 17, 17,  2,  7,  2, 14,  2,  1,  6, 11,  3, 19, 17,
          6,  1, 15,  2,  2, 15, 18, 14, 13, 10,  4,  7,  7,  7,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]),
 tensor(0.))

Define a DataLoaders class and a function that creates your DataLoaders given a train dataset, a valid dataset, and a batch size.


source

DataLoaders

 DataLoaders (*dls)

Combines training and validation data in a DataLoaders object that can be passed to a learner.


source

get_dls

 get_dls (train_ds, valid_ds, bs=32)

Turn training and validation set into a DataLoaders object.

Get the DataLoaders object and test it.

dls = get_dls(trnds, vldds)
next(iter(dls.train))[0][:5], next(iter(dls.train))[1][:5]
(tensor([[ 4, 16, 20,  1,  8,  5, 16, 10, 12,  1, 14,  5,  3,  1, 10, 17,  8, 12,  9,  1, 15,  9, 10, 20, 15,  6,  9, 17, 15, 10, 10, 14,  6,
          14, 15,  8,  4, 10, 16,  3, 19, 13,  4, 17, 16, 18,  4, 15, 17,  4,  5, 20, 15, 20, 10, 10, 12,  9, 12, 10,  1, 14, 11, 12,  1,  7,
          19,  1, 16, 10, 16,  5, 16,  6,  9,  1, 15, 13, 13,  9, 18,  8,  3, 14, 13, 16,  8,  4,  8, 10, 10,  1, 19, 11, 15,  4, 14, 13, 12,
          15,  8,  6, 20,  1, 13, 18,  3, 16, 10, 13, 16,  3,  1,  7, 18, 10, 20, 18, 18,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0],
         [ 6,  8,  3, 13,  5, 17, 11, 10,  7,  8,  4,  5,  8, 17,  3, 10,  6,  1,  9, 18, 17, 18,  3, 18,  4, 16,  1,  3,  9, 10, 10,  3, 18,
          14, 15, 14, 20,  6, 15, 10,  6, 19, 17, 16,  6,  4, 18, 13, 18,  6,  6, 20, 14,  5, 13, 10,  4, 12,  4, 13,  3,  5,  3, 19, 16, 10,
           8,  6,  1, 15,  9, 19, 17, 12, 13,  4,  6,  4,  4, 11,  8, 10,  7, 15,  6,  7,  1, 20, 15, 15, 15,  4, 10,  4,  1, 18,  3, 16, 15,
           9, 11,  9, 10, 13,  1,  1, 18,  9, 20, 16, 15,  6,  1,  9, 12, 17,  3, 13,  4,  7, 18, 15,  4,  9,  1,  3,  6,  4,  5,  4, 20, 18,
          17, 10,  1,  8,  5, 15,  6,  6,  9, 15, 14,  4, 15, 20,  1, 18, 13,  6, 16, 12, 15, 13, 14,  1,  6,  1, 13,  1, 15, 16,  1,  1, 17,
          15,  1, 14,  6,  1, 15, 13,  6,  1, 18,  1, 18, 14,  3,  4,  4, 17, 13,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0],
         [ 6, 13,  6, 16, 11,  1,  3, 18,  1,  3, 13, 16, 16,  4, 10,  8, 12,  4, 10, 15,  1, 10, 18, 18, 16,  6, 16,  6, 10, 18, 13, 10,  3,
           4, 15,  9, 18,  7, 17, 18, 18,  4, 12, 11, 17,  3,  1,  5, 15,  7, 10,  4, 16,  8, 16, 15, 12, 13,  5,  1,  3, 13, 16, 14, 13, 20,
          20, 16,  6,  1, 11, 15,  5, 20,  9,  1,  9,  2, 10, 15,  3,  9, 15,  2, 18, 18,  1, 20, 10, 10, 19, 15, 14, 16, 14,  8, 17,  9, 16,
          19, 19,  4,  1, 15,  3, 12, 17,  8, 16, 12, 11, 10,  1, 13,  2,  4, 15, 17,  5, 10, 14,  3, 20, 12,  3, 18, 11, 18,  4, 20, 11, 17,
          16,  5,  1, 18, 13, 10,  3, 10, 15, 16,  5, 17, 19, 15, 13, 13, 16, 17, 14, 14, 10,  4, 18, 15,  6, 10, 18, 12,  7, 18,  5, 18, 16,
          16,  8, 17,  6,  1, 18,  8, 12, 10, 20,  9,  6,  9, 14,  8, 10, 10,  6,  5,  4,  4,  1,  4, 16, 10,  8, 14, 14,  6, 18, 18,  4, 10,
          18,  4],
         [11,  8, 16,  9, 17, 16,  8, 10, 10, 10,  6,  5,  1,  2,  1,  1, 18,  6,  1, 12, 20,  8,  4, 16,  1,  8,  5,  9,  5,  3,  3, 16,  2,
           4, 13, 10,  4,  4, 13, 15,  9, 13,  4, 18,  1,  9,  2, 20,  4, 13, 16, 17, 16, 10, 13,  9, 17, 10,  4,  4, 20, 15,  1, 11,  6, 20,
           3, 19,  6,  4,  5, 18,  7,  3,  3,  9,  1,  3, 14,  1, 13,  4,  6, 10,  3,  4, 18,  5, 17, 20, 13, 11, 18, 13,  5,  8,  6, 11, 16,
          10, 10,  1, 16,  9,  9,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0],
         [12, 10, 10,  3,  3,  1, 18,  9, 15,  8, 16,  4,  3, 13, 13,  2,  9,  2, 13, 17,  9,  5,  2, 18,  4, 15, 10, 16, 14,  6, 15, 20, 15,
          18,  6,  4,  9,  8, 10,  5,  8, 15, 11, 10,  7, 12,  9,  7, 18, 11, 18, 15, 18,  6,  6,  6, 19,  4, 17,  5,  1,  6, 20, 10, 10,  9,
           7,  3, 13,  2, 15, 11, 10, 14,  8, 16, 15, 18,  3,  6,  9, 17, 16, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0]]),
 tensor([1., 1., 1., 0., 1.]))

Learner Framework and Callbacks

The flexible callback learner along with the useful callbacks and functions below are obtained from the fast.ai 2022 course lesson 16 (see also on GitHub) and might be adapted depending on what I find most useful along the way.

Figure out which acceleration device is available and define a functin that sends objects to that device.


source

to_device

 to_device (x, device='cpu')

Define a function that sends objects to the cpu.


source

to_cpu

 to_cpu (x)

Define exceptions that end the learning process.


source

CancelEpochException

Common base class for all non-exit exceptions.


source

CancelBatchException

Common base class for all non-exit exceptions.


source

CancelFitException

Common base class for all non-exit exceptions.

Define a callback class that assigns an order to each callback.


source

Callback

 Callback ()

Initialize self. See help(type(self)) for accurate signature.

Define a class to be used in the Learner as a context manager to handle callbacks.


source

with_cbs

 with_cbs (nm)

Initialize self. See help(type(self)) for accurate signature.

Define a function that runs callbacks in a list of callbacks.


source

run_cbs

 run_cbs (cbs, method_nm, learn=None)

Define the learner class.


source

Learner

 Learner (model, dls=(0,), loss_func=<function mse_loss>, lr=0.1,
          cbs=None, opt_func=<class 'torch.optim.sgd.SGD'>)

Initialize self. See help(type(self)) for accurate signature.

Define a class that inherits from learner that has all the functions needed for training without requiring a train callback.


source

TrainLearner

 TrainLearner (model, dls=(0,), loss_func=<function mse_loss>, lr=0.1,
               cbs=None, opt_func=<class 'torch.optim.sgd.SGD'>)

Initialize self. See help(type(self)) for accurate signature.

Create a callback that assigns model and batches to the available acceleration device.


source

DeviceCB

 DeviceCB (device='cpu')

Initialize self. See help(type(self)) for accurate signature.

Define a callback that runs for a single batch for testing purposes.


source

SingleBatchCB

 SingleBatchCB ()

Initialize self. See help(type(self)) for accurate signature.

Create a training callback that provides the learner with all functions necessary for training.


source

TrainCB

 TrainCB (n_inp=1)

Initialize self. See help(type(self)) for accurate signature.

Define scheduler callbacks that adjust the learning rate according to a schedule along with a callback that tracks the learining rate applied on every step.


source

RecorderCB

 RecorderCB (**d)

Initialize self. See help(type(self)) for accurate signature.


source

EpochSchedCB

 EpochSchedCB (sched=None)

Initialize self. See help(type(self)) for accurate signature.


source

BatchSchedCB

 BatchSchedCB (sched=None)

Initialize self. See help(type(self)) for accurate signature.


source

BaseSchedCB

 BaseSchedCB (sched=None)

Initialize self. See help(type(self)) for accurate signature.

Define a metrics callback that facilitates calculation of metrics along with a progress callback that enables the display of metrics, loss, and plots that show the training progress during training.


source

MetricsCB

 MetricsCB (*ms, **metrics)

Initialize self. See help(type(self)) for accurate signature.


source

ProgressCB

 ProgressCB (plot=False)

Initialize self. See help(type(self)) for accurate signature.

Finally create a learning rate finder callback and add a function to the learner that enables usage of the learning rate finder using learn.lr_find() syntax.


source

LRFinderCB

 LRFinderCB (gamma=1.3, max_mult=3, av_over=1)

Initialize self. See help(type(self)) for accurate signature.


source

show_doc

 show_doc (sym, renderer=None, name:str|None=None, title_level:int=3)

Show signature and docstring for sym

Type Default Details
sym Symbol to document
renderer NoneType None Optional renderer (defaults to markdown)
name str | None None Optionally override displayed name of sym
title_level int 3 Heading level to use for symbol name

Functions for Convenient Plotting of Images

/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/fastcore/docscrape.py:230: UserWarning: Unknown section Other Parameters
  else: warn(msg)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/fastcore/docscrape.py:230: UserWarning: Unknown section See Also
  else: warn(msg)

source

show_image

 show_image (im, ax=None, figsize=None, title=None, noframe=True,
             cmap=None, norm=None, aspect=None, interpolation=None,
             alpha=None, vmin=None, vmax=None, origin=None, extent=None,
             interpolation_stage=None, filternorm=True, filterrad=4.0,
             resample=None, url=None, data=None)

Show a PIL or PyTorch image on ax.

Type Default Details
im
ax NoneType None
figsize NoneType None
title NoneType None
noframe bool True
cmap NoneType None The Colormap instance or registered colormap name used to map scalar data
to colors.

This parameter is ignored if X is RGB(A).
norm NoneType None The normalization method used to scale scalar data to the [0, 1] range
before mapping to colors using cmap. By default, a linear scaling is
used, mapping the lowest value to 0 and the highest to 1.

If given, this can be one of the following:

- An instance of .Normalize or one of its subclasses
(see :ref:colormapnorms).
- A scale name, i.e. one of “linear”, “log”, “symlog”, “logit”, etc. For a
list of available scales, call matplotlib.scale.get_scale_names().
In that case, a suitable .Normalize subclass is dynamically generated
and instantiated.

This parameter is ignored if X is RGB(A).
aspect NoneType None The aspect ratio of the Axes. This parameter is particularly
relevant for images since it determines whether data pixels are
square.

This parameter is a shortcut for explicitly calling
.Axes.set_aspect. See there for further details.

- ‘equal’: Ensures an aspect ratio of 1. Pixels will be square
(unless pixel sizes are explicitly made non-square in data
coordinates using extent).
- ‘auto’: The Axes is kept fixed and the aspect is adjusted so
that the data fit in the Axes. In general, this will result in
non-square pixels.

Normally, None (the default) means to use :rc:image.aspect. However, if
the image uses a transform that does not contain the axes data transform,
then None means to not modify the axes aspect at all (in that case, directly
call .Axes.set_aspect if desired).
interpolation NoneType None The interpolation method used.

Supported values are ‘none’, ‘antialiased’, ‘nearest’, ‘bilinear’,
‘bicubic’, ‘spline16’, ‘spline36’, ‘hanning’, ‘hamming’, ‘hermite’,
‘kaiser’, ‘quadric’, ‘catrom’, ‘gaussian’, ‘bessel’, ‘mitchell’,
‘sinc’, ‘lanczos’, ‘blackman’.

The data X is resampled to the pixel size of the image on the
figure canvas, using the interpolation method to either up- or
downsample the data.

If interpolation is ‘none’, then for the ps, pdf, and svg
backends no down- or upsampling occurs, and the image data is
passed to the backend as a native image. Note that different ps,
pdf, and svg viewers may display these raw pixels differently. On
other backends, ‘none’ is the same as ‘nearest’.

If interpolation is the default ‘antialiased’, then ‘nearest’
interpolation is used if the image is upsampled by more than a
factor of three (i.e. the number of display pixels is at least
three times the size of the data array). If the upsampling rate is
smaller than 3, or the image is downsampled, then ‘hanning’
interpolation is used to act as an anti-aliasing filter, unless the
image happens to be upsampled by exactly a factor of two or one.

See
:doc:/gallery/images_contours_and_fields/interpolation_methods
for an overview of the supported interpolation methods, and
:doc:/gallery/images_contours_and_fields/image_antialiasing for
a discussion of image antialiasing.

Some interpolation methods require an additional radius parameter,
which can be set by filterrad. Additionally, the antigrain image
resize filter is controlled by the parameter filternorm.
alpha NoneType None The alpha blending value, between 0 (transparent) and 1 (opaque).
If alpha is an array, the alpha blending values are applied pixel
by pixel, and alpha must have the same shape as X.
vmin NoneType None
vmax NoneType None
origin NoneType None Place the [0, 0] index of the array in the upper left or lower
left corner of the Axes. The convention (the default) ‘upper’ is
typically used for matrices and images.

Note that the vertical axis points upward for ‘lower’
but downward for ‘upper’.

See the :ref:imshow_extent tutorial for
examples and a more detailed description.
extent NoneType None The bounding box in data coordinates that the image will fill.
These values may be unitful and match the units of the Axes.
The image is stretched individually along x and y to fill the box.

The default extent is determined by the following conditions.
Pixels have unit size in data coordinates. Their centers are on
integer coordinates, and their center coordinates range from 0 to
columns-1 horizontally and from 0 to rows-1 vertically.

Note that the direction of the vertical axis and thus the default
values for top and bottom depend on origin:

- For origin == 'upper' the default is
(-0.5, numcols-0.5, numrows-0.5, -0.5).
- For origin == 'lower' the default is
(-0.5, numcols-0.5, -0.5, numrows-0.5).

See the :ref:imshow_extent tutorial for
examples and a more detailed description.
interpolation_stage NoneType None If ‘data’, interpolation
is carried out on the data provided by the user. If ‘rgba’, the
interpolation is carried out after the colormapping has been
applied (visual interpolation).
filternorm bool True A parameter for the antigrain image resize filter (see the
antigrain documentation). If filternorm is set, the filter
normalizes integer values and corrects the rounding errors. It
doesn’t do anything with the source floating point values, it
corrects only integers according to the rule of 1.0 which means
that any sum of pixel weights must be equal to 1.0. So, the
filter function must produce a graph of the proper shape.
filterrad float 4.0 The filter radius for filters that have a radius parameter, i.e.
when interpolation is one of: ‘sinc’, ‘lanczos’ or ‘blackman’.
resample NoneType None When True, use a full resampling method. When False, only
resample when the output image is larger than the input image.
url NoneType None Set the url of the created .AxesImage. See .Artist.set_url.
data NoneType None

source

subplots

 subplots (nrows=1, ncols=1, figsize=None, imsize=3, suptitle=None,
           sharex:"bool|Literal['none','all','row','col']"=False,
           sharey:"bool|Literal['none','all','row','col']"=False,
           squeeze:bool=True, width_ratios:Sequence[float]|None=None,
           height_ratios:Sequence[float]|None=None,
           subplot_kw:dict[str,Any]|None=None,
           gridspec_kw:dict[str,Any]|None=None, **kwargs)

A figure and set of subplots to display images of imsize inches.

Type Default Details
nrows int 1 Number of rows in returned axes grid
ncols int 1 Number of columns in retruned axes grid
figsize NoneType None Width, height in inches of the returned figure
imsize int 3 Size (in inches) of images that will be displayed in the returned figure
suptitle NoneType None Title to be set to returned figure
sharex bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] False
sharey bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] False
squeeze bool True
width_ratios Sequence[float] | None None
height_ratios Sequence[float] | None None
subplot_kw dict[str, Any] | None None
gridspec_kw dict[str, Any] | None None
kwargs

source

get_grid

 get_grid (n, nrows=None, ncols=None, title=None, weight='bold', size=14,
           figsize=None, imsize=3, suptitle=None,
           sharex:"bool|Literal['none','all','row','col']"=False,
           sharey:"bool|Literal['none','all','row','col']"=False,
           squeeze:bool=True, width_ratios:Sequence[float]|None=None,
           height_ratios:Sequence[float]|None=None,
           subplot_kw:dict[str,Any]|None=None,
           gridspec_kw:dict[str,Any]|None=None)

Return a grid of n axes, nrows by ncols.

Type Default Details
n Number of axes
nrows NoneType None Number of rows, defaulting to int(math.sqrt(n))
ncols NoneType None Number of columns, defaulting to ceil(n/rows)
title NoneType None If passed, title set to the figure
weight str bold Title font weight
size int 14 Title font size
figsize NoneType None Width, height in inches of the returned figure
imsize int 3 Size (in inches) of images that will be displayed in the returned figure
suptitle NoneType None Title to be set to returned figure
sharex bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] False
sharey bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] False
squeeze bool True
width_ratios Sequence[float] | None None
height_ratios Sequence[float] | None None
subplot_kw dict[str, Any] | None None
gridspec_kw dict[str, Any] | None None

source

show_images

 show_images (ims:list, nrows=1, ncols=None, titles=None, noframe=True,
              figsize=None, imsize=3, suptitle=None,
              sharex:"bool|Literal['none','all','row','col']"=False,
              sharey:"bool|Literal['none','all','row','col']"=False,
              squeeze:bool=True, width_ratios:Sequence[float]|None=None,
              height_ratios:Sequence[float]|None=None,
              subplot_kw:dict[str,Any]|None=None,
              gridspec_kw:dict[str,Any]|None=None)

Show all images ims as subplots with nrows using titles.

Type Default Details
ims list Images to show
nrows int 1 Number of rows in grid
ncols NoneType None Number of columns in grid (auto-calculated if None)
titles NoneType None Optional list of titles for each image
noframe bool True Hide axes, yes or no
figsize NoneType None Width, height in inches of the returned figure
imsize int 3 Size (in inches) of images that will be displayed in the returned figure
suptitle NoneType None Title to be set to returned figure
sharex bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] False
sharey bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] False
squeeze bool True
width_ratios Sequence[float] | None None
height_ratios Sequence[float] | None None
subplot_kw dict[str, Any] | None None
gridspec_kw dict[str, Any] | None None

Activation Statistics using Hooks


source

append_stats

 append_stats (hook, mod, inp, outp)

source

Hook

 Hook (m, f)

Initialize self. See help(type(self)) for accurate signature.


source

Hooks

 Hooks (ms, f)

*Built-in mutable sequence.

If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.*


source

HooksCallback

 HooksCallback (hookfunc, mod_filter=<function noop>, on_train=True,
                on_valid=False, mods=None)

Initialize self. See help(type(self)) for accurate signature.


source

append_stats

 append_stats (hook, mod, inp, outp)

source

get_hist

 get_hist (h)

source

get_min

 get_min (h)

source

ActivationStats

 ActivationStats (mod_filter=<function noop>)

Initialize self. See help(type(self)) for accurate signature.

Functions for Convenient Memory Management


source

clean_ipython_hist

 clean_ipython_hist ()

source

clean_tb

 clean_tb ()

source

clean_mem

 clean_mem ()

Weight Initialization and General Relu


source

init_weights

 init_weights (m, leaky=0.0)

source

GeneralRelu

 GeneralRelu (leak=None, sub=None, maxv=None)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*

act_genrelu = partial(GeneralRelu, leak=0.1, sub=0.4)

Training a Model

Obtain a single batch from dls to help with model design.

idx = next(iter(dls.train))[0]
idx, idx.shape
(tensor([[16, 12, 13,  ...,  0,  0,  0],
         [11,  1, 16,  ...,  0,  0,  0],
         [11,  3,  6,  ...,  0,  0,  0],
         ...,
         [16, 19, 14,  ...,  0,  0,  0],
         [ 9, 18, 10,  ...,  0,  0,  0],
         [11, 17, 20,  ...,  0,  0,  0]]),
 torch.Size([32, 200]))

Tiny Resnet

Design a model architecture. Define functions to obtain 1D convolutional layers with an activation function and normalization and define a 1D residual block class.


source

ResBlock1d

 ResBlock1d (ni, nf, stride=1, ks=3, act=<class
             'torch.nn.modules.activation.ReLU'>, norm=None)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*


source

conv1d

 conv1d (ni, nf, ks=3, stride=2, act=<class
         'torch.nn.modules.activation.ReLU'>, norm=None, bias=None)

Define a class that switches the rank order from BLC to BCL.


source

Reshape

 Reshape (*args, **kwargs)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*

Put the model together.

lr = 1e-2
epochs = 10
n_embd = 16
dls = get_dls(trnds, vldds, bs=32)

model = nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),
                      ResBlock1d(n_embd, 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      nn.Flatten(1, -1),
                      nn.Linear(32, 1),
                      nn.Flatten(0, -1),
                      nn.Sigmoid())
model(idx).shape
torch.Size([32])
iw = partial(init_weights, leaky=0.1)
model = model.apply(iw)
metrics = MetricsCB(BinaryAccuracy(), BinaryMatthewsCorrCoef(), BinaryAUROC())
rec = RecorderCB(lr=_lr, beta1=_beta1, beta2=_beta2)
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), BatchSchedCB(), ProgressCB(plot=False), metrics, astats, rec]
# cbs = [DeviceCB(), ProgressCB(plot=False), metrics, astats, rec] # for lr_find()
learn = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)
print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
# learn.lr_find(start_lr=1e-4, gamma=1.05, av_over=5, max_mult=5)
Parameters total: 10175
learn.fit(epochs)
BinaryAccuracy BinaryMatthewsCorrCoef BinaryAUROC loss epoch train
0.502 -0.003 0.499 0.732 0 train
0.507 -0.022 0.480 0.701 0 eval
0.513 0.003 0.503 0.698 1 train
0.501 -0.016 0.494 0.695 1 eval
0.528 0.031 0.518 0.692 2 train
0.473 -0.014 0.490 0.695 2 eval
0.540 0.058 0.529 0.689 3 train
0.530 0.026 0.511 0.689 3 eval
0.557 0.100 0.557 0.683 4 train
0.545 0.060 0.566 0.681 4 eval
0.595 0.185 0.612 0.667 5 train
0.628 0.247 0.657 0.639 5 eval
0.636 0.271 0.661 0.641 6 train
0.637 0.266 0.665 0.633 6 eval
0.643 0.291 0.675 0.631 7 train
0.643 0.287 0.679 0.626 7 eval
0.653 0.313 0.685 0.623 8 train
0.649 0.301 0.683 0.619 8 eval
0.657 0.322 0.693 0.617 9 train
0.647 0.291 0.683 0.621 9 eval

Inspect the learning rate schedule applied during training.

rec.plot()

Transformer Model with Skip Connections and LayerNorm

The transformer model below is adapted from the model built in Andrej Karpathy’s video Let’s build GPT: from scratch, in code, spelled out.


source

FeedForward

 FeedForward (n_embed)

A simple linear layer followed by a non-linearity.


source

MultiHeadAttention

 MultiHeadAttention (num_heads, head_size)

Multiple heads of self-attention in parallel.


source

Block

 Block (n_embd, n_head)

Transformer block: communication (attention) followed by computation.


source

TransformerModel

 TransformerModel (device)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*

lr = 1e-3
block_size = 200
epochs = 10
n_embd = 16
n_head = 8
n_layer = 5
dropout = 0.2

model = TransformerModel(device='cpu')
model(idx).shape
torch.Size([32])
dls = get_dls(trnds, vldds, bs=32)
model = TransformerModel(device=def_device)
iw = partial(init_weights, leaky=0.1)
model = model.apply(iw)
metrics = MetricsCB(BinaryAccuracy(), BinaryMatthewsCorrCoef(), BinaryAUROC())
rec = RecorderCB(lr=_lr, beta1=_beta1, beta2=_beta2)
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), BatchSchedCB(), ProgressCB(plot=False), metrics, astats, rec]
learn = TrainLearner(model, dls, F.binary_cross_entropy_with_logits, lr=lr, cbs=cbs, opt_func=optim.AdamW)
print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
#learn.lr_find(start_lr=1e-5, gamma=1.1, av_over=3, max_mult=5)
Parameters total: 22929
learn.fit(epochs)
BinaryAccuracy BinaryMatthewsCorrCoef BinaryAUROC loss epoch train
0.520 0.002 0.504 0.767 0 train
0.534 0.041 0.534 0.712 0 eval
0.535 0.051 0.537 0.713 1 train
0.539 0.222 0.626 0.692 1 eval
0.602 0.189 0.627 0.670 2 train
0.659 0.265 0.683 0.622 2 eval
0.632 0.250 0.668 0.639 3 train
0.654 0.310 0.697 0.610 3 eval
0.639 0.282 0.687 0.628 4 train
0.660 0.306 0.712 0.614 4 eval
0.651 0.308 0.704 0.616 5 train
0.658 0.333 0.721 0.607 5 eval
0.656 0.329 0.719 0.604 6 train
0.650 0.346 0.722 0.605 6 eval
0.667 0.354 0.737 0.592 7 train
0.662 0.355 0.726 0.596 7 eval
0.674 0.366 0.743 0.586 8 train
0.672 0.339 0.726 0.599 8 eval
0.672 0.372 0.748 0.583 9 train
0.662 0.354 0.725 0.596 9 eval
rec.plot()