Source code for tensortrax._tensor

"""
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

from copy import deepcopy

import numpy as np

from ._helpers import Δ, Δδ, f, δ


[docs] class Tensor: r"""A Hyper-Dual Tensor. Parameters ---------- x : array_like The data of the tensor. δx : array_like or None, optional (Dual) variation data (δ-operator) of the tensor. Δx : array_like or None, optional (Dual) variation data (Δ-operator) of the tensor. Δδx : array_like or None, optional (Dual) linearization data (Δδ-operator) of the tensor. ntrax : int, optional Total number of trailing axes including dual axes (default is 0). ndual : int, optional Number of axes containing dual data (default is 0). Examples -------- A hyper-dual tensor consists of batches of tensors, where the elementwise-operating batches are located at the last dimensions (trailing axes). For example, consider two second-order tensors in 3d-space as shown in Eq. :eq:`2d-tensors`. .. math:: :label: 2d-tensors \boldsymbol{T}_1 &= \begin{bmatrix} 0 & 3 & 6 \\ 1 & 4 & 7 \\ 2 & 5 & 8 \end{bmatrix} \boldsymbol{T}_2 &= \begin{bmatrix} 9 & 12 & 15 \\ 10 & 13 & 16 \\ 11 & 14 & 17 \end{bmatrix} This tensor is created with one elementwise-operating trailing axis. .. plot:: :context: >>> import tensortrax as tr >>> import numpy as np >>> >>> T = tr.Tensor(np.arange(18).reshape(2, 3, 3).T, ntrax=1) >>> T <tensortrax tensor object> Shape / size of trailing axes: (2,) / 1 Shape / size of tensor: (3, 3) / 9 <BLANKLINE> Data (first batch): <BLANKLINE> array([[0, 3, 6], [1, 4, 7], [2, 5, 8]]) A tensor, which is created without any dual data, is treated as constant and hence, all dual data arrays are initially filled with zeros. .. tip:: Broadcasting is used to enhance performance and to reduce memory consumption. If no dual data arrays are provided, the trailing axes of the dual arrays are compressed. .. plot:: :context: >>> T.x.shape (3, 3, 2) >>> T.δx.shape (3, 3, 1) If the gradient (or jacobian) w.r.t. this tensor should be recorded, the tensor must be initiated with the argument ``gradient=True``. This will add additional dual axes to the data arrays of the tensor. These axes are used for the gradient dimension. .. plot:: :context: >>> T.init(gradient=True) >>> T <tensortrax tensor object> Shape / size of trailing axes: (1, 1, 2) / 3 Shape / size of tensor: (3, 3) / 9 <BLANKLINE> Data (first batch): <BLANKLINE> array([[0, 3, 6], [1, 4, 7], [2, 5, 8]]) Again, broadcasting is used, see the shapes of the tensor data arrays. .. plot:: :context: >>> T.x.shape (3, 3, 1, 1, 2) >>> T.δx.shape (3, 3, 3, 3, 1) If we look at the dual array, this is the (non-symmetric) fourth-order identity tensor, see Eq. :eq:`dual-identity`. This identity tensor has values of one located at the component to be differentiated and is filled with zeros otherwise. .. math:: :label: dual-identity \frac{\partial T_{ij}}{\partial T_{kl}} = \delta_{ik}\ \delta_{jl} .. plot:: :context: >>> T.δx[0, 1, ..., 0] array([[0., 1., 0.], [0., 0., 0.], [0., 0., 0.]]) >>> T.δx[1, 2, ..., 0] array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.]]) >>> T.δx.reshape(9, 9) array([[1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 1., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 1.]]) Now let's perform some math, e.g. the dot product as shown in Eq. :eq:`dot`. Along with the evaluation of the dot-product, the partial derivative of :math:`\boldsymbol{V}` w.r.t. :math:`\boldsymbol{T}` is recorded in the dual data array. .. math:: :label: dot \boldsymbol{V} &= \boldsymbol{T}^T\ \boldsymbol{T} \delta \boldsymbol{V} &= \delta \boldsymbol{T}^T\ \boldsymbol{T} + \boldsymbol{T}^T\ \delta \boldsymbol{T} .. plot:: :context: >>> V = T.T @ T >>> V <tensortrax tensor object> Shape / size of trailing axes: (1, 1, 2) / 3 Shape / size of tensor: (3, 3) / 9 <BLANKLINE> Data (first batch): <BLANKLINE> array([[ 5, 14, 23], [ 14, 50, 86], [ 23, 86, 149]]) >>> V.δx[1, 2, ..., 0] array([[0., 6., 3.], [0., 7., 4.], [0., 8., 5.]]) See Also -------- tensortrax.function : Evaluate a function. tensortrax.gradient : Evaluate the gradient of a scalar-valued function. tensortrax.jacobian : Evaluate the jacobian of a function. tensortrax.hessian : Evaluate the hessian of a scalar-valued function. """ def __init__(self, x, δx=None, Δx=None, Δδx=None, ntrax=0, ndual=0): self.x = np.asarray(x) self.ntrax = ntrax self.ndual = ndual self.shape = self.x.shape[: len(self.x.shape) - ntrax] self.trax = self.x.shape[len(self.x.shape) - ntrax :] self.size = np.prod(self.shape, dtype=int) self.δx = self._init_and_reshape(δx) self.Δx = self._init_and_reshape(Δx) self.Δδx = self._init_and_reshape(Δδx) def _add(self, ndual=1): "Add additional trailing axes for dual values." # reset the tensor self.reset() # add new dual-related axes to the shape of the data ones = np.ones(ndual, dtype=int) x = self.x.reshape(*self.shape, *ones, *self.trax) # re-init the tensor self.__init__( x, δx=self.δx, Δx=self.Δx, Δδx=self.Δδx, ntrax=self.ntrax + ndual, ndual=ndual, )
[docs] def reset(self): "Reset the Tensor (set all dual values to zero)." # reset the shape of the data (delete axes related to dual data) x = self.x.reshape(*self.shape, *self.trax[self.ndual :]) # re-init the tensor self.__init__(x, ntrax=self.ntrax - self.ndual)
[docs] def copy(self): "Copy the Tensor." return deepcopy(self)
[docs] def init(self, gradient=False, hessian=False, sym=False, δx=None, Δx=None): """Re-Initialize tensor with dual values to keep track of the hessian and/or the gradient.""" if gradient and not hessian: # add additional element-wise acting axes for dual values self._add(ndual=len(self.shape)) # create the dual values if δx is None or isinstance(δx, bool): shape = (*self.shape, *self.shape) if len(shape) == 0: shape = (1,) if δx is False: δx = np.zeros(self.size**2).reshape(shape) else: δx = np.eye(self.size).reshape(shape) else: δx = δx.reshape(*self.shape, *self.trax) Δx = δx.copy() elif hessian: # add additional trailing axes for dual values self._add(ndual=2 * len(self.shape)) # create the dual values ones = np.ones(len(self.shape), dtype=int) if δx is None or isinstance(δx, bool): shape = (*self.shape, *self.shape, *ones) if len(shape) == 0: shape = (1,) if δx is False: δx = np.zeros(self.size**2).reshape(shape) else: δx = np.eye(self.size).reshape(shape) else: δx = δx.reshape(*self.shape, *self.trax) if Δx is None or isinstance(Δx, bool): shape = (*self.shape, *ones, *self.shape) if len(shape) == 0: shape = (1,) if Δx is False: Δx = np.zeros(self.size**2).reshape(shape) else: Δx = np.eye(self.size).reshape(shape) else: Δx = Δx.reshape(*self.shape, *self.trax) if gradient or hessian: if sym: idx_off_diag = {1: None, 3: [1], 6: [1, 2, 4]}[self.size] δx[idx_off_diag] /= 2 Δx[idx_off_diag] /= 2 # re-init the tensor self.__init__(self.x, δx=δx, Δx=Δx, ntrax=self.ntrax, ndual=self.ndual)
def _init_and_reshape(self, value): if value is None: value = np.zeros(self.shape) else: value = np.asarray(value) if len(value.shape) != len(self.x.shape): newshape = ( *value.shape, *np.ones(len(self.x.shape) - len(value.shape), dtype=int), ) value = value.reshape(newshape) return value def __neg__(self): return Tensor( x=-self.x, δx=-self.δx, Δx=-self.Δx, Δδx=-self.Δδx, ntrax=self.ntrax ) def __add__(self, B): A = self if isinstance(B, Tensor): x = f(A) + f(B) δx = δ(A) + δ(B) Δx = Δ(A) + Δ(B) Δδx = Δδ(A) + Δδ(B) ntrax = min(A.ntrax, B.ntrax) else: x = f(A) + B δx = δ(A) Δx = Δ(A) Δδx = Δδ(A) ntrax = A.ntrax return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax) def __sub__(self, B): A = self if isinstance(B, Tensor): x = f(A) - f(B) δx = δ(A) - δ(B) Δx = Δ(A) - Δ(B) Δδx = Δδ(A) - Δδ(B) ntrax = min(A.ntrax, B.ntrax) else: x = f(A) - B δx = δ(A) Δx = Δ(A) Δδx = Δδ(A) ntrax = A.ntrax return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax) def __rsub__(self, B): return -self.__sub__(B) def __mul__(self, B): A = self if isinstance(B, Tensor): x = f(A) * f(B) δx = δ(A) * f(B) + f(A) * δ(B) Δx = Δ(A) * f(B) + f(A) * Δ(B) Δδx = Δ(A) * δ(B) + δ(A) * Δ(B) + Δδ(A) * f(B) + f(A) * Δδ(B) ntrax = min(A.ntrax, B.ntrax) else: x = f(A) * B δx = δ(A) * B Δx = Δ(A) * B Δδx = Δδ(A) * B ntrax = A.ntrax return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax) def __truediv__(self, B): A = self return A * B**-1 def __rtruediv__(self, B): A = self return B * A**-1 def __pow__(self, p): A = self x = f(A) ** p δx = p * f(A) ** (p - 1) * δ(A) Δx = p * f(A) ** (p - 1) * Δ(A) Δδx = p * f(A) ** (p - 1) * Δδ(A) + p * (p - 1) * f(A) ** (p - 2) * δ(A) * Δ(A) return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=A.ntrax) def __gt__(self, B): A = self return f(A) > (f(B) if isinstance(B, Tensor) else B) def __lt__(self, B): A = self return f(A) < (f(B) if isinstance(B, Tensor) else B) def __ge__(self, B): A = self return f(A) >= (f(B) if isinstance(B, Tensor) else B) def __le__(self, B): A = self return f(A) <= (f(B) if isinstance(B, Tensor) else B) def __eq__(self, B): A = self return f(A) == (f(B) if isinstance(B, Tensor) else B) def __matmul__(self, B): return matmul(self, B) def __rmatmul__(self, B): return matmul(B, self) def __getitem__(self, key): x = f(self)[key] Δx = Δ(self)[key] δx = δ(self)[key] Δδx = Δδ(self)[key] return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=self.ntrax) def __setitem__(self, key, value): if isinstance(value, Tensor): shape = (*self.shape, *self.trax) self.x[key] = f(value) if self.δx[key].shape != δ(value).shape: self.δx = broadcast_to(self.δx, shape).copy() if self.Δx[key].shape != Δ(value).shape: self.Δx = broadcast_to(self.Δx, shape).copy() if self.Δδx[key].shape != Δδ(value).shape: self.Δδx = broadcast_to(self.Δδx, shape).copy() self.δx[key] = δ(value) self.Δx[key] = Δ(value) self.Δδx[key] = Δδ(value) else: self.x[key] = value self.δx[key].fill(0) self.Δx[key].fill(0) self.Δδx[key].fill(0) def __repr__(self): header = "<tensortrax tensor object>" metadata = [ f" Shape / size of trailing axes: {self.trax} / {self.ntrax}", f" Shape / size of tensor: {self.shape} / {self.size}", "", "Data (first batch):", ] first_batch = (0,) * self.ntrax data = self.x[(..., *first_batch)].__repr__() return "\n".join([header, *metadata, "", data]) @property def T(self): return transpose(self)
[docs] def ravel(self, order="C"): "Return a contiguous flattened array." return ravel(self, order=order)
[docs] def reshape(self, *shape, order="C"): "Gives a new shape to an array without changing its data." return reshape(self, shape, order=order)
[docs] def squeeze(self, axis=None): "Remove axes of length one." return squeeze(self, axis=axis)
[docs] def dual2real(self, like): return dual_to_real(self, like=like)
[docs] def dual_to_real(self, like): return dual_to_real(self, like=like)
[docs] def real_to_dual(self, x, mul=None): return real_to_dual(self, x=x, mul=mul)
[docs] def astype(self, dtype): return Tensor( x=self.x.astype(dtype), δx=self.δx.astype(dtype), Δx=self.Δx.astype(dtype), Δδx=self.Δδx.astype(dtype), ntrax=self.ntrax, )
__radd__ = __add__ __rmul__ = __mul__ __array_ufunc__ = None
[docs] def broadcast_to(A, shape): "Broadcast Array or Tensor to a new shape." def _broadcast_to(a): return np.broadcast_to(a, shape=shape) if isinstance(A, Tensor): return Tensor( x=_broadcast_to(f(A)), δx=_broadcast_to(δ(A)), Δx=_broadcast_to(Δ(A)), Δδx=_broadcast_to(Δδ(A)), ntrax=A.ntrax, ) else: return _broadcast_to(A)
[docs] def dual_to_real(A, like=None): """Return a new Tensor with old-dual data as new-real values, with `ntrax` derived by `like`.""" ndual = like.ndual - len(like.shape) ntrax = A.ntrax - ndual return Tensor(x=A.δx, δx=A.Δδx, Δx=A.Δδx, ndual=min(ndual, ntrax), ntrax=ntrax)
[docs] def real_to_dual(A, x, mul=None): r"""Return a new Tensor with old-real data as new-dual values. .. math:: \delta W(x) &= A(x)\ \delta x \Delta \delta W(x) &= \Delta A(x)\ \delta x + A(x)\ \Delta \delta x """ if mul is None: def mul(A, B): return A * B return Tensor( x=mul(f(A), f(x)) * np.nan, δx=mul(f(A), δ(x)), Δx=mul(f(A), Δ(x)), Δδx=mul(δ(A), Δ(x)) + mul(f(A), Δδ(x)), ntrax=A.ntrax, )
[docs] def ravel(A, order="C"): "Return a contiguous flattened array." if isinstance(A, Tensor): δtrax = δ(A).shape[len(A.shape) :] Δtrax = Δ(A).shape[len(A.shape) :] Δδtrax = Δδ(A).shape[len(A.shape) :] return Tensor( x=f(A).reshape(A.size, *A.trax, order=order), δx=δ(A).reshape(A.size, *δtrax, order=order), Δx=Δ(A).reshape(A.size, *Δtrax, order=order), Δδx=Δδ(A).reshape(A.size, *Δδtrax, order=order), ntrax=A.ntrax, ) else: return np.ravel(A, order=order)
[docs] def squeeze(A, axis=None): "Remove axes of length one." if isinstance(A, Tensor): if axis is None: if 1 in A.shape: axis = tuple(np.arange(len(A.shape))[np.array(A.shape) == 1]) else: axis = () return Tensor( x=np.squeeze(f(A), axis=axis), δx=np.squeeze(δ(A), axis=axis), Δx=np.squeeze(Δ(A), axis=axis), Δδx=np.squeeze(Δδ(A), axis=axis), ntrax=A.ntrax, ) else: return np.squeeze(A, axis=axis)
[docs] def reshape(A, shape, order="C"): "Gives a new shape to an array without changing its data." if isinstance(A, Tensor): δtrax = δ(A).shape[len(A.shape) :] Δtrax = Δ(A).shape[len(A.shape) :] Δδtrax = Δδ(A).shape[len(A.shape) :] return Tensor( x=f(A).reshape(*shape, *A.trax, order=order), δx=δ(A).reshape(*shape, *δtrax, order=order), Δx=Δ(A).reshape(*shape, *Δtrax, order=order), Δδx=Δδ(A).reshape(*shape, *Δδtrax, order=order), ntrax=A.ntrax, ) else: return np.reshape(A, shape, order=order)
def einsum4(subscripts, *operands): "Einsum with four operands." A, B, C, D = operands def _einsum(*operands): return np.einsum(subscripts, *operands) if ( isinstance(A, Tensor) and isinstance(B, Tensor) and isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(f(A), f(B), f(C), f(D)) δx = ( _einsum(δ(A), f(B), f(C), f(D)) + _einsum(f(A), δ(B), f(C), f(D)) + _einsum(f(A), f(B), δ(C), f(D)) + _einsum(f(A), f(B), f(C), δ(D)) ) Δx = ( _einsum(Δ(A), f(B), f(C), f(D)) + _einsum(f(A), Δ(B), f(C), f(D)) + _einsum(f(A), f(B), Δ(C), f(D)) + _einsum(f(A), f(B), f(C), Δ(D)) ) Δδx = ( _einsum(Δδ(A), f(B), f(C), f(D)) + _einsum(f(A), Δδ(B), f(C), f(D)) + _einsum(f(A), f(B), Δδ(C), f(D)) + _einsum(f(A), f(B), f(C), Δδ(D)) + _einsum(δ(A), Δ(B), f(C), f(D)) + _einsum(Δ(A), δ(B), f(C), f(D)) + _einsum(δ(A), f(B), Δ(C), f(D)) + _einsum(Δ(A), f(B), δ(C), f(D)) + _einsum(δ(A), f(B), f(C), Δ(D)) + _einsum(Δ(A), f(B), f(C), δ(D)) + _einsum(f(A), δ(B), Δ(C), f(D)) + _einsum(f(A), Δ(B), δ(C), f(D)) + _einsum(f(A), δ(B), f(C), Δ(D)) + _einsum(f(A), Δ(B), f(C), δ(D)) + _einsum(f(A), f(B), δ(C), Δ(D)) + _einsum(f(A), f(B), Δ(C), δ(D)) ) ntrax = min(A.ntrax, B.ntrax, C.ntrax, D.ntrax) elif ( isinstance(A, Tensor) and not isinstance(B, Tensor) and not isinstance(C, Tensor) and not isinstance(D, Tensor) ): x = _einsum(f(A), B, C, D) δx = _einsum(δ(A), B, C, D) Δx = _einsum(Δ(A), B, C, D) Δδx = _einsum(Δδ(A), B, C, D) ntrax = A.ntrax elif ( not isinstance(A, Tensor) and isinstance(B, Tensor) and not isinstance(C, Tensor) and not isinstance(D, Tensor) ): x = _einsum(A, f(B), C, D) δx = _einsum(A, δ(B), C, D) Δx = _einsum(A, Δ(B), C, D) Δδx = _einsum(A, Δδ(B), C, D) ntrax = B.ntrax elif ( not isinstance(A, Tensor) and not isinstance(B, Tensor) and isinstance(C, Tensor) and not isinstance(D, Tensor) ): x = _einsum(A, B, f(C), D) δx = _einsum(A, B, δ(C), D) Δx = _einsum(A, B, Δ(C), D) Δδx = _einsum(A, B, Δδ(C), D) ntrax = C.ntrax elif ( not isinstance(A, Tensor) and not isinstance(B, Tensor) and not isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(A, B, C, f(D)) δx = _einsum(A, B, C, δ(D)) Δx = _einsum(A, B, C, Δ(D)) Δδx = _einsum(A, B, C, Δδ(D)) ntrax = D.ntrax elif ( isinstance(A, Tensor) and isinstance(B, Tensor) and not isinstance(C, Tensor) and not isinstance(D, Tensor) ): x = _einsum(f(A), f(B), C, D) δx = _einsum(δ(A), f(B), C, D) + _einsum(f(A), δ(B), C, D) Δx = _einsum(Δ(A), f(B), C, D) + _einsum(f(A), Δ(B), C, D) Δδx = ( _einsum(Δδ(A), f(B), C, D) + _einsum(f(A), Δδ(B), C, D) + _einsum(δ(A), Δ(B), C, D) + _einsum(Δ(A), δ(B), C, D) ) ntrax = min(A.ntrax, B.ntrax) elif ( isinstance(A, Tensor) and not isinstance(B, Tensor) and isinstance(C, Tensor) and not isinstance(D, Tensor) ): x = _einsum(f(A), B, f(C), D) δx = _einsum(δ(A), B, f(C), D) + _einsum(f(A), B, δ(C), D) Δx = _einsum(Δ(A), B, f(C), D) + _einsum(f(A), B, Δ(C), D) Δδx = ( _einsum(Δδ(A), B, f(C), D) + _einsum(f(A), B, Δδ(C), D) + _einsum(δ(A), B, Δ(C), D) + _einsum(Δ(A), B, δ(C), D) ) ntrax = min(A.ntrax, C.ntrax) elif ( isinstance(A, Tensor) and not isinstance(B, Tensor) and not isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(f(A), B, C, f(D)) δx = _einsum(δ(A), B, C, f(D)) + _einsum(f(A), B, C, δ(D)) Δx = _einsum(Δ(A), B, C, f(D)) + _einsum(f(A), B, C, Δ(D)) Δδx = ( _einsum(Δδ(A), B, C, f(D)) + _einsum(f(A), B, C, Δδ(D)) + _einsum(δ(A), B, C, Δ(D)) + _einsum(Δ(A), B, C, δ(D)) ) ntrax = min(A.ntrax, D.ntrax) elif ( not isinstance(A, Tensor) and isinstance(B, Tensor) and isinstance(C, Tensor) and not isinstance(D, Tensor) ): x = _einsum(A, f(B), f(C), D) δx = _einsum(A, δ(B), f(C), D) + _einsum(A, f(B), δ(C), D) Δx = _einsum(A, Δ(B), f(C), D) + _einsum(A, f(B), Δ(C), D) Δδx = ( _einsum(A, Δδ(B), f(C), D) + _einsum(A, f(B), Δδ(C), D) + _einsum(A, δ(B), Δ(C), D) + _einsum(A, Δ(B), δ(C), D) ) ntrax = min(B.ntrax, C.ntrax) elif ( not isinstance(A, Tensor) and isinstance(B, Tensor) and not isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(A, f(B), C, f(D)) δx = _einsum(A, δ(B), C, f(D)) + _einsum(A, f(B), C, δ(D)) Δx = _einsum(A, Δ(B), C, f(D)) + _einsum(A, f(B), C, Δ(D)) Δδx = ( _einsum(A, Δδ(B), C, f(D)) + _einsum(A, f(B), C, Δδ(D)) + _einsum(A, δ(B), C, Δ(D)) + _einsum(A, Δ(B), C, δ(D)) ) ntrax = min(B.ntrax, D.ntrax) elif ( not isinstance(A, Tensor) and not isinstance(B, Tensor) and isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(A, B, f(C), f(D)) δx = _einsum(A, B, δ(C), f(D)) + _einsum(A, B, f(C), δ(D)) Δx = _einsum(A, B, Δ(C), f(D)) + _einsum(A, B, f(C), Δ(D)) Δδx = ( _einsum(A, B, Δδ(C), f(D)) + _einsum(A, B, f(C), Δδ(D)) + _einsum(A, B, δ(C), Δ(D)) + _einsum(A, B, Δ(C), δ(D)) ) ntrax = min(C.ntrax, D.ntrax) elif ( isinstance(A, Tensor) and isinstance(B, Tensor) and isinstance(C, Tensor) and not isinstance(D, Tensor) ): x = _einsum(f(A), f(B), f(C), D) δx = ( _einsum(δ(A), f(B), f(C), D) + _einsum(f(A), δ(B), f(C), D) + _einsum(f(A), f(B), δ(C), D) ) Δx = ( _einsum(Δ(A), f(B), f(C), D) + _einsum(f(A), Δ(B), f(C), D) + _einsum(f(A), f(B), Δ(C), D) ) Δδx = ( _einsum(Δδ(A), f(B), f(C), D) + _einsum(f(A), Δδ(B), f(C), D) + _einsum(f(A), f(B), Δδ(C), D) + _einsum(δ(A), Δ(B), f(C), D) + _einsum(Δ(A), δ(B), f(C), D) + _einsum(δ(A), f(B), Δ(C), D) + _einsum(Δ(A), f(B), δ(C), D) + _einsum(f(A), δ(B), Δ(C), D) + _einsum(f(A), Δ(B), δ(C), D) ) ntrax = min(A.ntrax, B.ntrax, C.ntrax) elif ( isinstance(A, Tensor) and isinstance(B, Tensor) and not isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(f(A), f(B), C, f(D)) δx = ( _einsum(δ(A), f(B), C, f(D)) + _einsum(f(A), δ(B), C, f(D)) + _einsum(f(A), f(B), C, δ(D)) ) Δx = ( _einsum(Δ(A), f(B), C, f(D)) + _einsum(f(A), Δ(B), C, f(D)) + _einsum(f(A), f(B), C, Δ(D)) ) Δδx = ( _einsum(Δδ(A), f(B), C, f(D)) + _einsum(f(A), Δδ(B), C, f(D)) + _einsum(f(A), f(B), C, Δδ(D)) + _einsum(δ(A), Δ(B), C, f(D)) + _einsum(Δ(A), δ(B), C, f(D)) + _einsum(δ(A), f(B), C, Δ(D)) + _einsum(Δ(A), f(B), C, δ(D)) + _einsum(f(A), δ(B), C, Δ(D)) + _einsum(f(A), Δ(B), C, δ(D)) ) ntrax = min(A.ntrax, B.ntrax, D.ntrax) elif ( isinstance(A, Tensor) and not isinstance(B, Tensor) and isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(f(A), B, f(C), f(D)) δx = ( _einsum(δ(A), B, f(C), f(D)) + _einsum(f(A), B, δ(C), f(D)) + _einsum(f(A), B, f(C), δ(D)) ) Δx = ( _einsum(Δ(A), B, f(C), f(D)) + _einsum(f(A), B, Δ(C), f(D)) + _einsum(f(A), B, f(C), Δ(D)) ) Δδx = ( _einsum(Δδ(A), B, f(C), f(D)) + _einsum(f(A), B, Δδ(C), f(D)) + _einsum(f(A), B, f(C), Δδ(D)) + _einsum(δ(A), B, Δ(C), f(D)) + _einsum(Δ(A), B, δ(C), f(D)) + _einsum(δ(A), B, f(C), Δ(D)) + _einsum(Δ(A), B, f(C), δ(D)) + _einsum(f(A), B, δ(C), Δ(D)) + _einsum(f(A), B, Δ(C), δ(D)) ) ntrax = min(A.ntrax, C.ntrax, D.ntrax) elif ( not isinstance(A, Tensor) and isinstance(B, Tensor) and isinstance(C, Tensor) and isinstance(D, Tensor) ): x = _einsum(A, f(B), f(C), f(D)) δx = ( _einsum(A, δ(B), f(C), f(D)) + _einsum(A, f(B), δ(C), f(D)) + _einsum(A, f(B), f(C), δ(D)) ) Δx = ( _einsum(A, Δ(B), f(C), f(D)) + _einsum(A, f(B), Δ(C), f(D)) + _einsum(A, f(B), f(C), Δ(D)) ) Δδx = ( _einsum(A, Δδ(B), f(C), f(D)) + _einsum(A, f(B), Δδ(C), f(D)) + _einsum(A, f(B), f(C), Δδ(D)) + _einsum(A, δ(B), Δ(C), f(D)) + _einsum(A, Δ(B), δ(C), f(D)) + _einsum(A, δ(B), f(C), Δ(D)) + _einsum(A, Δ(B), f(C), δ(D)) + _einsum(A, f(B), δ(C), Δ(D)) + _einsum(A, f(B), Δ(C), δ(D)) ) ntrax = min(B.ntrax, C.ntrax, D.ntrax) else: return _einsum(*operands) return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax) def einsum3(subscripts, *operands): "Einsum with three operands." A, B, C = operands def _einsum(*operands): return np.einsum(subscripts, *operands) if isinstance(A, Tensor) and isinstance(B, Tensor) and isinstance(C, Tensor): x = _einsum(f(A), f(B), f(C)) δx = ( _einsum(δ(A), f(B), f(C)) + _einsum(f(A), δ(B), f(C)) + _einsum(f(A), f(B), δ(C)) ) Δx = ( _einsum(Δ(A), f(B), f(C)) + _einsum(f(A), Δ(B), f(C)) + _einsum(f(A), f(B), Δ(C)) ) Δδx = ( _einsum(Δδ(A), f(B), f(C)) + _einsum(f(A), Δδ(B), f(C)) + _einsum(f(A), f(B), Δδ(C)) + _einsum(δ(A), Δ(B), f(C)) + _einsum(Δ(A), δ(B), f(C)) + _einsum(δ(A), f(B), Δ(C)) + _einsum(Δ(A), f(B), δ(C)) + _einsum(f(A), δ(B), Δ(C)) + _einsum(f(A), Δ(B), δ(C)) ) ntrax = min(A.ntrax, B.ntrax, C.ntrax) elif ( isinstance(A, Tensor) and not isinstance(B, Tensor) and not isinstance(C, Tensor) ): x = _einsum(f(A), B, C) δx = _einsum(δ(A), B, C) Δx = _einsum(Δ(A), B, C) Δδx = _einsum(Δδ(A), B, C) ntrax = A.ntrax elif ( not isinstance(A, Tensor) and isinstance(B, Tensor) and not isinstance(C, Tensor) ): x = _einsum(A, f(B), C) δx = _einsum(A, δ(B), C) Δx = _einsum(A, Δ(B), C) Δδx = _einsum(A, Δδ(B), C) ntrax = B.ntrax elif ( not isinstance(A, Tensor) and not isinstance(B, Tensor) and isinstance(C, Tensor) ): x = _einsum(A, B, f(C)) δx = _einsum(A, B, δ(C)) Δx = _einsum(A, B, Δ(C)) Δδx = _einsum(A, B, Δδ(C)) ntrax = C.ntrax elif isinstance(A, Tensor) and isinstance(B, Tensor) and not isinstance(C, Tensor): x = _einsum(f(A), f(B), C) δx = _einsum(δ(A), f(B), C) + _einsum(f(A), δ(B), C) Δx = _einsum(Δ(A), f(B), C) + _einsum(f(A), Δ(B), C) Δδx = ( _einsum(Δδ(A), f(B), C) + _einsum(f(A), Δδ(B), C) + _einsum(δ(A), Δ(B), C) + _einsum(Δ(A), δ(B), C) ) ntrax = min(A.ntrax, B.ntrax) elif isinstance(A, Tensor) and not isinstance(B, Tensor) and isinstance(C, Tensor): x = _einsum(f(A), B, f(C)) δx = _einsum(δ(A), B, f(C)) + _einsum(f(A), B, δ(C)) Δx = _einsum(Δ(A), B, f(C)) + _einsum(f(A), B, Δ(C)) Δδx = ( _einsum(Δδ(A), B, f(C)) + _einsum(f(A), B, Δδ(C)) + _einsum(δ(A), B, Δ(C)) + _einsum(Δ(A), B, δ(C)) ) ntrax = min(A.ntrax, C.ntrax) elif not isinstance(A, Tensor) and isinstance(B, Tensor) and isinstance(C, Tensor): x = _einsum(A, f(B), f(C)) δx = _einsum(A, δ(B), f(C)) + _einsum(A, f(B), δ(C)) Δx = _einsum(A, Δ(B), f(C)) + _einsum(A, f(B), Δ(C)) Δδx = ( _einsum(A, Δδ(B), f(C)) + _einsum(A, f(B), Δδ(C)) + _einsum(A, δ(B), Δ(C)) + _einsum(A, Δ(B), δ(C)) ) ntrax = min(B.ntrax, C.ntrax) else: return _einsum(*operands) return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax) def einsum2(subscripts, *operands): "Einsum with two operands." A, B = operands def _einsum(*operands): return np.einsum(subscripts, *operands) if isinstance(A, Tensor) and isinstance(B, Tensor): x = _einsum(f(A), f(B)) δx = _einsum(δ(A), f(B)) + _einsum(f(A), δ(B)) Δx = _einsum(Δ(A), f(B)) + _einsum(f(A), Δ(B)) Δδx = ( _einsum(Δδ(A), f(B)) + _einsum(f(A), Δδ(B)) + _einsum(δ(A), Δ(B)) + _einsum(Δ(A), δ(B)) ) ntrax = min(A.ntrax, B.ntrax) elif isinstance(A, Tensor) and not isinstance(B, Tensor): x = _einsum(f(A), B) δx = _einsum(δ(A), B) Δx = _einsum(Δ(A), B) Δδx = _einsum(Δδ(A), B) ntrax = A.ntrax elif not isinstance(A, Tensor) and isinstance(B, Tensor): x = _einsum(A, f(B)) δx = _einsum(A, δ(B)) Δx = _einsum(A, Δ(B)) Δδx = _einsum(A, Δδ(B)) ntrax = B.ntrax else: return _einsum(*operands) return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax) def einsum1(subscripts, *operands): "Einsum with one operand." A = operands[0] def _einsum(*operands): return np.einsum(subscripts, *operands) if isinstance(A, Tensor): x = _einsum(f(A)) δx = _einsum(δ(A)) Δx = _einsum(Δ(A)) Δδx = _einsum(Δδ(A)) return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=A.ntrax) else: return _einsum(*operands)
[docs] def einsum(subscripts, *operands): "Einsum limited to one, two, three or four operands." if len(operands) == 1: return einsum1(subscripts, *operands) elif len(operands) == 2: return einsum2(subscripts, *operands) elif len(operands) == 3: return einsum3(subscripts, *operands) elif len(operands) == 4: return einsum4(subscripts, *operands) else: raise NotImplementedError("More than four operands are not supported.")
def transpose(A): "Returns an array with axes transposed." ij = "abcdefghijklmnopqrstuvwxyz"[: len(A.shape)] ji = ij[::-1] return einsum(f"{ij}...->{ji}...", A)
[docs] def matmul(A, B): "Matrix product of two arrays." ik = "ik"[2 - len(A.shape) :] kj = "kj"[: len(B.shape)] ij = (ik + kj).replace("k", "") return einsum(f"{ik}...,{kj}...->{ij}...", A, B)