// SPDX-License-Identifier: Apache-2.0 // // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ------------------------------------------------------------------------ //! \addtogroup op_dot //! @{ //! for two arrays, generic version for non-complex values template arma_inline typename arma_not_cx::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_debug_sigprint(); #if defined(__FAST_MATH__) { eT val = eT(0); for(uword i=0; i inline typename arma_cx_only::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_debug_sigprint(); typedef typename get_pod_type::result T; T val_real = T(0); T val_imag = T(0); for(uword i=0; i& X = A[i]; const std::complex& Y = B[i]; const T a = X.real(); const T b = X.imag(); const T c = Y.real(); const T d = Y.imag(); val_real += (a*c) - (b*d); val_imag += (a*d) + (b*c); } return std::complex(val_real, val_imag); } //! for two arrays, float and double version template inline typename arma_real_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { arma_debug_sigprint(); if( n_elem <= 32u ) { return op_dot::direct_dot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_ATLAS) { arma_debug_print("atlas::cblas_dot()"); return atlas::cblas_dot(n_elem, A, B); } #elif defined(ARMA_USE_BLAS) { arma_debug_print("blas::dot()"); return blas::dot(n_elem, A, B); } #else { return op_dot::direct_dot_arma(n_elem, A, B); } #endif } } //! for two arrays, complex version template inline typename arma_cx_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { if( n_elem <= 16u ) { return op_dot::direct_dot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_ATLAS) { arma_debug_print("atlas::cblas_cx_dot()"); return atlas::cblas_cx_dot(n_elem, A, B); } #elif defined(ARMA_USE_BLAS) { arma_debug_print("blas::dot()"); return blas::dot(n_elem, A, B); } #else { return op_dot::direct_dot_arma(n_elem, A, B); } #endif } } //! for two arrays, integral version template inline typename arma_integral_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { return op_dot::direct_dot_arma(n_elem, A, B); } //! for three arrays template inline eT op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) { arma_debug_sigprint(); eT val = eT(0); for(uword i=0; i inline typename T1::elem_type op_dot::apply(const T1& X, const T2& Y) { arma_debug_sigprint(); typedef typename T1::elem_type eT; if(is_subview_row::value && is_subview_row::value) { const subview_row& A = reinterpret_cast< const subview_row& >(X); const subview_row& B = reinterpret_cast< const subview_row& >(Y); if( (A.m.n_rows == 1) && (B.m.n_rows == 1) ) { arma_debug_print("op_dot::apply(): subview_row optimisation"); arma_conform_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" ); const eT* A_mem = A.m.memptr(); const eT* B_mem = B.m.memptr(); return op_dot::direct_dot(A.n_elem, &A_mem[A.aux_col1], &B_mem[B.aux_col1]); } } if(is_subview::value || is_subview::value) { arma_debug_print("op_dot::apply(): subview optimisation"); const sv_keep_unwrap& UA(X); const sv_keep_unwrap& UB(Y); typedef typename sv_keep_unwrap::stored_type UA_M_type; typedef typename sv_keep_unwrap::stored_type UB_M_type; const UA_M_type& A = UA.M; const UB_M_type& B = UB.M; const uword A_n_rows = A.n_rows; const uword A_n_cols = A.n_cols; if( (A_n_rows == B.n_rows) && (A_n_cols == B.n_cols) ) { eT acc = eT(0); for(uword c=0; c < A_n_cols; ++c) { acc += op_dot::direct_dot(A_n_rows, A.colptr(c), B.colptr(c)); } return acc; } else { const quasi_unwrap UUA(A); const quasi_unwrap UUB(B); arma_conform_check( (UUA.M.n_elem != UUB.M.n_elem), "dot(): objects must have the same number of elements" ); return op_dot::direct_dot(UUA.M.n_elem, UUA.M.memptr(), UUB.M.memptr()); } } // if possible, bypass transposes of non-complex vectors if( (is_cx::no) && (resolves_to_vector::value) && (resolves_to_vector::value) && (partial_unwrap::is_fast) && (partial_unwrap::is_fast) ) { arma_debug_print("op_dot::apply(): vector optimisation"); const partial_unwrap UA(X); const partial_unwrap UB(Y); const typename partial_unwrap::stored_type& A = UA.M; const typename partial_unwrap::stored_type& B = UB.M; arma_conform_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" ); const eT val = op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr()); return (UA.do_times || UB.do_times) ? (val * UA.get_val() * UB.get_val()) : val; } constexpr bool proxy_is_mat = (is_Mat::stored_type>::value && is_Mat::stored_type>::value); constexpr bool use_at = (Proxy::use_at) || (Proxy::use_at); constexpr bool have_direct_mem = (quasi_unwrap::has_orig_mem) && (quasi_unwrap::has_orig_mem); if(proxy_is_mat || use_at || have_direct_mem) { arma_debug_print("op_dot::apply(): direct_mem optimisation"); const quasi_unwrap A(X); const quasi_unwrap B(Y); arma_conform_check( (A.M.n_elem != B.M.n_elem), "dot(): objects must have the same number of elements" ); return op_dot::direct_dot(A.M.n_elem, A.M.memptr(), B.M.memptr()); } const Proxy PA(X); const Proxy PB(Y); arma_conform_check( (PA.get_n_elem() != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); return op_dot::apply_proxy_linear(PA,PB); } template inline typename arma_not_cx::result op_dot::apply_proxy_linear(const Proxy& PA, const Proxy& PB) { arma_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename Proxy::ea_type ea_type1; typedef typename Proxy::ea_type ea_type2; const uword N = PA.get_n_elem(); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); eT val1 = eT(0); eT val2 = eT(0); uword i,j; for(i=0, j=1; j inline typename arma_cx_only::result op_dot::apply_proxy_linear(const Proxy& PA, const Proxy& PB) { arma_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename get_pod_type::result T; typedef typename Proxy::ea_type ea_type1; typedef typename Proxy::ea_type ea_type2; const uword N = PA.get_n_elem(); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); T val_real = T(0); T val_imag = T(0); for(uword i=0; i xx = A[i]; const std::complex yy = B[i]; const T a = xx.real(); const T b = xx.imag(); const T c = yy.real(); const T d = yy.imag(); val_real += (a*c) - (b*d); val_imag += (a*d) + (b*c); } return std::complex(val_real, val_imag); } // // op_norm_dot template inline typename T1::elem_type op_norm_dot::apply(const T1& X, const T2& Y) { arma_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; const quasi_unwrap tmp1(X); const quasi_unwrap tmp2(Y); const Col A( const_cast(tmp1.M.memptr()), tmp1.M.n_elem, false ); const Col B( const_cast(tmp2.M.memptr()), tmp2.M.n_elem, false ); arma_conform_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" ); const T denom = norm(A,2) * norm(B,2); return (denom != T(0)) ? ( op_dot::apply(A,B) / denom ) : eT(0); } // // op_cdot template inline eT op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_debug_sigprint(); typedef typename get_pod_type::result T; T val_real = T(0); T val_imag = T(0); for(uword i=0; i& X = A[i]; const std::complex& Y = B[i]; const T a = X.real(); const T b = X.imag(); const T c = Y.real(); const T d = Y.imag(); val_real += (a*c) + (b*d); val_imag += (a*d) - (b*c); } return std::complex(val_real, val_imag); } template inline eT op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) { arma_debug_sigprint(); if( n_elem <= 32u ) { return op_cdot::direct_cdot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_BLAS) { arma_debug_print("blas::gemv()"); // using gemv() workaround due to compatibility issues with cdotc() and zdotc() const char trans = 'C'; const blas_int m = blas_int(n_elem); const blas_int n = 1; //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); const blas_int inc = 1; const eT alpha = eT(1); const eT beta = eT(0); eT result[2]; // paranoia: using two elements instead of one //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc); blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc); return result[0]; } #else { return op_cdot::direct_cdot_arma(n_elem, A, B); } #endif } } template inline typename T1::elem_type op_cdot::apply(const T1& X, const T2& Y) { arma_debug_sigprint(); if(is_Mat::value && is_Mat::value) { return op_cdot::apply_unwrap(X,Y); } else { return op_cdot::apply_proxy(X,Y); } } template inline typename T1::elem_type op_cdot::apply_unwrap(const T1& X, const T2& Y) { arma_debug_sigprint(); typedef typename T1::elem_type eT; const unwrap tmp1(X); const unwrap tmp2(Y); const Mat& A = tmp1.M; const Mat& B = tmp2.M; arma_conform_check( (A.n_elem != B.n_elem), "cdot(): objects must have the same number of elements" ); return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem ); } template inline typename T1::elem_type op_cdot::apply_proxy(const T1& X, const T2& Y) { arma_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename get_pod_type::result T; typedef typename Proxy::ea_type ea_type1; typedef typename Proxy::ea_type ea_type2; constexpr bool use_at = (Proxy::use_at) || (Proxy::use_at); if(use_at == false) { const Proxy PA(X); const Proxy PB(Y); const uword N = PA.get_n_elem(); arma_conform_check( (N != PB.get_n_elem()), "cdot(): objects must have the same number of elements" ); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); T val_real = T(0); T val_imag = T(0); for(uword i=0; i AA = A[i]; const std::complex BB = B[i]; const T a = AA.real(); const T b = AA.imag(); const T c = BB.real(); const T d = BB.imag(); val_real += (a*c) + (b*d); val_imag += (a*d) - (b*c); } return std::complex(val_real, val_imag); } else { return op_cdot::apply_unwrap( X, Y ); } } template inline typename promote_type::result op_dot_mixed::apply(const T1& A, const T2& B) { arma_debug_sigprint(); typedef typename T1::elem_type in_eT1; typedef typename T2::elem_type in_eT2; typedef typename promote_type::result out_eT; const Proxy PA(A); const Proxy PB(B); const uword N = PA.get_n_elem(); arma_conform_check( (N != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); out_eT acc = out_eT(0); for(uword i=0; i < N; ++i) { acc += upgrade_val::apply(PA[i]) * upgrade_val::apply(PB[i]); } return acc; } //! @}