Trax is a recent effort by Google Brain Team for creating a Deep Learning Framework using Jax.


Jax and Trax differ are quite different from PyTorch, Tensorflow or any other Deep Learning Framework. Jax/Trax follow a pure Numpy-Python based syntax. You can simply play with tensors as numpy arrays and still keep track of your computational graph.


Not only syntax, but Jax/Trax are found to be faster than any other framework.


I contributed a Semantic Segmentation example for Trax. Here is the link to the notebook.


Currently, the Trax community is quite new. However there has been some recent surge in contributors due to the popular Transformer and Reformer implementation in Trax. I believe, the community will grow much rapidly with time as google starts distributing more code bases in Jax/Trax.

Personally, I feel Trax community has worked more on NLP and hence needed some tools in the Vision domain. Hence, I choose Semantic segmentation as my target.

Trax needs more tools for vision. For example, there are no image augmentation methods implemented.


Thanks to Lukasz Kaiser for reviewing the code.