BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
herk.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_HERK_HH
7#define BLAS_HERK_HH
8
9#include "blas/util.hh"
10#include "blas/syrk.hh"
11
12#include <limits>
13
14namespace blas {
15
16// =============================================================================
78
79template <typename TA, typename TC>
80void herk(
81 blas::Layout layout,
82 blas::Uplo uplo,
83 blas::Op trans,
84 int64_t n, int64_t k,
85 real_type<TA, TC> alpha, // note: real
86 TA const *A, int64_t lda,
87 real_type<TA, TC> beta, // note: real
88 TC *C, int64_t ldc )
89{
90 typedef blas::scalar_type<TA, TC> scalar_t;
91 typedef blas::real_type<TA, TC> real_t;
92
93 #define A(i_, j_) A[ (i_) + (j_)*lda ]
94 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
95
96 // constants
97 const scalar_t szero = 0;
98 const real_t zero = 0;
99 const real_t one = 1;
100
101 // check arguments
102 blas_error_if( layout != Layout::ColMajor &&
103 layout != Layout::RowMajor );
104 blas_error_if( uplo != Uplo::Lower &&
105 uplo != Uplo::Upper &&
106 uplo != Uplo::General );
107 blas_error_if( n < 0 );
108 blas_error_if( k < 0 );
109
110 // check and interpret argument trans
111 if (trans == Op::Trans) {
112 blas_error_if_msg(
114 "trans == Op::Trans && "
115 "blas::is_complex<TA>::value" );
116 trans = Op::ConjTrans;
117 }
118 else {
119 blas_error_if( trans != Op::NoTrans &&
120 trans != Op::ConjTrans );
121 }
122
123 // adapt if row major
124 if (layout == Layout::RowMajor) {
125 if (uplo == Uplo::Lower)
126 uplo = Uplo::Upper;
127 else if (uplo == Uplo::Upper)
128 uplo = Uplo::Lower;
129 trans = (trans == Op::NoTrans)
130 ? Op::ConjTrans
131 : Op::NoTrans;
132 alpha = conj(alpha);
133 }
134
135 // check remaining arguments
136 blas_error_if( lda < ((trans == Op::NoTrans) ? n : k) );
137 blas_error_if( ldc < n );
138
139 // quick return
140 if (n == 0 || k == 0)
141 return;
142
143 // alpha == zero
144 if (alpha == zero) {
145 if (beta == zero) {
146 if (uplo != Uplo::Upper) {
147 for (int64_t j = 0; j < n; ++j) {
148 for (int64_t i = 0; i <= j; ++i)
149 C(i, j) = szero;
150 }
151 }
152 else if (uplo != Uplo::Lower) {
153 for (int64_t j = 0; j < n; ++j) {
154 for (int64_t i = j; i < n; ++i)
155 C(i, j) = szero;
156 }
157 }
158 else {
159 for (int64_t j = 0; j < n; ++j) {
160 for (int64_t i = 0; i < n; ++i)
161 C(i, j) = szero;
162 }
163 }
164 }
165 else if (beta != one) {
166 if (uplo != Uplo::Upper) {
167 for (int64_t j = 0; j < n; ++j) {
168 for (int64_t i = 0; i < j; ++i)
169 C(i, j) *= beta;
170 C(j, j) = beta * real( C(j, j) );
171 }
172 }
173 else if (uplo != Uplo::Lower) {
174 for (int64_t j = 0; j < n; ++j) {
175 C(j, j) = beta * real( C(j, j) );
176 for (int64_t i = j+1; i < n; ++i)
177 C(i, j) *= beta;
178 }
179 }
180 else {
181 for (int64_t j = 0; j < n; ++j) {
182 for (int64_t i = 0; i < j; ++i)
183 C(i, j) *= beta;
184 C(j, j) = beta * real( C(j, j) );
185 for (int64_t i = j+1; i < n; ++i)
186 C(i, j) *= beta;
187 }
188 }
189 }
190 return;
191 }
192
193 // alpha != zero
194 if (trans == Op::NoTrans) {
195 if (uplo != Uplo::Lower) {
196 // uplo == Uplo::Upper or uplo == Uplo::General
197 for (int64_t j = 0; j < n; ++j) {
198
199 for (int64_t i = 0; i < j; ++i)
200 C(i, j) *= beta;
201 C(j, j) = beta * real( C(j, j) );
202
203 for (int64_t l = 0; l < k; ++l) {
204
205 scalar_t alpha_conj_Ajl = alpha*conj( A(j, l) );
206
207 for (int64_t i = 0; i < j; ++i)
208 C(i, j) += A(i, l)*alpha_conj_Ajl;
209 C(j, j) += real( A(j, l) * alpha_conj_Ajl );
210 }
211 }
212 }
213 else { // uplo == Uplo::Lower
214 for (int64_t j = 0; j < n; ++j) {
215
216 C(j, j) = beta * real( C(j, j) );
217 for (int64_t i = j+1; i < n; ++i)
218 C(i, j) *= beta;
219
220 for (int64_t l = 0; l < k; ++l) {
221
222 scalar_t alpha_conj_Ajl = alpha*conj( A(j, l) );
223
224 C(j, j) += real( A(j, l) * alpha_conj_Ajl );
225 for (int64_t i = j+1; i < n; ++i) {
226 C(i, j) += A(i, l) * alpha_conj_Ajl;
227 }
228 }
229 }
230 }
231 }
232 else { // trans == Op::ConjTrans
233 if (uplo != Uplo::Lower) {
234 // uplo == Uplo::Upper or uplo == Uplo::General
235 for (int64_t j = 0; j < n; ++j) {
236 for (int64_t i = 0; i < j; ++i) {
237 scalar_t sum = szero;
238 for (int64_t l = 0; l < k; ++l)
239 sum += conj( A(l, i) ) * A(l, j);
240 C(i, j) = alpha*sum + beta*C(i, j);
241 }
242 real_t sum = zero;
243 for (int64_t l = 0; l < k; ++l)
244 sum += real(A(l, j)) * real(A(l, j))
245 + imag(A(l, j)) * imag(A(l, j));
246 C(j, j) = alpha*sum + beta*real( C(j, j) );
247 }
248 }
249 else {
250 // uplo == Uplo::Lower
251 for (int64_t j = 0; j < n; ++j) {
252 for (int64_t i = j+1; i < n; ++i) {
253 scalar_t sum = szero;
254 for (int64_t l = 0; l < k; ++l)
255 sum += conj( A(l, i) ) * A(l, j);
256 C(i, j) = alpha*sum + beta*C(i, j);
257 }
258 real_t sum = zero;
259 for (int64_t l = 0; l < k; ++l)
260 sum += real(A(l, j)) * real(A(l, j))
261 + imag(A(l, j)) * imag(A(l, j));
262 C(j, j) = alpha*sum + beta*real( C(j, j) );
263 }
264 }
265 }
266
267 if (uplo == Uplo::General) {
268 for (int64_t j = 0; j < n; ++j) {
269 for (int64_t i = j+1; i < n; ++i)
270 C(i, j) = conj( C(j, i) );
271 }
272 }
273
274 #undef A
275 #undef C
276}
277
278} // namespace blas
279
280#endif // #ifndef BLAS_HERK_HH
void herk(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_herk.cc:92
True if T is std::complex<T2> for some type T2.
Definition util.hh:349