{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "89a08e5a-29ea-4717-9fb1-a64167baef1d", "metadata": {}, "source": [ "# Lazy computation\n", "\n", "Abstracting out the array interface using `autoray` also allows tracing through\n", "computations lazily. This is useful for a number of purposes, including:\n", "\n", "1. Investigating the computational graph, including cost and memory usage,\n", " of a calculation ahead of time.\n", "2. Doing basic computational graph optimizations such as **constant folding**\n", " and **intermediate sharing**.\n", "3. Extracting a flattened list of operations that can be compiled or\n", " translated to other libraries.\n", "\n", "This is implemented in a very lightweight fashion in `autoray` using the array\n", "backend found in [autoray.lazy](autoray.lazy).\n", "\n", "---\n", "\n", "As an illustration first let's define a simple autoray function:" ] }, { "cell_type": "code", "execution_count": 1, "id": "50ea547d-9e34-4a4d-a84e-bb472204624c", "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "\n", "from autoray import do, shape\n", "\n", "\n", "def modified_gram_schmidt(X):\n", " # n.b. performance-wise this particular function is *not*\n", " # a good candidate for a pure python implementation\n", "\n", " Q = []\n", " for j in range(0, shape(X)[0]):\n", " q = X[j, :]\n", " for i in range(0, j):\n", " rij = do(\"tensordot\", do(\"conj\", Q[i]), q, 1)\n", " q = q - rij * Q[i]\n", "\n", " rjj = do(\"linalg.norm\", q, 2)\n", " Q.append(q / rjj)\n", "\n", " return do(\"stack\", tuple(Q), axis=0)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a3408af7-e335-4aec-8c51-358fe5290e01", "metadata": {}, "source": [ "This function automatically dispatches based on ``X``. Let's start with a\n", "`torch` tensor:" ] }, { "cell_type": "code", "execution_count": 2, "id": "f24903fa-d994-4017-a0a0-ba85f167a263", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.5805, -0.4512, -0.2807, 0.3313, 0.4838, 0.1920],\n", " [ 0.3426, 0.2975, 0.3121, 0.3296, 0.7331, -0.2251],\n", " [-0.3725, 0.1051, 0.3351, -0.7794, 0.3566, 0.0568],\n", " [ 0.6077, -0.3415, -0.4634, -0.3796, 0.3007, 0.2547],\n", " [-0.0159, 0.5119, -0.0306, 0.1317, 0.0139, 0.8481],\n", " [ 0.1933, -0.5641, 0.7042, 0.1128, -0.1033, 0.3538]])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# input array - can be anything autoray.do supports\n", "x = do(\"random.normal\", size=(6, 6), like=\"torch\")\n", "modified_gram_schmidt(x)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "17cab25d-87f6-49ca-991c-4b5292049be8", "metadata": {}, "source": [ "If instead we wanted to run the function lazily, we first call\n", "[`lazy.array`](autoray.lazy.array) to wrap `x`:" ] }, { "cell_type": "code", "execution_count": 3, "id": "a3c91b05-cd17-4cf6-9892-2d3a320eea82", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from autoray import lazy\n", "\n", "lx = lazy.array(x)\n", "ly = modified_gram_schmidt(lx)\n", "ly" ] }, { "attachments": {}, "cell_type": "markdown", "id": "03d992ad-f746-4cc2-9e56-14606bc827a3", "metadata": {}, "source": [ "[`LazyArray`](autoray.lazy.LazyArray) objects simply stores the following:\n", "\n", "* The function to be called and backend it came from\n", "* The `args` and `kwargs` to be passed to the function\n", "* A tuple of which of these are themselves `LazyArray` objects, known as\n", " *'dependencies'*\n", "* The shape of `fn(*args, **kwargs)` were it to be computed\n", "\n", "If a lazy array is an input (as with `lx`), or has been materialized /\n", "computed, then it simply stores the result and shape of the computation, and\n", "has no reference to *how* it was computed. This means you should do any\n", "inspection of the graph before performing computation." ] }, { "attachments": {}, "cell_type": "markdown", "id": "08f9156b-3d75-470a-a225-0c89e27de115", "metadata": {}, "source": [ "## Inspection\n", "\n", "For speed and simplicity, there is not an actual graph data structure, instead\n", "the `LazyArray` objects simply track their dependencies (and not their\n", "'children'). However from this we can still traverse the nodes and extract an\n", "actual graph if so desired. Useful methods are:\n", "\n", "* [`LazyArray.ascend`](autoray.lazy.ascend) - generate every unique node\n", " in the graph, yielding dependencies before their children (i.e. a topological\n", " sort). This is the computational order. Nodes are also sorted by their\n", " *'depth'*, i.e. the longest distance to an input.\n", "\n", "* [`LazyArray.descend`](autoray.lazy.descend) - generate every unique node\n", " in the graph, starting from the current node. Use this if order doesn't\n", " matter.\n", "\n", "```{hint}\n", "Both these can be called as methods but also have top level function versions\n", "that also accept a sequence of `LazyArray` objects - i.e. multiple outputs.\n", "```\n", "\n", "You can also extract an actual graph using the following method:\n", "\n", "- [`LazyArray.to_nx_digraph`](autoray.lazy.LazyArray.to_nx_digraph)\n", "\n", "\n", "Some built in graph inspection methods are illustrated below:" ] }, { "cell_type": "code", "execution_count": 4, "id": "65c5d01a-89ba-4414-b984-545bb8e043b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0 stack[6, 6]\n", " 1 ├─truediv[6]\n", " 2 │ ├─getitem[6]\n", " 3 │ │ ╰─←[6, 6]\n", " 4 │ ╰─linalg_norm[]\n", " 5 │ ╰─ ... (getitem[6] from line 2)\n", " 6 ├─truediv[6]\n", " 7 │ ├─sub[6]\n", " 8 │ │ ├─getitem[6]\n", " 9 │ │ │ ╰─ ... (←[6, 6] from line 3)\n", " 10 │ │ ╰─mul[6]\n", " 11 │ │ ├─ ... (truediv[6] from line 1)\n", " 12 │ │ ╰─tensordot[]\n", " 13 │ │ ├─ ... (getitem[6] from line 8)\n", " 14 │ │ ╰─conj[6]\n", " 15 │ │ ╰─ ... (truediv[6] from line 1)\n", " 16 │ ╰─linalg_norm[]\n", " 17 │ ╰─ ... (sub[6] from line 7)\n", " 18 ├─truediv[6]\n", " 19 │ ├─sub[6]\n", " 20 │ │ ├─sub[6]\n", " 21 │ │ │ ├─getitem[6]\n", " 22 │ │ │ │ ╰─ ... (←[6, 6] from line 3)\n", " 23 │ │ │ ╰─mul[6]\n", " 24 │ │ │ ├─ ... (truediv[6] from line 1)\n", " 25 │ │ │ ╰─tensordot[]\n", " 26 │ │ │ ├─ ... (getitem[6] from line 21)\n", " 27 │ │ │ ╰─conj[6]\n", " 28 │ │ │ ╰─ ... (truediv[6] from line 1)\n", " 29 │ │ ╰─mul[6]\n", " 30 │ │ ├─ ... (truediv[6] from line 6)\n", " 31 │ │ ╰─tensordot[]\n", " 32 │ │ ├─ ... (sub[6] from line 20)\n", " 33 │ │ ╰─conj[6]\n", " 34 │ │ ╰─ ... (truediv[6] from line 6)\n", " 35 │ ╰─linalg_norm[]\n", " 36 │ ╰─ ... (sub[6] from line 19)\n", " 37 ├─truediv[6]\n", " 38 │ ├─sub[6]\n", " 39 │ │ ├─sub[6]\n" ] } ], "source": [ "# print the lazy computation graph\n", "ly.show(max_lines=40)" ] }, { "cell_type": "code", "execution_count": 5, "id": "d4512ece-9393-4e24-a362-5bf2777f2588", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'stack': 1,\n", " 'truediv': 6,\n", " 'linalg_norm': 6,\n", " 'sub': 15,\n", " 'mul': 15,\n", " 'getitem': 6,\n", " 'None': 1,\n", " 'tensordot': 15,\n", " 'conj': 15}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# how many times each function was called\n", "ly.history_fn_frequencies()" ] }, { "cell_type": "code", "execution_count": 6, "id": "e8ce1d57-9457-49f9-a912-feac1c59b831", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the counts as a pie chart\n", "ly.plot_history_stats(fn=\"count\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "d3e82d07-57a0-4587-8f10-21571f4f2ea1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "36" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the largest node encountered\n", "ly.history_max_size()" ] }, { "cell_type": "code", "execution_count": 8, "id": "f2d0388e-be9c-4d85-8216-3a66951c0209", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the sizes of all nodes encountered, in log 2\n", "ly.plot_history_functions(log=2)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f3be1223-5f17-4795-9274-4e57278b88c5", "metadata": {}, "source": [ "In all the above you can also customize the function that is computed for each\n", "node, for instance to estimate FLOPs." ] }, { "cell_type": "code", "execution_count": 9, "id": "5d753bcf-8924-4904-a01f-e1c2c06730fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "72" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the peak memory required for all intermediates when\n", "# traversing the graph in computational order\n", "ly.history_peak_size()" ] }, { "cell_type": "code", "execution_count": 10, "id": "a043523b-0e69-45a4-a35c-2a482023484a", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the total memory required for all intermediates when\n", "# traversing the graph in computational order\n", "ly.plot_history_size_footprint()" ] }, { "cell_type": "code", "execution_count": 11, "id": "417bd002-50f1-42cd-8444-db4ba83cc134", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# plot the graph as circuit diagram\n", "ly.plot_circuit()" ] }, { "cell_type": "code", "execution_count": 12, "id": "f5914632-a626-4932-8210-a5d70ea278cd", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# plot the graph in using networkx or pygraphviz\n", "ly.plot_graph(layout=\"sfdp\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "3bb234b6-e434-4636-951a-99ce8af9fd9b", "metadata": {}, "source": [ "## Computation\n", "\n", "When you are ready to actually perform the computation, you can call\n", "[`LazyArray.compute`](autoray.lazy.LazyArray.compute) on output nodes:" ] }, { "cell_type": "code", "execution_count": 13, "id": "26ceedb0-209d-41eb-b2d6-6a7d5abf904c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.5805, -0.4512, -0.2807, 0.3313, 0.4838, 0.1920],\n", " [ 0.3426, 0.2975, 0.3121, 0.3296, 0.7331, -0.2251],\n", " [-0.3725, 0.1051, 0.3351, -0.7794, 0.3566, 0.0568],\n", " [ 0.6077, -0.3415, -0.4634, -0.3796, 0.3007, 0.2547],\n", " [-0.0159, 0.5119, -0.0306, 0.1317, 0.0139, 0.8481],\n", " [ 0.1933, -0.5641, 0.7042, 0.1128, -0.1033, 0.3538]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ly.compute()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "16182602-2385-4226-9db9-e3c6b82d8db0", "metadata": {}, "source": [ "```{note}\n", "[`LazyArray`](autoray.lazy.LazyArray) objects clear references to their\n", "dependencies once computed and simply store the result and shape. This is to\n", "aid garbage collection and reduce memory usage.\n", "```\n", "\n", "The computation is done non-recursively. You can compute multiple outputs at\n", "once with the function [`lazy.compute`](autoray.lazy.compute):" ] }, { "cell_type": "code", "execution_count": 14, "id": "8a3659a7-ac73-4e08-9c48-e3f6a6e23f76", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([-0.5805, -0.4512, -0.2807, 0.3313, 0.4838, 0.1920]),\n", " tensor([ 0.5805, 0.4512, 0.2807, -0.3313, -0.4838, -0.1920]))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lazy.compute([ly[0], -ly[0]])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2f13d343-7a55-40b6-9759-cf3df6670493", "metadata": {}, "source": [ "### Sharing intermediates\n", "\n", "A basic computational graph optimization that `autoray` can do is to\n", "automatically cache [`LazyArray`](autoray.lazy.LazyArray) objects that are\n", "computed with the same inputs. This is achieved with the context manager:\n", "\n", "* [`lazy.shared_intermediates`](autoray.lazy.shared_intermediates)" ] }, { "cell_type": "code", "execution_count": 15, "id": "974c05c1-6f16-4c03-847b-6bc3240cd78d", "metadata": {}, "outputs": [], "source": [ "with lazy.shared_intermediates():\n", " ly_shared = modified_gram_schmidt(lx)\n", "\n", "# reconstruct the non-shared lazy graph\n", "ly = modified_gram_schmidt(lx)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "dd8ce649-014c-404e-b05c-cfab4932e080", "metadata": {}, "source": [ "In this case you can see a slight reduction in the number of unique nodes:" ] }, { "cell_type": "code", "execution_count": 16, "id": "4772a614-9993-4596-aec2-f1a75a19c502", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(80, 70)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ly.history_num_nodes(), ly_shared.history_num_nodes()" ] }, { "cell_type": "code", "execution_count": 17, "id": "c78ed2f8-f62a-42b0-908e-be1a54cdd0c1", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ly_shared.plot_circuit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2fd8a573-36a4-491d-899c-a1e8dfdc071e", "metadata": {}, "source": [ "### `Function`, `Variable`, and constant folding\n", "\n", "Sometimes you may want to think of certain input nodes as variables, which\n", "might change from call to call, and any other inputs as constants. One option\n", "is to create 'empty'\n", "[`LazyArray`](autoray.lazy.LazyArray) instances with\n", "[`lazy.Variable`](autoray.lazy.Variable), which just\n", "takes a shape and optionally backend, and uses a placeholder for the data." ] }, { "cell_type": "code", "execution_count": 18, "id": "95ef490a-355d-4dea-8b40-f1decf0cd5b2", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lx = lazy.Variable((6, 6), backend=\"numpy\")\n", "\n", "ly = lazy.array(do(\"random.normal\", size=(6, 6), like=\"numpy\"))\n", "ly += ly.T\n", "ly **= 2\n", "\n", "lz = ly / (lx + 3)\n", "\n", "lz.plot_circuit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "14a60f19-39a3-4fbd-9048-80cdaf1c25bb", "metadata": {}, "source": [ "If we tried to call `lz.compute()` now, we would get an error relating to\n", "attempting to use the placeholder data, we would need to substitute it in\n", "first.\n", "\n", "However we can compute all the nodes that don't depend on the variable like so:" ] }, { "cell_type": "code", "execution_count": 19, "id": "18c89402-86a4-48c6-8c5b-d5a9896ff310", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lz.compute_constants(variables=[lx])\n", "\n", "# now all that remain is parts of the computational graph that depend on lx\n", "lz.plot_circuit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f3911bc3-3c32-43b7-88fd-17c5c6e7c939", "metadata": {}, "source": [ "If one wants to extract the function that the computational graph represents,\n", "in order to call it repeatedly with different inputs, then one can create a\n", "[`lazy.Function`](autoray.lazy.Function):" ] }, { "cell_type": "code", "execution_count": 20, "id": "85b18331-8858-4f8b-8995-87f3071a7c64", "metadata": {}, "outputs": [ { "data": { "text/plain": [ " array_like>" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f = lazy.Function(inputs=lx, outputs=lz)\n", "f" ] }, { "attachments": {}, "cell_type": "markdown", "id": "fb019d3c-0af3-430e-ac23-e2dc5f6f04c3", "metadata": {}, "source": [ "```{hint}\n", "By default, [`lazy.Function`](autoray.lazy.Function) will compute constants,\n", "as we did above, this can be disabled by passing `fold_constants=False`.\n", "```" ] }, { "cell_type": "code", "execution_count": 21, "id": "5134c1dd-8f30-484d-9064-03da90d5586a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1.60841992e-04, 8.12123221e-01, 1.97831460e+00, 1.13263167e-01,\n", " 2.33488356e-01, 4.08718533e-03],\n", " [2.18269736e-01, 1.04576049e+00, 5.29610184e-01, 6.86544110e-01,\n", " 4.30559858e-03, 1.45473766e+00],\n", " [3.47257338e+00, 1.07335438e+00, 1.32390470e-01, 7.11823564e-01,\n", " 1.11092912e-01, 2.23358234e+00],\n", " [1.87099522e-01, 5.45491213e-01, 7.87780835e-01, 1.07803626e+00,\n", " 5.32992554e-03, 9.09088651e-01],\n", " [2.97641208e-01, 3.39287590e-03, 5.68403161e-02, 4.84043497e-03,\n", " 9.40752758e-01, 2.27492826e+00],\n", " [2.94115949e-03, 1.52417836e+00, 1.55115629e+00, 6.83576849e-01,\n", " 5.76864117e-01, 1.09412185e-03]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create a numpy array\n", "x = do(\"random.normal\", size=(6, 6), like=\"numpy\")\n", "\n", "# now we can call it on a raw numpy array\n", "f(x)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e4d38a1f-615e-497a-8d5a-400d8b5d5e99", "metadata": {}, "source": [ "```{note}\n", "Such a function is created with the backend specific functions injected in, see\n", "[the compilation](compilation) section and [`autojit`](autoray.autojit) for\n", "creating such functions just in time for the right backend.\n", "```\n", "\n", "You can view the function's source code using\n", "[Function.print_source](autoray.lazy.Function.print_source), or extract it\n", "from [`LazyArray`](autoray.lazy.LazyArray) objects yourself with\n", "[lazy.get_source](autoray.lazy.get_source)." ] }, { "attachments": {}, "cell_type": "markdown", "id": "3f57cbb7-6e5e-4246-82e4-eeb263d357cb", "metadata": {}, "source": [ "## Comparison to alternatives:\n", "\n", "The main difference to other approaches is that `autoray` is super simple and\n", "lightweight, and is not concerned with complex optimizations or modes of\n", "execution.\n", "\n", "As demonstrated below, the dispatch mechanism in `autoray` is compatible with\n", "tensors objects from both these libraries, so it is not an either/or situation.\n", "The comparison is only with regard to when you might want to use lazy\n", "computational graph tracing.\n", "\n", "\n", "### `dask`\n", "\n", "There are many reasons to use [dask](https://dask.org/), but it incurs a pretty\n", "large overhead for big computational graphs with comparatively small\n", "operations. Calling and computing the ``modified_gram_schmidt`` function for a\n", "100x100 matrix (20,102 computational nodes) with ``dask.array`` takes ~1min\n", "whereas with ``lazy.array`` it takes ~0.2s:" ] }, { "cell_type": "code", "execution_count": null, "id": "4da8d638-0916-46e0-8824-81ea93e5ded0", "metadata": {}, "outputs": [], "source": [ "import dask.array as da\n", "\n", "x = do(\"random.normal\", size=(100, 100), like=\"numpy\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d77f1179-84d7-4ce7-96ad-65a09cdb7a7d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 58.5 s, sys: 315 ms, total: 58.8 s\n", "Wall time: 58.2 s\n" ] } ], "source": [ "%%time\n", "dx = da.array(x)\n", "dy = modified_gram_schmidt(dx)\n", "y = dy.compute()" ] }, { "cell_type": "code", "execution_count": null, "id": "f372f76e-17c5-4c15-970d-4463edaf3fba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 204 ms, sys: 12.1 ms, total: 216 ms\n", "Wall time: 208 ms\n" ] } ], "source": [ "%%time\n", "lx = lazy.array(x)\n", "ly = modified_gram_schmidt(lx)\n", "y = ly.compute()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "895531f0-1853-404c-98ef-f4ed997d3f82", "metadata": {}, "source": [ "Moreover `autoray.lazy` can also lazily wrap around more backends such\n", "as `torch` due to the [automatic dispatch](automatic_dispatch) mechanism." ] }, { "attachments": {}, "cell_type": "markdown", "id": "0ac03a5a-5e0c-4f58-9efc-b3d6ab2273ae", "metadata": {}, "source": [ "### `aesara`\n", "\n", "[`aesara`](https://aesara.readthedocs.io) is another nice library, and the\n", "successor to [`theano`](https://github.com/Theano/Theano). It is much more\n", "heavyweight than `autoray` with a focus on optimizations, symbolic\n", "manipulations such as gradients, and compilation to\n", "specific targets (`jax`, `numba` or `C`). It also supports dynamic shapes,\n", "whereas `autoray` restricts itself to static shapes.\n", "\n", "`aesara` is indeed quite compatible with `autoray`, but the fact that\n", "it often falls back to dynamic/unknown shapes occasionally makes things\n", "tricky." ] }, { "cell_type": "code", "execution_count": 15, "id": "e9bb5a8d-4d69-41f4-8f84-47104fb07f10", "metadata": {}, "outputs": [], "source": [ "import aesara\n", "import aesara.tensor as at\n", "\n", "# create equivalent of a Variable\n", "ax = at.tensor(\"float64\", (10, 10))" ] }, { "cell_type": "code", "execution_count": 16, "id": "efd4216b-a6da-4b73-88ef-b36e02e3bfc3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 175 ms, sys: 0 ns, total: 175 ms\n", "Wall time: 174 ms\n" ] } ], "source": [ "%%time\n", "# construct the computational graph\n", "ay = modified_gram_schmidt(ax)" ] }, { "cell_type": "code", "execution_count": 17, "id": "0efdb1b4-1799-4d11-a744-ceeed57468aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Join.0 (None, None)\n" ] } ], "source": [ "# aesara falls back to dynamic shapes quite often, which can be tricky\n", "print(ay, shape(ay))" ] }, { "cell_type": "code", "execution_count": 26, "id": "2140367c-52c7-4548-ae58-709ab52bbce8", "metadata": {}, "outputs": [], "source": [ "# # if you want to view the graph, you could use pydotprint:\n", "# aesara.printing.pydotprint(ay)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "69f89224-457a-4034-9939-564eea0728e8", "metadata": {}, "source": [ "Actually compiling the graph can take quite a long time for anything but\n", "quite small graphs (similarly to `jax`/XLA):" ] }, { "cell_type": "code", "execution_count": 19, "id": "d65362f2-f7ac-4412-8b1f-f993556b6a06", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5min 53s, sys: 189 ms, total: 5min 54s\n", "Wall time: 5min 54s\n" ] } ], "source": [ "%%time\n", "f = aesara.function([ax], ay)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1386648e-27ad-4d02-b12b-820ae934f069", "metadata": {}, "source": [ "However, the function produced, should be heavily optimized, and ought to be\n", "much faster than a pure python function for computations not dominated by large\n", "linear algebra operations." ] }, { "cell_type": "code", "execution_count": 25, "id": "06a43f44-d56c-49b2-b50f-df9ca470dc65", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.08811928, 0.45177107, -0.23512127, 0.56826279, -0.10209893,\n", " 0.04483325, 0.00321893, -0.06721217, 0.23305057, 0.58194387],\n", " [ 0.20286314, -0.16805831, -0.22599417, 0.12787872, -0.0876818 ,\n", " -0.34810252, -0.36127286, -0.46075707, -0.61894859, 0.09165501],\n", " [ 0.13412721, 0.36994813, 0.700074 , -0.06636773, -0.23315896,\n", " -0.44683645, 0.13946412, -0.26698148, 0.07190328, -0.02673288],\n", " [ 0.1407475 , 0.29877271, 0.14226686, -0.02852857, 0.04154398,\n", " 0.75386262, 0.07329936, -0.44735181, -0.25730248, -0.16773645],\n", " [ 0.04966494, 0.17005891, -0.26316584, -0.51207777, 0.319918 ,\n", " -0.1089547 , -0.302915 , -0.43376454, 0.48719139, 0.07516769],\n", " [ 0.02527629, -0.27882853, 0.49141276, 0.35693951, 0.63621879,\n", " 0.07151982, -0.34254243, -0.019437 , 0.0824143 , 0.13538384],\n", " [ 0.76372311, 0.15424796, -0.13639341, -0.06633823, 0.40270286,\n", " -0.13394064, 0.31463354, 0.2651581 , -0.12780612, -0.06467988],\n", " [ 0.46141628, -0.31660808, -0.04249231, 0.30744742, -0.37957198,\n", " 0.08494901, -0.18073615, -0.13298267, 0.46522187, -0.41527374],\n", " [-0.27891892, -0.11882063, -0.1926552 , 0.33001037, 0.3000039 ,\n", " -0.2147232 , 0.58955971, -0.43279871, 0.09319829, -0.28700755],\n", " [ 0.19026913, -0.54471619, 0.12261167, -0.24092508, -0.14976652,\n", " 0.14350703, 0.39536722, -0.21896414, 0.07494096, 0.58403977]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = do(\"random.normal\", size=(10, 10), like=\"numpy\")\n", "f(x)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a0964864-0efa-483f-ab08-6732f95fede0", "metadata": {}, "source": [ "Hopefully `aesara` will be another possible target for\n", "[`autoray.autojit`](autoray.autojit), eventually." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3" } }, "nbformat": 4, "nbformat_minor": 2 }