Source code for tensortrax.math._math_tensor

"""
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""
import numpy as np

from .._tensor import Tensor, Δ, Δδ, broadcast_to, einsum, f, matmul, δ

dot = matmul


[docs] def array(object, dtype=None, like=None, shape=None): """Create a tensor or an array from another tensor, an array or from a list/tuple of tensors or arrays. Parameters ---------- object : tensortrax.Tensor, array_like, list or tuple of tensortrax.Tensor or list or tuple of array_like The object from which the array is created. dtype : data-type or None, optional Data-type of the array(s). Default is None. like : tensortrax.Tensor or None, optional Reference tensor for shape and (number of) trailing axes. Default is None. Only considered if ``object`` is not a tensor. shape : tuple of int or None, optional The shape of the data of the tensor (without shape of trailing axes). If None, the shape is taken from ``like``. . Only considered if ``object`` is not a tensor. Returns ------- tensortrax.Tensor or ndarray The return type depends on the type of ``object``. """ if isinstance(object, Tensor): return Tensor( x=np.array(f(object), dtype=dtype), δx=np.array(δ(object), dtype=dtype), Δx=np.array(Δ(object), dtype=dtype), Δδx=np.array(Δδ(object), dtype=dtype), ntrax=object.ntrax, ) elif isinstance(object, list) or isinstance(object, tuple): if isinstance(object[0], Tensor): return Tensor( x=np.array([f(o) for o in object], dtype=dtype), δx=np.array([δ(o) for o in object], dtype=dtype), Δx=np.array([Δ(o) for o in object], dtype=dtype), Δδx=np.array([Δδ(o) for o in object], dtype=dtype), ntrax=min([o.ntrax for o in object]), ) else: return np.array(object, dtype=dtype) else: if like is None: return np.array(object, dtype=dtype) else: x = np.array(object, dtype=dtype) if shape is None: shape = like.shape return Tensor(x=x.reshape(*shape, *like.trax), ntrax=like.ntrax)
[docs] def trace(A): "Return the sum along diagonals of the array." return einsum("ii...->...", A)
[docs] def transpose(A): "Returns an array with axes transposed." return einsum("ij...->ji...", A)
[docs] def sum(A, axis=0): "Sum of array elements over a given axis." if isinstance(A, Tensor): return Tensor( x=np.sum(f(A), axis=axis), δx=np.sum(δ(A), axis=axis), Δx=np.sum(Δ(A), axis=axis), Δδx=np.sum(Δδ(A), axis=axis), ntrax=A.ntrax, ) else: return np.sum(A, axis=axis)
[docs] def sign(A): "Returns an element-wise indication of the sign of a number." if isinstance(A, Tensor): return Tensor( x=np.sign(f(A)), δx=0 * δ(A), Δx=0 * Δ(A), Δδx=0 * Δδ(A), ntrax=A.ntrax, ) else: return np.sign(A)
[docs] def abs(A): "Calculate the absolute value element-wise." if isinstance(A, Tensor): return Tensor( x=np.abs(f(A)), δx=np.sign(f(A)) * δ(A), Δx=np.sign(f(A)) * Δ(A), Δδx=np.sign(f(A)) * Δδ(A), ntrax=A.ntrax, ) else: return np.abs(A)
[docs] def sqrt(A): "Return the non-negative square-root of an array, element-wise." if isinstance(A, Tensor): return A**0.5 else: return np.sqrt(A)
[docs] def sin(A): "Trigonometric sine, element-wise." if isinstance(A, Tensor): return Tensor( x=np.sin(f(A)), δx=np.cos(f(A)) * δ(A), Δx=np.cos(f(A)) * Δ(A), Δδx=-np.sin(f(A)) * δ(A) * Δ(A) + np.cos(f(A)) * Δδ(A), ntrax=A.ntrax, ) else: return np.sin(A)
[docs] def cos(A): "Cosine element-wise." if isinstance(A, Tensor): return Tensor( x=np.cos(f(A)), δx=-np.sin(f(A)) * δ(A), Δx=-np.sin(f(A)) * Δ(A), Δδx=-np.cos(f(A)) * δ(A) * Δ(A) - np.sin(f(A)) * Δδ(A), ntrax=A.ntrax, ) else: return np.cos(A)
[docs] def tan(A): "Compute tangent element-wise." if isinstance(A, Tensor): return Tensor( x=np.tan(f(A)), δx=np.cos(f(A)) ** -2 * δ(A), Δx=np.cos(f(A)) ** -2 * Δ(A), Δδx=2 * np.tan(f(A)) * np.cos(f(A)) ** -2 * δ(A) * Δ(A) + np.cos(f(A)) ** -2 * Δδ(A), ntrax=A.ntrax, ) else: return np.tan(A)
[docs] def sinh(A): "Hyperbolic sine, element-wise." if isinstance(A, Tensor): return Tensor( x=np.sinh(f(A)), δx=np.cosh(f(A)) * δ(A), Δx=np.cosh(f(A)) * Δ(A), Δδx=np.sinh(f(A)) * δ(A) * Δ(A) + np.cosh(f(A)) * Δδ(A), ntrax=A.ntrax, ) else: return np.sinh(A)
[docs] def cosh(A): "Hyperbolic cosine, element-wise." if isinstance(A, Tensor): return Tensor( x=np.cosh(f(A)), δx=np.sinh(f(A)) * δ(A), Δx=np.sinh(f(A)) * Δ(A), Δδx=np.cosh(f(A)) * δ(A) * Δ(A) + np.sinh(f(A)) * Δδ(A), ntrax=A.ntrax, ) else: return np.cosh(A)
[docs] def tanh(A): "Compute hyperbolic tangent element-wise." if isinstance(A, Tensor): x = np.tanh(f(A)) return Tensor( x=x, δx=(1 - x**2) * δ(A), Δx=(1 - x**2) * Δ(A), Δδx=-2 * x * (1 - x**2) * δ(A) * Δ(A) + (1 - x**2) * Δδ(A), ntrax=A.ntrax, ) else: return np.tanh(A)
[docs] def exp(A): "Calculate the exponential of all elements in the input array." if isinstance(A, Tensor): x = np.exp(f(A)) return Tensor( x=x, δx=x * δ(A), Δx=x * Δ(A), Δδx=x * δ(A) * Δ(A) + x * Δδ(A), ntrax=A.ntrax, ) else: return np.exp(A)
[docs] def log(A): "Natural logarithm, element-wise." if isinstance(A, Tensor): x = np.log(f(A)) return Tensor( x=x, δx=1 / f(A) * δ(A), Δx=1 / f(A) * Δ(A), Δδx=-1 / f(A) ** 2 * δ(A) * Δ(A) + 1 / f(A) * Δδ(A), ntrax=A.ntrax, ) else: return np.log(A)
[docs] def log10(A): "Return the base 10 logarithm of the input array, element-wise." if isinstance(A, Tensor): x = np.log10(f(A)) return Tensor( x=x, δx=1 / (np.log(10) * f(A)) * δ(A), Δx=1 / (np.log(10) * f(A)) * Δ(A), Δδx=-1 / (np.log(10) * f(A) ** 2) * δ(A) * Δ(A) + 1 / (np.log(10) * f(A)) * Δδ(A), ntrax=A.ntrax, ) else: return np.log10(A)
[docs] def diagonal(A, offset=0, axis1=0, axis2=1): "Return specified diagonals." kwargs = dict(offset=offset, axis1=axis1, axis2=axis2) if isinstance(A, Tensor): return Tensor( x=np.diagonal(f(A), **kwargs).T, δx=np.diagonal(δ(A), **kwargs).T, Δx=np.diagonal(Δ(A), **kwargs).T, Δδx=np.diagonal(Δδ(A), **kwargs).T, ntrax=A.ntrax, ) else: return np.diagonal(A, **kwargs).T
[docs] def tile(A, reps): "Construct an array by repeating A the number of times given by reps." if isinstance(A, Tensor): return Tensor( x=np.tile(f(A), reps=reps), δx=np.tile(δ(A), reps=reps), Δx=np.tile(Δ(A), reps=reps), Δδx=np.tile(Δδ(A), reps=reps), ntrax=A.ntrax, ) else: return np.tile(A, reps=reps)
[docs] def repeat(a, repeats, axis=None): "Repeat elements of an array." if isinstance(a, Tensor): return Tensor( x=np.repeat(f(a), repeats=repeats, axis=axis), δx=np.repeat(δ(a), repeats=repeats, axis=axis), Δx=np.repeat(Δ(a), repeats=repeats, axis=axis), Δδx=np.repeat(Δδ(a), repeats=repeats, axis=axis), ntrax=a.ntrax, ) else: return np.repeat(a, repeats=repeats, axis=axis)
[docs] def hstack(tup): "Stack arrays in sequence horizontally (column wise)." if isinstance(tup[0], Tensor): return Tensor( x=np.hstack([f(A) for A in tup]), δx=np.hstack([δ(A) for A in tup]), Δx=np.hstack([Δ(A) for A in tup]), Δδx=np.hstack([Δδ(A) for A in tup]), ntrax=min([A.ntrax for A in tup]), ) else: return np.hstack(tup)
[docs] def vstack(tup): "Stack arrays in sequence vertically (row wise)." if isinstance(tup[0], Tensor): return Tensor( x=np.vstack([f(A) for A in tup]), δx=np.vstack([δ(A) for A in tup]), Δx=np.vstack([Δ(A) for A in tup]), Δδx=np.vstack([Δδ(A) for A in tup]), ntrax=min([A.ntrax for A in tup]), ) else: return np.vstack(tup)
[docs] def stack(arrays, axis=0): "Join a sequence of arrays along a new axis." if isinstance(arrays[0], Tensor): return Tensor( x=np.stack([f(A) for A in arrays], axis=axis), δx=np.stack([δ(A) for A in arrays], axis=axis), Δx=np.stack([Δ(A) for A in arrays], axis=axis), Δδx=np.stack([Δδ(A) for A in arrays], axis=axis), ntrax=min([A.ntrax for A in arrays]), ) else: return np.stack(arrays, axis=axis)
[docs] def concatenate(arrays, axis=0): "Join a sequence of arrays along an existing axis." if isinstance(arrays[0], Tensor): return Tensor( x=np.concatenate([f(A) for A in arrays], axis=axis), δx=np.concatenate([δ(A) for A in arrays], axis=axis), Δx=np.concatenate([Δ(A) for A in arrays], axis=axis), Δδx=np.concatenate([Δδ(A) for A in arrays], axis=axis), ntrax=min([A.ntrax for A in arrays]), ) else: return np.concatenate(arrays, axis=axis)
[docs] def split(ary, indices_or_sections, axis=0): "Split an array into multiple sub-arrays as views into ary." if isinstance(ary, Tensor): xs = np.split(f(ary), indices_or_sections=indices_or_sections, axis=axis) δxs = np.split(δ(ary), indices_or_sections=indices_or_sections, axis=axis) Δxs = np.split(Δ(ary), indices_or_sections=indices_or_sections, axis=axis) Δδxs = np.split(Δδ(ary), indices_or_sections=indices_or_sections, axis=axis) return [ Tensor(x, δx, Δx, Δδx, ntrax=ary.ntrax) for x, δx, Δx, Δδx in zip(xs, δxs, Δxs, Δδxs) ] else: return np.split(ary, indices_or_sections=indices_or_sections, axis=axis)
[docs] def external(x, function, gradient, hessian, indices="ij", *args, **kwargs): """Evaluate the Tensor returned by an external scalar-valued function, evaluated at a given value `x`, with provided gradient and hessian which operates on the values of a tensor and optional arguments. All math methods inside the external function/gradient/hessian must handle arbitrary number of elementwise-operating trailing axes. """ # pre-evaluate the scalar-valued function along with its gradient and hessian if isinstance(x, Tensor): func = function(f(x), *args, **kwargs) grad = gradient(f(x), *args, **kwargs) hess = hessian(f(x), *args, **kwargs) def gvp(g, v, ntrax): "Evaluate the gradient-vector product." ij = indices.lower() return einsum(f"{ij}...,{ij}...->...", g, v) def hvp(h, v, u, ntrax): "Evaluate the hessian-vectors product." ij = indices.lower() kl = indices.upper() return einsum(f"{ij}{kl}...,{ij}...,{kl}...->...", h, v, u) if isinstance(x, Tensor): return Tensor( x=func, δx=gvp(grad, δ(x), x.ntrax), Δx=gvp(grad, Δ(x), x.ntrax), Δδx=hvp(hess, δ(x), Δ(x), x.ntrax) + gvp(grad, Δδ(x), x.ntrax), ntrax=x.ntrax, ) else: return function(x, *args, **kwargs)
[docs] def if_else(cond, true, false): "Mask-based Condition for arrays and tensors." mask = np.asarray(cond) out = true.copy() if isinstance(true, np.ndarray) and isinstance(false, np.ndarray): out = true.copy() out[..., mask] = true[..., mask] out[..., ~mask] = false[..., ~mask] elif isinstance(true, Tensor) and isinstance(false, Tensor): shape = np.maximum.reduce( [ true.x.shape, true.δx.shape, true.Δx.shape, true.Δδx.shape, false.x.shape, false.δx.shape, false.Δx.shape, false.Δδx.shape, ] ) out = broadcast_to(true, shape=shape).copy() mask = np.broadcast_to(mask, shape) out[..., ~mask] = broadcast_to(false, shape=shape)[..., ~mask] else: raise NotImplementedError( "`true` and `false` must be both arrays or both tensors." ) return out
[docs] def maximum(x1, x2): "Element-wise maximum of array elements." if isinstance(x1, Tensor): return if_else(x1 > x2, x1, x2) else: return np.maximum(x1, x2)
[docs] def minimum(x1, x2): "Element-wise minimum of array elements." if isinstance(x1, Tensor): return if_else(x1 < x2, x1, x2) else: return np.minimum(x1, x2)