10#include "blas/syrk.hh"
79template <
typename TA,
typename TC>
85 real_type<TA, TC> alpha,
86 TA
const *A, int64_t lda,
87 real_type<TA, TC> beta,
90 typedef blas::scalar_type<TA, TC> scalar_t;
91 typedef blas::real_type<TA, TC> real_t;
93 #define A(i_, j_) A[ (i_) + (j_)*lda ]
94 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
97 const scalar_t szero = 0;
98 const real_t zero = 0;
102 blas_error_if( layout != Layout::ColMajor &&
103 layout != Layout::RowMajor );
104 blas_error_if( uplo != Uplo::Lower &&
105 uplo != Uplo::Upper &&
106 uplo != Uplo::General );
107 blas_error_if( n < 0 );
108 blas_error_if( k < 0 );
111 if (trans == Op::Trans) {
114 "trans == Op::Trans && "
115 "blas::is_complex<TA>::value" );
116 trans = Op::ConjTrans;
119 blas_error_if( trans != Op::NoTrans &&
120 trans != Op::ConjTrans );
124 if (layout == Layout::RowMajor) {
125 if (uplo == Uplo::Lower)
127 else if (uplo == Uplo::Upper)
129 trans = (trans == Op::NoTrans)
136 blas_error_if( lda < ((trans == Op::NoTrans) ? n : k) );
137 blas_error_if( ldc < n );
140 if (n == 0 || k == 0)
146 if (uplo != Uplo::Upper) {
147 for (int64_t j = 0; j < n; ++j) {
148 for (int64_t i = 0; i <= j; ++i)
152 else if (uplo != Uplo::Lower) {
153 for (int64_t j = 0; j < n; ++j) {
154 for (int64_t i = j; i < n; ++i)
159 for (int64_t j = 0; j < n; ++j) {
160 for (int64_t i = 0; i < n; ++i)
165 else if (beta != one) {
166 if (uplo != Uplo::Upper) {
167 for (int64_t j = 0; j < n; ++j) {
168 for (int64_t i = 0; i < j; ++i)
170 C(j, j) = beta * real( C(j, j) );
173 else if (uplo != Uplo::Lower) {
174 for (int64_t j = 0; j < n; ++j) {
175 C(j, j) = beta * real( C(j, j) );
176 for (int64_t i = j+1; i < n; ++i)
181 for (int64_t j = 0; j < n; ++j) {
182 for (int64_t i = 0; i < j; ++i)
184 C(j, j) = beta * real( C(j, j) );
185 for (int64_t i = j+1; i < n; ++i)
194 if (trans == Op::NoTrans) {
195 if (uplo != Uplo::Lower) {
197 for (int64_t j = 0; j < n; ++j) {
199 for (int64_t i = 0; i < j; ++i)
201 C(j, j) = beta * real( C(j, j) );
203 for (int64_t l = 0; l < k; ++l) {
205 scalar_t alpha_conj_Ajl = alpha*conj( A(j, l) );
207 for (int64_t i = 0; i < j; ++i)
208 C(i, j) += A(i, l)*alpha_conj_Ajl;
209 C(j, j) += real( A(j, l) * alpha_conj_Ajl );
214 for (int64_t j = 0; j < n; ++j) {
216 C(j, j) = beta * real( C(j, j) );
217 for (int64_t i = j+1; i < n; ++i)
220 for (int64_t l = 0; l < k; ++l) {
222 scalar_t alpha_conj_Ajl = alpha*conj( A(j, l) );
224 C(j, j) += real( A(j, l) * alpha_conj_Ajl );
225 for (int64_t i = j+1; i < n; ++i) {
226 C(i, j) += A(i, l) * alpha_conj_Ajl;
233 if (uplo != Uplo::Lower) {
235 for (int64_t j = 0; j < n; ++j) {
236 for (int64_t i = 0; i < j; ++i) {
237 scalar_t sum = szero;
238 for (int64_t l = 0; l < k; ++l)
239 sum += conj( A(l, i) ) * A(l, j);
240 C(i, j) = alpha*sum + beta*C(i, j);
243 for (int64_t l = 0; l < k; ++l)
244 sum += real(A(l, j)) * real(A(l, j))
245 + imag(A(l, j)) * imag(A(l, j));
246 C(j, j) = alpha*sum + beta*real( C(j, j) );
251 for (int64_t j = 0; j < n; ++j) {
252 for (int64_t i = j+1; i < n; ++i) {
253 scalar_t sum = szero;
254 for (int64_t l = 0; l < k; ++l)
255 sum += conj( A(l, i) ) * A(l, j);
256 C(i, j) = alpha*sum + beta*C(i, j);
259 for (int64_t l = 0; l < k; ++l)
260 sum += real(A(l, j)) * real(A(l, j))
261 + imag(A(l, j)) * imag(A(l, j));
262 C(j, j) = alpha*sum + beta*real( C(j, j) );
267 if (uplo == Uplo::General) {
268 for (int64_t j = 0; j < n; ++j) {
269 for (int64_t i = j+1; i < n; ++i)
270 C(i, j) = conj( C(j, i) );
void herk(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_herk.cc:92
True if T is std::complex<T2> for some type T2.
Definition util.hh:349