Welcome to the part 1 of testing the models for this week’s assignment. We will perform decoding using the BERT Loss model. In this notebook we’ll use an input, mask (hide) random word(s) in it and see how well we get the “Target” answer(s).
Colab
Since this ungraded lab takes a lot of time to run on coursera, as an alternative we have a colab prepared for you.
If you run into a page that looks similar to the one below, with the option Open with, this would mean you need to download the Colaboratory app. You can do so by Open with -> Connect more apps -> in the search bar write "Colaboratory" -> install
After installation it should look like this. Click on Open with Google Colaboratory
Implement the Bidirectional Encoder Representation from Transformer (BERT) loss.
Use a pretrained version of the model you created in the assignment for inference.
Part 1: Getting ready
Run the code cells below to import the necessary libraries and to define some functions which will be useful for decoding. The code and the functions are the same as the ones you previsouly ran on the graded assignment.
import pickleimport stringimport astimport numpy as npimport trax from trax.supervised import decodingimport textwrap wrapper = textwrap.TextWrapper(width=70)
2025-02-23 13:19:39.834880: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1740309579.889588 1536743 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740309579.906669 1536743 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
example_jsons =list(map(ast.literal_eval, open('data.txt')))natural_language_texts = [example_json['text'] for example_json in example_jsons]PAD, EOS, UNK =0, 1, 2def detokenize(np_array):return trax.data.detokenize( np_array, vocab_type='sentencepiece', vocab_file='sentencepiece.model', vocab_dir='.')def tokenize(s):returnnext(trax.data.tokenize(iter([s]), vocab_type='sentencepiece', vocab_file='sentencepiece.model', vocab_dir='.'))vocab_size = trax.data.vocab_size( vocab_type='sentencepiece', vocab_file='sentencepiece.model', vocab_dir='.')def get_sentinels(vocab_size, display=False): sentinels = {}for i, char inenumerate(reversed(string.ascii_letters), 1): decoded_text = detokenize([vocab_size - i]) # Sentinels, ex: <Z> - <a> sentinels[decoded_text] =f'<{char}>'if display:print(f'The sentinel is <{char}> and the decoded token is:', decoded_text)return sentinelssentinels = get_sentinels(vocab_size, display=False) def pretty_decode(encoded_str_list, sentinels=sentinels):# If already a string, just do the replacements.ifisinstance(encoded_str_list, (str, bytes)):for token, char in sentinels.items(): encoded_str_list = encoded_str_list.replace(token, char)return encoded_str_list# We need to decode and then prettyfy it.return pretty_decode(detokenize(encoded_str_list))inputs_targets_pairs = []# here you are reading already computed input/target pairs from a filewithopen ('inputs_targets_pairs_file.txt', 'rb') as fp: inputs_targets_pairs = pickle.load(fp) def display_input_target_pairs(inputs_targets_pairs):for i, inp_tgt_pair inenumerate(inputs_targets_pairs, 1): inps, tgts = inp_tgt_pair inps, tgts = pretty_decode(inps), pretty_decode(tgts)print(f'[{i}]\n'f'inputs:\n{wrapper.fill(text=inps)}\n\n'f'targets:\n{wrapper.fill(text=tgts)}\n\n\n\n')display_input_target_pairs(inputs_targets_pairs)
---------------------------------------------------------------------------ImportError Traceback (most recent call last)
Cell In[2], line 23 15deftokenize(s):
16returnnext(trax.data.tokenize(
17iter([s]),
18 vocab_type='sentencepiece',
19 vocab_file='sentencepiece.model',
20 vocab_dir='.'))
---> 23 vocab_size =trax.data.vocab_size( 24vocab_type='sentencepiece', 25vocab_file='sentencepiece.model', 26vocab_dir='.') 29defget_sentinels(vocab_size, display=False):
30 sentinels = {}
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/data/tf_inputs.py:570, in vocab_size(vocab_type, vocab_file, vocab_dir, n_reserved_ids) 550defvocab_size(vocab_type='subword',
551 vocab_file=None,
552 vocab_dir=None,
553 n_reserved_ids=0):
554"""Returns the size of the vocabulary (number of symbols used). 555 556 This function can be used to set the size of the final layers of a model that (...) 568 An integer, the number of symbols used (including reserved IDs). 569 """--> 570 vocab =_get_vocab(vocab_type,vocab_file,vocab_dir) 571return vocab.vocab_size + n_reserved_ids
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/data/tf_inputs.py:603, in _get_vocab(vocab_type, vocab_file, vocab_dir, extra_ids) 600return text_encoder.BertEncoder(path, do_lower_case=True)
602assert vocab_type =='sentencepiece'--> 603returnt5_data().SentencePieceVocabulary(sentencepiece_model_file=path,
604 extra_ids=extra_ids)
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/data/tf_inputs.py:53, in t5_data() 51 module =None 52try:
---> 53importt5.data# pylint: disable=g-import-not-at-top 54 module = t5.data
55exceptAttributeErroras e:
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/t5/__init__.py:17 1# Copyright 2023 The T5 Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); (...) 12# See the License for the specific language governing permissions and 13# limitations under the License. 15"""Import API modules."""---> 17importt5.data 18importt5.evaluation 20# Version number.
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/t5/data/__init__.py:17 15"""Import data modules.""" 16# pylint:disable=wildcard-import,g-bad-import-order---> 17fromt5.data.dataset_providersimport* 18fromt5.data.glue_utilsimport* 19importt5.data.postprocessors
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/t5/data/dataset_providers.py:28 25fromcollections.abcimport Mapping
26importre---> 28importseqio 29fromt5.dataimport utils
30importtensorflow.compat.v2astf
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/seqio/__init__.py:19 15"""Import to top-level API.""" 17# pylint:disable=wildcard-import,g-bad-import-order,g-import-not-at-top---> 19fromseqio.dataset_providersimport* 20fromseqioimport evaluation
21fromseqioimport experimental
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/seqio/dataset_providers.py:36 33fromtypingimport Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, Type, Union
35fromabslimport logging
---> 36importclu.metrics 37importeditdistance 38importnumpyasnp
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/clu/metrics.py:66 64fromclu.internalimport utils
65importclu.values---> 66importflax 67importjax 68importjax.numpyasjnp
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/flax/__init__.py:19 1# Copyright 2022 The Flax Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); (...) 14 15# Lint as: python 3 17"""Flax API."""---> 19from.import core
20from.import linen
21from.import optim
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/flax/core/__init__.py:15 1# Copyright 2022 The Flax Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); (...) 12# See the License for the specific language governing permissions and 13# limitations under the License.---> 15from.axes_scanimport broadcast
16from.frozen_dictimport FrozenDict, freeze, unfreeze
17from.tracersimport current_trace, trace_level, check_trace_level
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/flax/core/axes_scan.py:22 19fromjaximport lax
21fromjax.interpretersimport partial_eval as pe
---> 22fromjaximport linear_util as lu
24fromtypingimport Union, Optional, Callable, Any
26importnumpyasnpImportError: cannot import name 'linear_util' from 'jax' (/home/oren/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/jax/__init__.py)
Part 2: BERT Loss
Now that you created the encoder, we will not make you train it. Training it could easily cost you a few days depending on which GPUs/TPUs you are using. Very few people train the full transformer from scratch. Instead, what the majority of people do, they load in a pretrained model, and they fine tune it on a specific task. That is exactly what you are about to do. Let’s start by initializing and then loading in the model.
# Now load in the model# this takes about 1 minuteshape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32) # Needed in predict mode.model.init_from_file('model.pkl.gz', weights_only=True, input_signature=(shape11, shape11))
---------------------------------------------------------------------------NotFoundError Traceback (most recent call last)
Cell In[4], line 4 1# Now load in the model 2# this takes about 1 minute 3 shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32) # Needed in predict mode.----> 4model.init_from_file('model.pkl.gz', 5weights_only=True,input_signature=(shape11,shape11))
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/layers/base.py:334, in Layer.init_from_file(self, file_name, weights_only, input_signature) 332with tf.io.gfile.GFile(file_name, 'rb') as f:
333with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf:
--> 334 dictionary =pickle.load(gzipf) 335# In the current checkpoint format, we store weights in a separate 336# non-pickled file with the same name but added ".npy". 337ifisinstance(dictionary['flat_weights'], int):
File /usr/lib/python3.10/gzip.py:321, in GzipFile.peek(self, n) 319importerrno 320raiseOSError(errno.EBADF, "peek() on write-only GzipFile object")
--> 321returnself._buffer.peek(n)
File /usr/lib/python3.10/_compression.py:68, in DecompressReader.readinto(self, b) 66defreadinto(self, b):
67withmemoryview(b) as view, view.cast("B") as byte_view:
---> 68 data =self.read(len(byte_view)) 69 byte_view[:len(data)] = data
70returnlen(data)
File /usr/lib/python3.10/gzip.py:488, in _GzipReader.read(self, size) 484ifself._new_member:
485# If the _new_member flag is set, we have to 486# jump to the next member, if there is one. 487self._init_read()
--> 488ifnotself._read_gzip_header():
489self._size =self._pos
490returnb""
File /usr/lib/python3.10/gzip.py:431, in _GzipReader._read_gzip_header(self) 430def_read_gzip_header(self):
--> 431 magic =self._fp.read(2) 432if magic ==b'':
433returnFalse
File /usr/lib/python3.10/gzip.py:97, in _PaddedFile.read(self, size) 94 read =self._read
95self._read =None 96returnself._buffer[read:] + \
---> 97self.file.read(size-self._length+read)
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/tensorflow/python/lib/io/file_io.py:116, in FileIO.read(self, n) 104defread(self, n=-1):
105"""Returns the contents of a file as a string. 106 107 Starts reading from current position in file. (...) 114 string if in string (regular) mode. 115 """--> 116self._preread_check() 117if n ==-1:
118 length =self.size() -self.tell()
File ~/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/tensorflow/python/lib/io/file_io.py:77, in FileIO._preread_check(self) 74ifnotself._read_check_passed:
75raise errors.PermissionDeniedError(None, None,
76"File isn't open for reading")
---> 77self._read_buf =_pywrap_file_io.BufferedInputStream( 78compat.path_to_str(self.__name),1024*512)NotFoundError: model.pkl.gz; No such file or directory
2.1 Decoding
Now you will use one of the inputs_targets_pairs for input and as target. Next you will use the pretty_decode to output the input and target. The code to perform all of this has been provided below.
# using the 3rd examplec4_input = inputs_targets_pairs[2][0]c4_target = inputs_targets_pairs[2][1]print('pretty_decoded input: \n\n', pretty_decode(c4_input))print('\npretty_decoded target: \n\n', pretty_decode(c4_target))print('\nc4_input:\n\n', c4_input)print('\nc4_target:\n\n', c4_target)print(len(c4_target))print(len(pretty_decode(c4_target)))
---------------------------------------------------------------------------NameError Traceback (most recent call last)
Cell In[5], line 2 1# using the 3rd example----> 2 c4_input =inputs_targets_pairs[2][0]
3 c4_target = inputs_targets_pairs[2][1]
5print('pretty_decoded input: \n\n', pretty_decode(c4_input))
NameError: name 'inputs_targets_pairs' is not defined
Run the cell below to decode.
Note: This will take some time to run
# Temperature is a parameter for sampling.# # * 0.0: same as argmax, always pick the most probable token# # * 1.0: sampling from the distribution (can sometimes say random things)# # * values inbetween can trade off diversity and quality, try it out!output = decoding.autoregressive_sample(model, inputs=np.array(c4_input)[None, :], temperature=0.0, max_length=5) # originally max_length = 50print(wrapper.fill(pretty_decode(output[0])))
---------------------------------------------------------------------------NameError Traceback (most recent call last)
Cell In[6], line 5 1# Temperature is a parameter for sampling. 2# # * 0.0: same as argmax, always pick the most probable token 3# # * 1.0: sampling from the distribution (can sometimes say random things) 4# # * values inbetween can trade off diversity and quality, try it out!----> 5 output = decoding.autoregressive_sample(model, inputs=np.array(c4_input)[None, :],
6 temperature=0.0, max_length=5) # originally max_length = 50 7print(wrapper.fill(pretty_decode(output[0])))
NameError: name 'c4_input' is not defined
At this point the RAM is almost full, this happens because the model and the decoding is memory heavy. You can run decoding just once. Running it the second time with another example might give you an answer that makes no sense, or repetitive words. If that happens restart the runtime (see how to at the start of the notebook) and run all the cells again.
You should also be aware that the quality of the decoding is not very good because max_length was downsized from 50 to 5 so that this runs faster within this environment. The colab version uses the original max_length so check that one for the actual decoding.