#!pip install trax==1.3.1 Use this version for this notebook
In this notebook you’ll get to know about the Trax framework and learn about some of its basic building blocks.
Background
Why Trax and not TensorFlow or PyTorch?
TensorFlow
and PyTorch
are both extensive frameworks that can do almost anything in deep learning. They offer a lot of flexibility, but that often means verbosity of syntax and extra time to code.
Trax
is much more concise. It runs on a TensorFlow backend but allows you to train models with 1 line commands. Trax
also runs end to end, allowing you to get data, model and train all with a single terse statements. This means you can focus on learning, instead of spending hours on the idiosyncrasies of big framework implementation.
Why not Keras then?
Keras
is now part of Tensorflow
itself from 2.0 onwards. Also, Trax
is good for implementing new state of the art algorithms like Transformers, Reformers, BERT because it is actively maintained by Google Brain Team for advanced deep learning tasks. It runs smoothly on CPUs, GPUs and TPUs as well with comparatively lesser modifications in code.
How to Code in Trax
Building models in Trax
relies on 2 key concepts:
- layers and
- combinators.
Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax combinators, allowing you to build layers and models of any complexity.
Trax, JAX, TensorFlow and Tensor2Tensor
You already know that Trax
uses Tensorflow
as a backend, but it also uses the JAX
library to speed up computation too. You can view JAX
as an enhanced and optimized version of numpy.
Watch out for assignments which import import trax.fastmath.numpy as np
. If you see this line, remember that when calling np
you are really calling Trax’s version of numpy that is compatible with JAX.
As a result of this, where you used to encounter the type numpy.ndarray
now you will find the type jax.interpreters.xla.DeviceArray
.
Tensor2Tensor is another name you might have heard. It started as an end to end solution much like how Trax is designed, but it grew unwieldy and complicated. So you can view Trax as the new improved version that operates much faster and simpler.
Resources
Installing Trax
Trax has dependencies on JAX and some libraries like JAX which are yet to be supported in Windows but work well in Ubuntu and MacOS. We would suggest that if you are working on Windows, try to install Trax on WSL2.
Official maintained documentation - trax-ml not to be confused with this TraX
Imports
import numpy as np # regular ol' numpy
from trax import layers as tl # core building block
from trax import shapes # data signatures: dimensionality and type
from trax import fastmath # uses jax, offers numpy on steroids
2025-02-10 16:53:07.573380: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1739199187.587843 120958 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739199187.592157 120958 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
# Trax version 1.3.1 or better
!pip list | grep trax
trax 1.4.1
Layers
Layers are the core building blocks in Trax or as mentioned in the lectures, they are the base classes.
They take inputs, compute functions/custom calculations and return outputs.
You can also inspect layer properties. Let me show you some examples.
Relu Layer
First I’ll show you how to build a relu activation function as a layer. A layer like this is one of the simplest types. Notice there is no object initialization so it works just like a math function.
Note: Activation functions are also layers in Trax, which might look odd if you have been using other frameworks for a longer time.
# Layers
# Create a relu trax layer
= tl.Relu()
relu
# Inspect properties
print("-- Properties --")
print("name :", relu.name)
print("expected inputs :", relu.n_in)
print("promised outputs :", relu.n_out, "\n")
# Inputs
= np.array([-2, -1, 0, 1, 2])
x print("-- Inputs --")
print("x :", x, "\n")
# Outputs
= relu(x)
y print("-- Outputs --")
print("y :", y)
-- Properties --
name : Serial
expected inputs : 1
promised outputs : 1
-- Inputs --
x : [-2 -1 0 1 2]
-- Outputs --
y : [0 0 0 1 2]
Concatenate Layer
Now I’ll show you how to build a layer that takes 2 inputs. Notice the change in the expected inputs property from 1 to 2.
# Create a concatenate trax layer
= tl.Concatenate()
concat print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")
# Inputs
= np.array([-10, -20, -30])
x1 = x1 / -10
x2 print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")
# Outputs
= concat([x1, x2])
y print("-- Outputs --")
print("y :", y)
-- Properties --
name : Concatenate
expected inputs : 2
promised outputs : 1
-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
-- Outputs --
y : [-10. -20. -30. 1. 2. 3.]
Layers are Configurable
You can change the default settings of layers. For example, you can change the expected inputs for a concatenate layer from 2 to 3 using the optional parameter n_items
.
# Configure a concatenate layer
= tl.Concatenate(n_items=3) # configure the layer's expected inputs
concat_3 print("-- Properties --")
print("name :", concat_3.name)
print("expected inputs :", concat_3.n_in)
print("promised outputs :", concat_3.n_out, "\n")
# Inputs
= np.array([-10, -20, -30])
x1 = x1 / -10
x2 = x2 * 0.99
x3 print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3, "\n")
# Outputs
= concat_3([x1, x2, x3])
y print("-- Outputs --")
print("y :", y)
-- Properties --
name : Concatenate
expected inputs : 3
promised outputs : 1
-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
x3 : [0.99 1.98 2.97]
-- Outputs --
y : [-10. -20. -30. 1. 2. 3. 0.99 1.98 2.97]
Note: At any point,if you want to refer the function help/ look up the documentation or use help function.
#help(tl.Concatenate) #Uncomment this to see the function docstring with explaination
Layers can have Weights
Some layer types include mutable weights and biases that are used in computation and training. Layers of this type require initialization before use.
For example the LayerNorm
layer calculates normalized data, that is also scaled by weights and biases. During initialization you pass the data shape and data type of the inputs, so the layer can initialize compatible arrays of weights and biases.
# Uncomment any of them to see information regarding the function
# help(tl.LayerNorm)
# help(shapes.signature)
# Layer initialization
= tl.LayerNorm()
norm # You first must know what the input data will look like
= np.array([0, 1, 2, 3], dtype="float")
x
# Use the input data signature to get shape and type for initializing weights and biases
# We need to convert the input datatype from usual tuple to trax ShapeDtype
norm.init(shapes.signature(x))
print("Normal shape:",x.shape, "Data Type:",type(x.shape))
print("Shapes Trax:",shapes.signature(x),"Data Type:",type(shapes.signature(x)))
# Inspect properties
print("-- Properties --")
print("name :", norm.name)
print("expected inputs :", norm.n_in)
print("promised outputs :", norm.n_out)
# Weights and biases
print("weights :", norm.weights[0])
print("biases :", norm.weights[1], "\n")
# Inputs
print("-- Inputs --")
print("x :", x)
# Outputs
= norm(x)
y print("-- Outputs --")
print("y :", y)
/home/oren/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/layers/normalization.py:141: UserWarning: Explicitly requested dtype float64 requested in ones is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
scale = jnp.ones(features, dtype=input_signature.dtype)
/home/oren/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/layers/normalization.py:142: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
bias = jnp.zeros(features, dtype=input_signature.dtype)
((Array([1., 1., 1., 1.], dtype=float32),
Array([0., 0., 0., 0.], dtype=float32)),
())
Normal shape: (4,) Data Type: <class 'tuple'>
Shapes Trax: ShapeDtype{shape:(4,), dtype:float64} Data Type: <class 'trax.shapes.ShapeDtype'>
-- Properties --
name : LayerNorm
expected inputs : 1
promised outputs : 1
weights : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.]
-- Inputs --
x : [0. 1. 2. 3.]
-- Outputs --
y : [-1.3416404 -0.44721344 0.44721344 1.3416404 ]
Custom Layers
This is where things start getting more interesting! You can create your own custom layers too and define custom functions for computations by using tl.Fn
. Let me show you how.
help(tl.Fn)
Help on function Fn in module trax.layers.base:
Fn(name, f, n_out=1)
Returns a layer with no weights that applies the function `f`.
`f` can take and return any number of arguments, and takes only positional
arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
The following, for example, would create a layer that takes two inputs and
returns two outputs -- element-wise sums and maxima:
`Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
The layer's number of inputs (`n_in`) is automatically set to number of
positional arguments in `f`, but you must explicitly set the number of
outputs (`n_out`) whenever it's not the default value 1.
Args:
name: Class-like name for the resulting layer; for use in debugging.
f: Pure function from input tensors to output tensors, where each input
tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
Output tensors must be packaged as specified in the `Layer` class
docstring.
n_out: Number of outputs promised by the layer; default value 1.
Returns:
Layer executing the function `f`.
# Define a custom layer
# In this example you will create a layer to calculate the input times 2
def TimesTwo():
= "TimesTwo" #don't forget to give your custom layer a name to identify
layer_name
# Custom function for the custom layer
def func(x):
return x * 2
return tl.Fn(layer_name, func)
# Test it
= TimesTwo()
times_two
# Inspect properties
print("-- Properties --")
print("name :", times_two.name)
print("expected inputs :", times_two.n_in)
print("promised outputs :", times_two.n_out, "\n")
# Inputs
= np.array([1, 2, 3])
x print("-- Inputs --")
print("x :", x, "\n")
# Outputs
= times_two(x)
y print("-- Outputs --")
print("y :", y)
-- Properties --
name : TimesTwo
expected inputs : 1
promised outputs : 1
-- Inputs --
x : [1 2 3]
-- Outputs --
y : [2 4 6]
Combinators
You can combine layers to build more complex layers. Trax provides a set of objects named combinator layers to make this happen. Combinators are themselves layers, so behavior commutes.
Serial Combinator
This is the most common and easiest to use. For example could build a simple neural network by combining layers into a single layer using the Serial
combinator. This new layer then acts just like a single layer, so you can inspect intputs, outputs and weights. Or even combine it into another layer! Combinators can then be used as trainable models. Try adding more layers
Note:As you must have guessed, if there is serial combinator, there must be a parallel combinator as well. Do try to explore about combinators and other layers from the trax documentation and look at the repo to understand how these layers are written.
# help(tl.Serial)
# help(tl.Parallel)
# Serial combinator
= tl.Serial(
serial # normalize input
tl.LayerNorm(), # convert negative values to zero
tl.Relu(), # the custom layer you created above, multiplies the input recieved from above by 2
times_two,
### START CODE HERE
# tl.Dense(n_units=2), # try adding more layers. eg uncomment these lines
# tl.Dense(n_units=1), # Binary classification, maybe? uncomment at your own peril
# tl.LogSoftmax() # Yes, LogSoftmax is also a layer
### END CODE HERE
)
# Initialization
= np.array([-2, -1, 0, 1, 2]) #input
x #initialising serial instance
serial.init(shapes.signature(x))
print("-- Serial Model --")
print(serial,"\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
print("weights & biases:", serial.weights, "\n")
# Inputs
print("-- Inputs --")
print("x :", x, "\n")
# Outputs
= serial(x)
y print("-- Outputs --")
print("y :", y)
/home/oren/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/layers/normalization.py:141: UserWarning: Explicitly requested dtype int64 requested in ones is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
scale = jnp.ones(features, dtype=input_signature.dtype)
/home/oren/work/notes/notes-nlp/.venv/lib/python3.10/site-packages/trax/layers/normalization.py:142: UserWarning: Explicitly requested dtype int64 requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
bias = jnp.zeros(features, dtype=input_signature.dtype)
(((Array([1, 1, 1, 1, 1], dtype=int32), Array([0, 0, 0, 0, 0], dtype=int32)),
((), (), ()),
()),
((), ((), (), ()), ()))
-- Serial Model --
Serial[
LayerNorm
Serial[
Relu
]
TimesTwo
]
-- Properties --
name : Serial
sublayers : [LayerNorm, Serial[
Relu
], TimesTwo]
expected inputs : 1
promised outputs : 1
weights & biases: ((Array([1, 1, 1, 1, 1], dtype=int32), Array([0, 0, 0, 0, 0], dtype=int32)), ((), (), ()), ())
-- Inputs --
x : [-2 -1 0 1 2]
-- Outputs --
y : [0. 0. 0. 1.4142132 2.8284264]
JAX
Just remember to lookout for which numpy you are using, the regular ol’ numpy or Trax’s JAX compatible numpy. Both tend to use the alias np so watch those import blocks.
Note:There are certain things which are still not possible in fastmath.numpy which can be done in numpy so you will see in assignments we will switch between them to get our work done.
# Numpy vs fastmath.numpy have different data types
# Regular ol' numpy
= np.array([1, 2, 3])
x_numpy print("good old numpy : ", type(x_numpy), "\n")
# Fastmath and jax numpy
= fastmath.numpy.array([1, 2, 3])
x_jax print("jax trax numpy : ", type(x_jax))
good old numpy : <class 'numpy.ndarray'>
jax trax numpy : <class 'jaxlib.xla_extension.ArrayImpl'>
Summary
Trax is a concise framework, built on TensorFlow, for end to end machine learning. The key building blocks are layers and combinators. This notebook is just a taste, but sets you up with some key inuitions to take forward into the rest of the course and assignments where you will build end to end models.
Citation
@online{bochman2020,
author = {Bochman, Oren},
title = {Trax : {Ungraded} {Lecture} {Notebook}},
date = {2020-11-07},
url = {https://orenbochman.github.io/notes-nlp/notes/c3w1/lab01.html},
langid = {en}
}