10#include "blas/syr2k.hh"
90template <
typename TA,
typename TB,
typename TC>
96 scalar_type<TA, TB, TC> alpha,
97 TA
const *A, int64_t lda,
98 TB
const *B, int64_t ldb,
99 real_type<TA, TB, TC> beta,
102 typedef blas::scalar_type<TA, TB, TC> scalar_t;
104 #define A(i_, j_) A[ (i_) + (j_)*lda ]
105 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
106 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
109 const scalar_t zero = 0;
110 const scalar_t one = 1;
113 blas_error_if( layout != Layout::ColMajor &&
114 layout != Layout::RowMajor );
115 blas_error_if( uplo != Uplo::Lower &&
116 uplo != Uplo::Upper &&
117 uplo != Uplo::General );
118 blas_error_if( n < 0 );
119 blas_error_if( k < 0 );
122 if (trans == Op::Trans) {
126 "trans == Op::Trans && "
127 "( blas::is_complex<TA>::value ||"
128 " blas::is_complex<TB>::value )" );
129 trans = Op::ConjTrans;
132 blas_error_if( trans != Op::NoTrans &&
133 trans != Op::ConjTrans );
137 if (layout == Layout::RowMajor) {
138 if (uplo == Uplo::Lower)
140 else if (uplo == Uplo::Upper)
142 trans = (trans == Op::NoTrans)
149 blas_error_if( lda < ((trans == Op::NoTrans) ? n : k) );
150 blas_error_if( ldb < ((trans == Op::NoTrans) ? n : k) );
151 blas_error_if( ldc < n );
154 if (n == 0 || k == 0)
160 if (uplo != Uplo::Upper) {
161 for (int64_t j = 0; j < n; ++j) {
162 for (int64_t i = 0; i <= j; ++i)
166 else if (uplo != Uplo::Lower) {
167 for (int64_t j = 0; j < n; ++j) {
168 for (int64_t i = j; i < n; ++i)
173 for (int64_t j = 0; j < n; ++j) {
174 for (int64_t i = 0; i < n; ++i)
179 else if (beta != one) {
180 if (uplo != Uplo::Upper) {
181 for (int64_t j = 0; j < n; ++j) {
182 for (int64_t i = 0; i < j; ++i)
184 C(j, j) = beta * real( C(j, j) );
187 else if (uplo != Uplo::Lower) {
188 for (int64_t j = 0; j < n; ++j) {
189 C(j, j) = beta * real( C(j, j) );
190 for (int64_t i = j+1; i < n; ++i)
195 for (int64_t j = 0; j < n; ++j) {
196 for (int64_t i = 0; i < j; ++i)
198 C(j, j) = beta * real( C(j, j) );
199 for (int64_t i = j+1; i < n; ++i)
208 if (trans == Op::NoTrans) {
209 if (uplo != Uplo::Lower) {
211 for (int64_t j = 0; j < n; ++j) {
213 for (int64_t i = 0; i < j; ++i)
215 C(j, j) = beta * real( C(j, j) );
217 for (int64_t l = 0; l < k; ++l) {
219 scalar_t alpha_conj_Bjl = alpha*conj( B(j, l) );
220 scalar_t conj_alpha_Ajl = conj( alpha*A(j, l) );
222 for (int64_t i = 0; i < j; ++i) {
223 C(i, j) += A(i, l)*alpha_conj_Bjl
224 + B(i, l)*conj_alpha_Ajl;
226 C(j, j) += 2 * real( A(j, l) * alpha_conj_Bjl );
231 for (int64_t j = 0; j < n; ++j) {
233 C(j, j) = beta * real( C(j, j) );
234 for (int64_t i = j+1; i < n; ++i)
237 for (int64_t l = 0; l < k; ++l) {
239 scalar_t alpha_conj_Bjl = alpha*conj( B(j, l) );
240 scalar_t conj_alpha_Ajl = conj( alpha*A(j, l) );
242 C(j, j) += 2 * real( A(j, l) * alpha_conj_Bjl );
243 for (int64_t i = j+1; i < n; ++i) {
244 C(i, j) += A(i, l) * alpha_conj_Bjl
245 + B(i, l) * conj_alpha_Ajl;
252 if (uplo != Uplo::Lower) {
254 for (int64_t j = 0; j < n; ++j) {
255 for (int64_t i = 0; i <= j; ++i) {
257 scalar_t sum1 = zero;
258 scalar_t sum2 = zero;
259 for (int64_t l = 0; l < k; ++l) {
260 sum1 += conj( A(l, i) ) * B(l, j);
261 sum2 += conj( B(l, i) ) * A(l, j);
265 ? alpha*sum1 + conj(alpha)*sum2 + beta*C(i, j)
266 : real( alpha*sum1 + conj(alpha)*sum2 )
267 + beta*real( C(i, j) );
274 for (int64_t j = 0; j < n; ++j) {
275 for (int64_t i = j; i < n; ++i) {
277 scalar_t sum1 = zero;
278 scalar_t sum2 = zero;
279 for (int64_t l = 0; l < k; ++l) {
280 sum1 += conj( A(l, i) ) * B(l, j);
281 sum2 += conj( B(l, i) ) * A(l, j);
285 ? alpha*sum1 + conj(alpha)*sum2 + beta*C(i, j)
286 : real( alpha*sum1 + conj(alpha)*sum2 )
287 + beta*real( C(i, j) );
294 if (uplo == Uplo::General) {
295 for (int64_t j = 0; j < n; ++j) {
296 for (int64_t i = j+1; i < n; ++i)
297 C(i, j) = conj( C(j, i) );
void her2k(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_her2k.cc:100
True if T is std::complex<T2> for some type T2.
Definition util.hh:349