diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 58f8985ed..17c23e026 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -553,8 +553,7 @@ Base.@deprecate( if p[1] === codomainind(tsrc) && p[2] === domainind(tsrc) add!(tdst, tsrc, α, β) else - I = sectortype(tdst) - if I === Trivial + if has_array_view(tdst) && has_array_view(tsrc) TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend, allocator) else ntasks = use_threaded_transform(tdst, transformer) ? get_num_transformer_threads() : 1 diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 375b63768..52f4eb417 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -33,6 +33,13 @@ function _canonicalize(p::IndexTuple, t::AbstractTensorMap) return (p₁, p₂) end +# Whether a tensor can be viewed as a single contiguous array, such that +# the fusiontree machinery and act directly on the `t[]` view. +has_array_view(t) = has_array_view(typeof(t)) +has_array_view(::Type) = false +has_array_view(::Type{T}) where {T <: TensorMap} = sectortype(T) === Trivial +has_array_view(::Type{T}) where {T <: AdjointTensorMap} = has_array_view(parenttype(T)) + # tensoradd! function TO.tensoradd!( C::AbstractTensorMap, @@ -43,9 +50,9 @@ function TO.tensoradd!( if conjA A′ = adjoint(A) pA′ = adjointtensorindices(A, _canonicalize(pA, C)) - permute!(C, A′, pA′, α, β, backend) + permute!(C, A′, pA′, α, β, backend, allocator) else - permute!(C, A, _canonicalize(pA, C), α, β, backend) + permute!(C, A, _canonicalize(pA, C), α, β, backend, allocator) end return C end @@ -125,6 +132,10 @@ function TO.tensorcontract!( ) pAB′ = _canonicalize(pAB, C) @boundscheck spacecheck_contract(C, A, pA, conjA, B, pB, conjB, pAB′) + if has_array_view(C) && has_array_view(A) && has_array_view(B) + TO.tensorcontract!(C[], A[], pA, conjA, B[], pB, conjB, pAB′, α, β, backend, allocator) + return C + end if conjA && conjB A′ = A' pA′ = adjointtensorindices(A, pA) @@ -219,7 +230,7 @@ function trace_permute!( q₁ = $(q₁), q₂ = $(q₂)")) end - if I === Trivial + if has_array_view(tdst) && has_array_view(tsrc) TO.tensortrace!(tdst[], tsrc[], (p₁, p₂), (q₁, q₂), false, α, β, backend) else _trace_permute!(FusionStyle(I), tdst, tsrc, (p₁, p₂), (q₁, q₂), α, β, backend)