BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
syrk.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_SYRK_HH
7#define BLAS_SYRK_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
77
78template <typename TA, typename TC>
79void syrk(
80 blas::Layout layout,
81 blas::Uplo uplo,
82 blas::Op trans,
83 int64_t n, int64_t k,
84 scalar_type<TA, TC> alpha,
85 TA const *A, int64_t lda,
86 scalar_type<TA, TC> beta,
87 TC *C, int64_t ldc )
88{
89 typedef blas::scalar_type<TA, TC> scalar_t;
90
91 #define A(i_, j_) A[ (i_) + (j_)*lda ]
92 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
93
94 // constants
95 const scalar_t zero = 0;
96 const scalar_t one = 1;
97
98 // check arguments
99 blas_error_if( layout != Layout::ColMajor &&
100 layout != Layout::RowMajor );
101 blas_error_if( uplo != Uplo::Lower &&
102 uplo != Uplo::Upper &&
103 uplo != Uplo::General );
104 blas_error_if( n < 0 );
105 blas_error_if( k < 0 );
106
107 // check and interpret argument trans
108 if (trans == Op::ConjTrans) {
109 blas_error_if_msg(
111 "trans == Op::ConjTrans && "
112 "blas::is_complex<TA>::value" );
113 trans = Op::Trans;
114 }
115 else {
116 blas_error_if( trans != Op::NoTrans &&
117 trans != Op::Trans );
118 }
119
120 // adapt if row major
121 if (layout == Layout::RowMajor) {
122 if (uplo == Uplo::Lower)
123 uplo = Uplo::Upper;
124 else if (uplo == Uplo::Upper)
125 uplo = Uplo::Lower;
126 trans = (trans == Op::NoTrans)
127 ? Op::Trans
128 : Op::NoTrans;
129 }
130
131 // check remaining arguments
132 blas_error_if( lda < ((trans == Op::NoTrans) ? n : k) );
133 blas_error_if( ldc < n );
134
135 // quick return
136 if (n == 0 || k == 0)
137 return;
138
139 // alpha == zero
140 if (alpha == zero) {
141 if (beta == zero) {
142 if (uplo != Uplo::Upper) {
143 for (int64_t j = 0; j < n; ++j) {
144 for (int64_t i = 0; i <= j; ++i)
145 C(i, j) = zero;
146 }
147 }
148 else if (uplo != Uplo::Lower) {
149 for (int64_t j = 0; j < n; ++j) {
150 for (int64_t i = j; i < n; ++i)
151 C(i, j) = zero;
152 }
153 }
154 else {
155 for (int64_t j = 0; j < n; ++j) {
156 for (int64_t i = 0; i < n; ++i)
157 C(i, j) = zero;
158 }
159 }
160 }
161 else if (beta != one) {
162 if (uplo != Uplo::Upper) {
163 for (int64_t j = 0; j < n; ++j) {
164 for (int64_t i = 0; i <= j; ++i)
165 C(i, j) *= beta;
166 }
167 }
168 else if (uplo != Uplo::Lower) {
169 for (int64_t j = 0; j < n; ++j) {
170 for (int64_t i = j; i < n; ++i)
171 C(i, j) *= beta;
172 }
173 }
174 else {
175 for (int64_t j = 0; j < n; ++j) {
176 for (int64_t i = 0; i < n; ++i)
177 C(i, j) *= beta;
178 }
179 }
180 }
181 return;
182 }
183
184 // alpha != zero
185 if (trans == Op::NoTrans) {
186 if (uplo != Uplo::Lower) {
187 // uplo == Uplo::Upper or uplo == Uplo::General
188 for (int64_t j = 0; j < n; ++j) {
189
190 for (int64_t i = 0; i <= j; ++i)
191 C(i, j) *= beta;
192
193 for (int64_t l = 0; l < k; ++l) {
194 scalar_t alpha_Ajl = alpha*A(j, l);
195 for (int64_t i = 0; i <= j; ++i)
196 C(i, j) += A(i, l)*alpha_Ajl;
197 }
198 }
199 }
200 else { // uplo == Uplo::Lower
201 for (int64_t j = 0; j < n; ++j) {
202
203 for (int64_t i = j; i < n; ++i)
204 C(i, j) *= beta;
205
206 for (int64_t l = 0; l < k; ++l) {
207 scalar_t alpha_Ajl = alpha*A(j, l);
208 for (int64_t i = j; i < n; ++i)
209 C(i, j) += A(i, l)*alpha_Ajl;
210 }
211 }
212 }
213 }
214 else { // trans == Op::Trans
215 if (uplo != Uplo::Lower) {
216 // uplo == Uplo::Upper or uplo == Uplo::General
217 for (int64_t j = 0; j < n; ++j) {
218 for (int64_t i = 0; i <= j; ++i) {
219 scalar_t sum = zero;
220 for (int64_t l = 0; l < k; ++l)
221 sum += A(l, i) * A(l, j);
222 C(i, j) = alpha*sum + beta*C(i, j);
223 }
224 }
225 }
226 else { // uplo == Uplo::Lower
227 for (int64_t j = 0; j < n; ++j) {
228 for (int64_t i = j; i < n; ++i) {
229 scalar_t sum = zero;
230 for (int64_t l = 0; l < k; ++l) {
231 sum += A(l, i) * A(l, j);
232 }
233 C(i, j) = alpha*sum + beta*C(i, j);
234 }
235 }
236 }
237 }
238
239 if (uplo == Uplo::General) {
240 for (int64_t j = 0; j < n; ++j) {
241 for (int64_t i = j+1; i < n; ++i)
242 C(i, j) = C(j, i);
243 }
244 }
245
246 #undef A
247 #undef C
248}
249
250} // namespace blas
251
252#endif // #ifndef BLAS_SYMM_HH
void syrk(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_syrk.cc:101
True if T is std::complex<T2> for some type T2.
Definition util.hh:349