{ "cells": [ { "cell_type": "markdown", "id": "92fef1c2-0074-4c7c-ba1a-faa5aee4816d", "metadata": {}, "source": [ "# Compilation\n", "\n", "Various libraries provide tools for tracing numeric functions and turning the resulting computation into a more efficient, compiled function. Notably:\n", "\n", "* [``jax.jit``](https://github.com/google/jax)\n", "* [``tensorflow.function``](https://www.tensorflow.org/api_docs/python/tf/function)\n", "* [``torch.jit.trace``](https://pytorch.org/docs/stable/jit.html)\n", "\n", " ``autoray`` is obviously very well suited to these since it just dispatches functions to whichever library is doing the tracing - functions written using autoray should be immediately compatible with all of them.\n", "\n", "**The `autojit` wrapper**\n", "\n", "Moreover, ``autoray`` also provides a *unified interface* for compiling functions so that the compilation backend can be easily switched or automatically identified:\n", "\n", "```python\n", "from autoray import autojit\n", "\n", "mgs = autojit(modified_gram_schmidt)\n", "```\n", "\n", "Currently ``autojit`` supports functions with the signature ``fn(*args, **kwargs) -> array`` where both ``args`` and ``kwargs`` can be any nested combination of ``tuple``, ``list`` and ``dict`` objects containings arrays.\n", "We can compare different compiled versions of this simply by changing the ``backend`` option:\n", "\n", "```python\n", "x = do(\"random.normal\", size=(50, 50), like='numpy')\n", "\n", "# first the uncompiled version\n", "%%timeit\n", "modified_gram_schmidt(x)\n", "# 23.5 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", "\n", "# 'python' mode unravels computation into source then uses compile+exec\n", "%%timeit\n", "mgs(x) # backend='python'\n", "# 17.8 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "\n", "%%timeit\n", "mgs(x, backend='torch')\n", "# 11.9 ms ± 80.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "\n", "%%timeit\n", "mgs(x, backend='tensorflow')\n", "# 1.87 ms ± 441 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "\n", "# need to config jax to run on same footing\n", "from jax.config import config\n", "config.update(\"jax_enable_x64\", True)\n", "config.update('jax_platform_name', 'cpu')\n", "\n", "%%timeit\n", "mgs(x, backend='jax')\n", "# 226 µs ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "\n", "%%timeit\n", "do('linalg.qr', x, like='numpy')[0] # appriximately the 'C' version\n", "# 156 µs ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", "```\n", "\n", "Here you see *(with this very for-loop heavy function)*, that there are significant gains to be made for all the compilations options. Whilst ``jax`` for example achieves fantastic performance, it should be noted the compilation step takes a lot of time and scales badly (super-linearly) with the number of computational nodes." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3" } }, "nbformat": 4, "nbformat_minor": 2 }