BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
her2k.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_HER2K_HH
7#define BLAS_HER2K_HH
8
9#include "blas/util.hh"
10#include "blas/syr2k.hh"
11
12#include <limits>
13
14namespace blas {
15
16// =============================================================================
89
90template <typename TA, typename TB, typename TC>
91void her2k(
92 blas::Layout layout,
93 blas::Uplo uplo,
94 blas::Op trans,
95 int64_t n, int64_t k,
96 scalar_type<TA, TB, TC> alpha, // note: complex
97 TA const *A, int64_t lda,
98 TB const *B, int64_t ldb,
99 real_type<TA, TB, TC> beta, // note: real
100 TC *C, int64_t ldc )
101{
102 typedef blas::scalar_type<TA, TB, TC> scalar_t;
103
104 #define A(i_, j_) A[ (i_) + (j_)*lda ]
105 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
106 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
107
108 // constants
109 const scalar_t zero = 0;
110 const scalar_t one = 1;
111
112 // check arguments
113 blas_error_if( layout != Layout::ColMajor &&
114 layout != Layout::RowMajor );
115 blas_error_if( uplo != Uplo::Lower &&
116 uplo != Uplo::Upper &&
117 uplo != Uplo::General );
118 blas_error_if( n < 0 );
119 blas_error_if( k < 0 );
120
121 // check and interpret argument trans
122 if (trans == Op::Trans) {
123 blas_error_if_msg(
126 "trans == Op::Trans && "
127 "( blas::is_complex<TA>::value ||"
128 " blas::is_complex<TB>::value )" );
129 trans = Op::ConjTrans;
130 }
131 else {
132 blas_error_if( trans != Op::NoTrans &&
133 trans != Op::ConjTrans );
134 }
135
136 // adapt if row major
137 if (layout == Layout::RowMajor) {
138 if (uplo == Uplo::Lower)
139 uplo = Uplo::Upper;
140 else if (uplo == Uplo::Upper)
141 uplo = Uplo::Lower;
142 trans = (trans == Op::NoTrans)
143 ? Op::ConjTrans
144 : Op::NoTrans;
145 alpha = conj(alpha);
146 }
147
148 // check remaining arguments
149 blas_error_if( lda < ((trans == Op::NoTrans) ? n : k) );
150 blas_error_if( ldb < ((trans == Op::NoTrans) ? n : k) );
151 blas_error_if( ldc < n );
152
153 // quick return
154 if (n == 0 || k == 0)
155 return;
156
157 // alpha == zero
158 if (alpha == zero) {
159 if (beta == zero) {
160 if (uplo != Uplo::Upper) {
161 for (int64_t j = 0; j < n; ++j) {
162 for (int64_t i = 0; i <= j; ++i)
163 C(i, j) = zero;
164 }
165 }
166 else if (uplo != Uplo::Lower) {
167 for (int64_t j = 0; j < n; ++j) {
168 for (int64_t i = j; i < n; ++i)
169 C(i, j) = zero;
170 }
171 }
172 else {
173 for (int64_t j = 0; j < n; ++j) {
174 for (int64_t i = 0; i < n; ++i)
175 C(i, j) = zero;
176 }
177 }
178 }
179 else if (beta != one) {
180 if (uplo != Uplo::Upper) {
181 for (int64_t j = 0; j < n; ++j) {
182 for (int64_t i = 0; i < j; ++i)
183 C(i, j) *= beta;
184 C(j, j) = beta * real( C(j, j) );
185 }
186 }
187 else if (uplo != Uplo::Lower) {
188 for (int64_t j = 0; j < n; ++j) {
189 C(j, j) = beta * real( C(j, j) );
190 for (int64_t i = j+1; i < n; ++i)
191 C(i, j) *= beta;
192 }
193 }
194 else {
195 for (int64_t j = 0; j < n; ++j) {
196 for (int64_t i = 0; i < j; ++i)
197 C(i, j) *= beta;
198 C(j, j) = beta * real( C(j, j) );
199 for (int64_t i = j+1; i < n; ++i)
200 C(i, j) *= beta;
201 }
202 }
203 }
204 return;
205 }
206
207 // alpha != zero
208 if (trans == Op::NoTrans) {
209 if (uplo != Uplo::Lower) {
210 // uplo == Uplo::Upper or uplo == Uplo::General
211 for (int64_t j = 0; j < n; ++j) {
212
213 for (int64_t i = 0; i < j; ++i)
214 C(i, j) *= beta;
215 C(j, j) = beta * real( C(j, j) );
216
217 for (int64_t l = 0; l < k; ++l) {
218
219 scalar_t alpha_conj_Bjl = alpha*conj( B(j, l) );
220 scalar_t conj_alpha_Ajl = conj( alpha*A(j, l) );
221
222 for (int64_t i = 0; i < j; ++i) {
223 C(i, j) += A(i, l)*alpha_conj_Bjl
224 + B(i, l)*conj_alpha_Ajl;
225 }
226 C(j, j) += 2 * real( A(j, l) * alpha_conj_Bjl );
227 }
228 }
229 }
230 else { // uplo == Uplo::Lower
231 for (int64_t j = 0; j < n; ++j) {
232
233 C(j, j) = beta * real( C(j, j) );
234 for (int64_t i = j+1; i < n; ++i)
235 C(i, j) *= beta;
236
237 for (int64_t l = 0; l < k; ++l) {
238
239 scalar_t alpha_conj_Bjl = alpha*conj( B(j, l) );
240 scalar_t conj_alpha_Ajl = conj( alpha*A(j, l) );
241
242 C(j, j) += 2 * real( A(j, l) * alpha_conj_Bjl );
243 for (int64_t i = j+1; i < n; ++i) {
244 C(i, j) += A(i, l) * alpha_conj_Bjl
245 + B(i, l) * conj_alpha_Ajl;
246 }
247 }
248 }
249 }
250 }
251 else { // trans == Op::ConjTrans
252 if (uplo != Uplo::Lower) {
253 // uplo == Uplo::Upper or uplo == Uplo::General
254 for (int64_t j = 0; j < n; ++j) {
255 for (int64_t i = 0; i <= j; ++i) {
256
257 scalar_t sum1 = zero;
258 scalar_t sum2 = zero;
259 for (int64_t l = 0; l < k; ++l) {
260 sum1 += conj( A(l, i) ) * B(l, j);
261 sum2 += conj( B(l, i) ) * A(l, j);
262 }
263
264 C(i, j) = (i < j)
265 ? alpha*sum1 + conj(alpha)*sum2 + beta*C(i, j)
266 : real( alpha*sum1 + conj(alpha)*sum2 )
267 + beta*real( C(i, j) );
268 }
269
270 }
271 }
272 else {
273 // uplo == Uplo::Lower
274 for (int64_t j = 0; j < n; ++j) {
275 for (int64_t i = j; i < n; ++i) {
276
277 scalar_t sum1 = zero;
278 scalar_t sum2 = zero;
279 for (int64_t l = 0; l < k; ++l) {
280 sum1 += conj( A(l, i) ) * B(l, j);
281 sum2 += conj( B(l, i) ) * A(l, j);
282 }
283
284 C(i, j) = (i > j)
285 ? alpha*sum1 + conj(alpha)*sum2 + beta*C(i, j)
286 : real( alpha*sum1 + conj(alpha)*sum2 )
287 + beta*real( C(i, j) );
288 }
289
290 }
291 }
292 }
293
294 if (uplo == Uplo::General) {
295 for (int64_t j = 0; j < n; ++j) {
296 for (int64_t i = j+1; i < n; ++i)
297 C(i, j) = conj( C(j, i) );
298 }
299 }
300
301 #undef A
302 #undef B
303 #undef C
304}
305
306} // namespace blas
307
308#endif // #ifndef BLAS_HER2K_HH
void her2k(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_her2k.cc:100
True if T is std::complex<T2> for some type T2.
Definition util.hh:349