autoray.grad¶
Attributes¶
Functions¶
|
For backends without grad, keep x unchanged. |
Stop gradient flow through array |
Module Contents¶
- autoray.grad._GRAD_BACKENDS¶
- autoray.grad._NO_GRAD_BACKENDS¶
- autoray.grad.stop_gradient(x)[source]¶
Stop gradient flow through array
x.In autodiff backends (JAX, PyTorch, TensorFlow, etc.), this detaches
xfrom the computational graph so that no gradients are propagated through it. For non-autodiff backends (NumPy, CuPy, etc.), this is a no-op and returnsxunchanged.- Parameters:
x (array) – The array to stop gradient flow through.
- Returns:
An array with the same value as
xbut detached from the autodiff computational graph.- Return type:
array
Examples
With JAX:
>>> import jax, jax.numpy as jnp, autoray as ar >>> x = jnp.array([1.0, 2.0, 3.0]) >>> ar.stop_gradient(x) # equivalent to jax.lax.stop_gradient(x)
With PyTorch:
>>> import torch, autoray as ar >>> x = torch.tensor([1.0, 2.0], requires_grad=True) >>> y = ar.stop_gradient(x) # equivalent to x.detach() >>> y.requires_grad False
With NumPy (no-op):
>>> import numpy as np, autoray as ar >>> x = np.array([1.0, 2.0]) >>> ar.stop_gradient(x) is x True