autoray.grad ============ .. py:module:: autoray.grad Attributes ---------- .. autoapisummary:: autoray.grad._GRAD_BACKENDS autoray.grad._NO_GRAD_BACKENDS Functions --------- .. autoapisummary:: autoray.grad.jax_stop_gradient autoray.grad.torch_stop_gradient autoray.grad.tensorflow_stop_gradient autoray.grad.paddle_stop_gradient autoray.grad.autograd_stop_gradient autoray.grad.do_nothing autoray.grad.stop_gradient Module Contents --------------- .. py:data:: _GRAD_BACKENDS .. py:data:: _NO_GRAD_BACKENDS .. py:function:: jax_stop_gradient(x) .. py:function:: torch_stop_gradient(x) .. py:function:: tensorflow_stop_gradient(x) .. py:function:: paddle_stop_gradient(x) .. py:function:: autograd_stop_gradient(x) .. py:function:: do_nothing(x) For backends without grad, keep x unchanged. .. py:function:: stop_gradient(x) Stop gradient flow through array ``x``. In autodiff backends (JAX, PyTorch, TensorFlow, etc.), this detaches ``x`` from the computational graph so that no gradients are propagated through it. For non-autodiff backends (NumPy, CuPy, etc.), this is a no-op and returns ``x`` unchanged. :param x: The array to stop gradient flow through. :type x: array :returns: An array with the same value as ``x`` but detached from the autodiff computational graph. :rtype: array .. rubric:: Examples With JAX:: >>> import jax, jax.numpy as jnp, autoray as ar >>> x = jnp.array([1.0, 2.0, 3.0]) >>> ar.stop_gradient(x) # equivalent to jax.lax.stop_gradient(x) With PyTorch:: >>> import torch, autoray as ar >>> x = torch.tensor([1.0, 2.0], requires_grad=True) >>> y = ar.stop_gradient(x) # equivalent to x.detach() >>> y.requires_grad False With NumPy (no-op):: >>> import numpy as np, autoray as ar >>> x = np.array([1.0, 2.0]) >>> ar.stop_gradient(x) is x True