BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
syr2k.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_SYR2K_HH
7#define BLAS_SYR2K_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
88
89template <typename TA, typename TB, typename TC>
90void syr2k(
91 blas::Layout layout,
92 blas::Uplo uplo,
93 blas::Op trans,
94 int64_t n, int64_t k,
95 scalar_type<TA, TB, TC> alpha,
96 TA const *A, int64_t lda,
97 TB const *B, int64_t ldb,
98 scalar_type<TA, TB, TC> beta,
99 TC *C, int64_t ldc )
100{
101 typedef blas::scalar_type<TA, TB, TC> scalar_t;
102
103 #define A(i_, j_) A[ (i_) + (j_)*lda ]
104 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
105 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
106
107 // constants
108 const scalar_t zero = 0;
109 const scalar_t one = 1;
110
111 // check arguments
112 blas_error_if( layout != Layout::ColMajor &&
113 layout != Layout::RowMajor );
114 blas_error_if( uplo != Uplo::Lower &&
115 uplo != Uplo::Upper &&
116 uplo != Uplo::General );
117 blas_error_if( n < 0 );
118 blas_error_if( k < 0 );
119
120 // check and interpret argument trans
121 if (trans == Op::ConjTrans) {
122 blas_error_if_msg(
125 "trans == Op::ConjTrans && "
126 "( blas::is_complex<TA>::value ||"
127 " blas::is_complex<TB>::value )" );
128 trans = Op::Trans;
129 }
130 else {
131 blas_error_if( trans != Op::NoTrans &&
132 trans != Op::Trans );
133 }
134
135 // adapt if row major
136 if (layout == Layout::RowMajor) {
137 if (uplo == Uplo::Lower)
138 uplo = Uplo::Upper;
139 else if (uplo == Uplo::Upper)
140 uplo = Uplo::Lower;
141 trans = (trans == Op::NoTrans)
142 ? Op::Trans
143 : Op::NoTrans;
144 }
145
146 // check remaining arguments
147 blas_error_if( lda < ((trans == Op::NoTrans) ? n : k) );
148 blas_error_if( ldb < ((trans == Op::NoTrans) ? n : k) );
149 blas_error_if( ldc < n );
150
151 // quick return
152 if (n == 0 || k == 0)
153 return;
154
155 // alpha == zero
156 if (alpha == zero) {
157 if (beta == zero) {
158 if (uplo != Uplo::Upper) {
159 for (int64_t j = 0; j < n; ++j) {
160 for (int64_t i = 0; i <= j; ++i)
161 C(i, j) = zero;
162 }
163 }
164 else if (uplo != Uplo::Lower) {
165 for (int64_t j = 0; j < n; ++j) {
166 for (int64_t i = j; i < n; ++i)
167 C(i, j) = zero;
168 }
169 }
170 else {
171 for (int64_t j = 0; j < n; ++j) {
172 for (int64_t i = 0; i < n; ++i)
173 C(i, j) = zero;
174 }
175 }
176 }
177 else if (beta != one) {
178 if (uplo != Uplo::Upper) {
179 for (int64_t j = 0; j < n; ++j) {
180 for (int64_t i = 0; i <= j; ++i)
181 C(i, j) *= beta;
182 }
183 }
184 else if (uplo != Uplo::Lower) {
185 for (int64_t j = 0; j < n; ++j) {
186 for (int64_t i = j; i < n; ++i)
187 C(i, j) *= beta;
188 }
189 }
190 else {
191 for (int64_t j = 0; j < n; ++j) {
192 for (int64_t i = 0; i < n; ++i)
193 C(i, j) *= beta;
194 }
195 }
196 }
197 return;
198 }
199
200 // alpha != zero
201 if (trans == Op::NoTrans) {
202 if (uplo != Uplo::Lower) {
203 // uplo == Uplo::Upper or uplo == Uplo::General
204 for (int64_t j = 0; j < n; ++j) {
205
206 for (int64_t i = 0; i <= j; ++i)
207 C(i, j) *= beta;
208
209 for (int64_t l = 0; l < k; ++l) {
210 scalar_t alpha_Bjl = alpha*B(j, l);
211 scalar_t alpha_Ajl = alpha*A(j, l);
212 for (int64_t i = 0; i <= j; ++i)
213 C(i, j) += A(i, l)*alpha_Bjl + B(i, l)*alpha_Ajl;
214 }
215 }
216 }
217 else { // uplo == Uplo::Lower
218 for (int64_t j = 0; j < n; ++j) {
219
220 for (int64_t i = j; i < n; ++i)
221 C(i, j) *= beta;
222
223 for (int64_t l = 0; l < k; ++l) {
224 scalar_t alpha_Bjl = alpha*B(j, l);
225 scalar_t alpha_Ajl = alpha*A(j, l);
226 for (int64_t i = j; i < n; ++i)
227 C(i, j) += A(i, l)*alpha_Bjl + B(i, l)*alpha_Ajl;
228 }
229 }
230 }
231 }
232 else { // trans == Op::Trans
233 if (uplo != Uplo::Lower) {
234 // uplo == Uplo::Upper or uplo == Uplo::General
235 for (int64_t j = 0; j < n; ++j) {
236 for (int64_t i = 0; i <= j; ++i) {
237 scalar_t sum1 = zero;
238 scalar_t sum2 = zero;
239 for (int64_t l = 0; l < k; ++l) {
240 sum1 += A(l, i) * B(l, j);
241 sum2 += B(l, i) * A(l, j);
242 }
243 C(i, j) = alpha*sum1 + alpha*sum2 + beta*C(i, j);
244 }
245 }
246 }
247 else { // uplo == Uplo::Lower
248 for (int64_t j = 0; j < n; ++j) {
249 for (int64_t i = j; i < n; ++i) {
250 scalar_t sum1 = zero;
251 scalar_t sum2 = zero;
252 for (int64_t l = 0; l < k; ++l) {
253 sum1 += A(l, i) * B(l, j);
254 sum2 += B(l, i) * A(l, j);
255 }
256 C(i, j) = alpha*sum1 + alpha*sum2 + beta*C(i, j);
257 }
258 }
259 }
260 }
261
262 if (uplo == Uplo::General) {
263 for (int64_t j = 0; j < n; ++j) {
264 for (int64_t i = j+1; i < n; ++i)
265 C(i, j) = C(j, i);
266 }
267 }
268
269 #undef A
270 #undef B
271 #undef C
272}
273
274} // namespace blas
275
276#endif // #ifndef BLAS_SYMM_HH
void syr2k(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_syr2k.cc:107
True if T is std::complex<T2> for some type T2.
Definition util.hh:349