Compilation#

Various libraries provide tools for tracing numeric functions and turning the resulting computation into a more efficient, compiled function. Notably:

autoray is obviously very well suited to these since it just dispatches functions to whichever library is doing the tracing - functions written using autoray should be immediately compatible with all of them.

The autojit wrapper

Moreover, autoray also provides a unified interface for compiling functions so that the compilation backend can be easily switched or automatically identified:

from autoray import autojit

mgs = autojit(modified_gram_schmidt)

Currently autojit supports functions with the signature fn(*args, **kwargs) -> array where both args and kwargs can be any nested combination of tuple, list and dict objects containings arrays. We can compare different compiled versions of this simply by changing the backend option:

x = do("random.normal", size=(50, 50), like='numpy')

# first the uncompiled version
%%timeit
modified_gram_schmidt(x)
# 23.5 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# 'python' mode unravels computation into source then uses compile+exec
%%timeit
mgs(x)  # backend='python'
# 17.8 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
mgs(x, backend='torch')
# 11.9 ms ± 80.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
mgs(x, backend='tensorflow')
# 1.87 ms ± 441 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

# need to config jax to run on same footing
from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')

%%timeit
mgs(x, backend='jax')
# 226 µs ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
do('linalg.qr', x, like='numpy')[0]  # appriximately the 'C' version
# 156 µs ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Here you see (with this very for-loop heavy function), that there are significant gains to be made for all the compilations options. Whilst jax for example achieves fantastic performance, it should be noted the compilation step takes a lot of time and scales badly (super-linearly) with the number of computational nodes.