Source code for tensortrax.math.linalg._linalg_tensor

"""
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

import numpy as np

from ..._tensor import Tensor, Δ, Δδ, einsum, f, matmul, δ
from .._math_tensor import exp, sqrt, sum, transpose
from ..special._special_tensor import ddot
from . import _linalg_array as linalg

dot = matmul


[docs] def det(A): "Determinant of a 2x2 or 3x3 Tensor." if isinstance(A, Tensor): x = linalg.det(f(A)) B = transpose(linalg.inv(f(A))) δx = x * ddot(B, δ(A)) Δx = x * ddot(B, Δ(A)) ΔB = -matmul(matmul(B, transpose(Δ(A))), B) Δδx = Δx * δx / x + x * ddot(ΔB, δ(A)) + x * ddot(B, Δδ(A)) return Tensor( x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=A.ntrax, ) else: return linalg.det(A)
[docs] def inv(A, inverse=linalg.inv): "Inverse of a 2x2 or 3x3 Tensor." if isinstance(A, Tensor): x = inverse(f(A)) invA = inverse(f(A)) δx = -matmul(matmul(invA, δ(A)), invA) Δx = -matmul(matmul(invA, Δ(A)), invA) Δδx = -( matmul(matmul(Δx, δ(A)), invA) + matmul(matmul(invA, δ(A)), Δx) + matmul(matmul(invA, Δδ(A)), invA) ) return Tensor( x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=A.ntrax, ) else: return linalg.inv(A)
[docs] def pinv(A): "Pseudo-Inverse of a 2x2 or 3x3 Tensor." return inv(A, inverse=linalg.pinv)
[docs] def eigvalsh(A, eps=np.sqrt(np.finfo(float).eps)): "Eigenvalues of a symmetric Tensor." if isinstance(A, Tensor): A[0, 0] += eps A[1, 1] -= eps λ, N = [x.T for x in np.linalg.eigh(f(A).T)] M = einsum("ai...,aj...->aij...", N, N) dim = len(λ) δλ = einsum("aij...,ij...->a...", M, δ(A)) Δλ = einsum("aij...,ij...->a...", M, Δ(A)) # alpha = [0, 1, 2] # beta = [(1, 2), (2, 0), (0, 1)] alpha = np.arange(dim) beta = [ np.concatenate([np.arange(a + 1, dim), np.arange(a)]) for a in np.arange(dim) ] δN = [] for α in alpha: δNα = [] for β in beta[α]: Mαβ = einsum("i...,j...->ij...", N[α], N[β]) δAαβ = einsum("ij...,ij...->...", Mαβ, δ(A)) λαβ = λ[α] - λ[β] δNα.append(1 / λαβ * N[β] * δAαβ) δN.append(sum(δNα, axis=0)) δM = einsum("ai...,aj...->aij...", δN, N) + einsum("ai...,aj...->aij...", N, δN) Δδλ = einsum("aij...,ij...->a...", δM, Δ(A)) + einsum( "aij...,ij...->a...", M, Δδ(A) ) return Tensor( x=λ, δx=δλ, Δx=Δλ, Δδx=Δδλ, ntrax=A.ntrax, ) else: return np.linalg.eigvalsh(A.T).T
[docs] def eigh(A, eps=np.sqrt(np.finfo(float).eps)): "Eigenvalues and -bases of a symmetric Tensor." if isinstance(A, Tensor): A[0, 0] += eps A[1, 1] -= eps λ, N = [x.T for x in np.linalg.eigh(f(A).T)] M = einsum("ai...,aj...->aij...", N, N) dim = len(λ) δλ = einsum("aij...,ij...->a...", M, δ(A)) Δλ = einsum("aij...,ij...->a...", M, Δ(A)) # alpha = [0, 1, 2] # beta = [(1, 2), (2, 0), (0, 1)] alpha = np.arange(dim) beta = [ np.concatenate([np.arange(a + 1, dim), np.arange(a)]) for a in np.arange(dim) ] δN = [] ΔN = [] for α in alpha: δNα = [] ΔNα = [] for β in beta[α]: Mαβ = einsum("i...,j...->ij...", N[α], N[β]) δAαβ = einsum("ij...,ij...->...", Mαβ, δ(A)) ΔAαβ = einsum("ij...,ij...->...", Mαβ, Δ(A)) λαβ = λ[α] - λ[β] δNα.append(1 / λαβ * N[β] * δAαβ) ΔNα.append(1 / λαβ * N[β] * ΔAαβ) δN.append(sum(δNα, axis=0)) ΔN.append(sum(ΔNα, axis=0)) ΔδN = [] for α in alpha: ΔδNα = [] for β in beta[α]: Mαβ = einsum("i...,j...->ij...", N[α], N[β]) δAαβ = einsum("ij...,ij...->...", Mαβ, δ(A)) ΔδAαβ = einsum("ij...,ij...->...", Mαβ, Δδ(A)) λαβ = λ[α] - λ[β] Δλαβ = Δλ[α] - Δλ[β] ΔδNα.append( -(λαβ**-2) * Δλαβ * N[β] * δAαβ + 1 / λαβ * ΔN[β] * δAαβ + 1 / λαβ * N[β] * ΔδAαβ ) ΔδN.append(sum(ΔδNα, axis=0)) δM = einsum("ai...,aj...->aij...", δN, N) + einsum("ai...,aj...->aij...", N, δN) ΔM = einsum("ai...,aj...->aij...", ΔN, N) + einsum("ai...,aj...->aij...", N, ΔN) Δδλ = einsum("aij...,ij...->a...", δM, Δ(A)) + einsum( "aij...,ij...->a...", M, Δδ(A) ) ΔδM = ( einsum("ai...,aj...->aij...", δN, ΔN) + einsum("ai...,aj...->aij...", ΔN, δN) + einsum("ai...,aj...->aij...", ΔδN, N) + einsum("ai...,aj...->aij...", N, ΔδN) ) return ( Tensor( x=λ, δx=δλ, Δx=Δλ, Δδx=Δδλ, ntrax=A.ntrax, ), Tensor( x=M, δx=δM, Δx=ΔM, Δδx=ΔδM, ntrax=A.ntrax, ), ) else: λ, N = [x.T for x in np.linalg.eigh(A.T)] M = einsum("ai...,aj...->aij...", N, N) return λ, M
[docs] def expm(A): "Compute the matrix exponential of a symmetric array." λ, M = eigh(A) return einsum("a...,aij...->ij...", exp(λ), M)
[docs] def sqrtm(A): "Compute the matrix square root of a symmetric array." λ, M = eigh(A) return einsum("a...,aij...->ij...", sqrt(λ), M)