autoray.grad

Attributes

Functions

jax_stop_gradient(x)

torch_stop_gradient(x)

tensorflow_stop_gradient(x)

paddle_stop_gradient(x)

autograd_stop_gradient(x)

do_nothing(x)

For backends without grad, keep x unchanged.

stop_gradient(x)

Stop gradient flow through array x.

Module Contents

autoray.grad._GRAD_BACKENDS
autoray.grad._NO_GRAD_BACKENDS
autoray.grad.jax_stop_gradient(x)[source]
autoray.grad.torch_stop_gradient(x)[source]
autoray.grad.tensorflow_stop_gradient(x)[source]
autoray.grad.paddle_stop_gradient(x)[source]
autoray.grad.autograd_stop_gradient(x)[source]
autoray.grad.do_nothing(x)[source]

For backends without grad, keep x unchanged.

autoray.grad.stop_gradient(x)[source]

Stop gradient flow through array x.

In autodiff backends (JAX, PyTorch, TensorFlow, etc.), this detaches x from the computational graph so that no gradients are propagated through it. For non-autodiff backends (NumPy, CuPy, etc.), this is a no-op and returns x unchanged.

Parameters:

x (array) – The array to stop gradient flow through.

Returns:

An array with the same value as x but detached from the autodiff computational graph.

Return type:

array

Examples

With JAX:

>>> import jax, jax.numpy as jnp, autoray as ar
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> ar.stop_gradient(x)  # equivalent to jax.lax.stop_gradient(x)

With PyTorch:

>>> import torch, autoray as ar
>>> x = torch.tensor([1.0, 2.0], requires_grad=True)
>>> y = ar.stop_gradient(x)  # equivalent to x.detach()
>>> y.requires_grad
False

With NumPy (no-op):

>>> import numpy as np, autoray as ar
>>> x = np.array([1.0, 2.0])
>>> ar.stop_gradient(x) is x
True