Compilation#
Various libraries provide tools for tracing numeric functions and turning the resulting computation into a more efficient, compiled function. Notably:
autoray
is obviously very well suited to these since it just dispatches functions to whichever library is doing the tracing - functions written using autoray should be immediately compatible with all of them.
The autojit
wrapper
Moreover, autoray
also provides a unified interface for compiling functions so that the compilation backend can be easily switched or automatically identified:
from autoray import autojit
mgs = autojit(modified_gram_schmidt)
Currently autojit
supports functions with the signature fn(*args, **kwargs) -> array
where both args
and kwargs
can be any nested combination of tuple
, list
and dict
objects containings arrays.
We can compare different compiled versions of this simply by changing the backend
option:
x = do("random.normal", size=(50, 50), like='numpy')
# first the uncompiled version
%%timeit
modified_gram_schmidt(x)
# 23.5 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 'python' mode unravels computation into source then uses compile+exec
%%timeit
mgs(x) # backend='python'
# 17.8 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
mgs(x, backend='torch')
# 11.9 ms ± 80.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
mgs(x, backend='tensorflow')
# 1.87 ms ± 441 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# need to config jax to run on same footing
from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')
%%timeit
mgs(x, backend='jax')
# 226 µs ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
do('linalg.qr', x, like='numpy')[0] # appriximately the 'C' version
# 156 µs ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Here you see (with this very for-loop heavy function), that there are significant gains to be made for all the compilations options. Whilst jax
for example achieves fantastic performance, it should be noted the compilation step takes a lot of time and scales badly (super-linearly) with the number of computational nodes.