BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
trmm.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_TRMM_HH
7#define BLAS_TRMM_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
86
87template <typename TA, typename TB>
88void trmm(
89 blas::Layout layout,
90 blas::Side side,
91 blas::Uplo uplo,
92 blas::Op trans,
93 blas::Diag diag,
94 int64_t m,
95 int64_t n,
96 blas::scalar_type<TA, TB> alpha,
97 TA const *A, int64_t lda,
98 TB *B, int64_t ldb )
99{
100 using std::swap;
101 using scalar_t = blas::scalar_type<TA, TB>;
102
103 #define A(i_, j_) A[ (i_) + (j_)*lda ]
104 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
105
106 // constants
107 const scalar_t zero = 0;
108
109 // check arguments
110 blas_error_if( layout != Layout::ColMajor &&
111 layout != Layout::RowMajor );
112 blas_error_if( side != Side::Left &&
113 side != Side::Right );
114 blas_error_if( uplo != Uplo::Lower &&
115 uplo != Uplo::Upper );
116 blas_error_if( trans != Op::NoTrans &&
117 trans != Op::Trans &&
118 trans != Op::ConjTrans );
119 blas_error_if( diag != Diag::NonUnit &&
120 diag != Diag::Unit );
121 blas_error_if( m < 0 );
122 blas_error_if( n < 0 );
123
124 // adapt if row major
125 if (layout == Layout::RowMajor) {
126 side = (side == Side::Left)
127 ? Side::Right
128 : Side::Left;
129 if (uplo == Uplo::Lower)
130 uplo = Uplo::Upper;
131 else if (uplo == Uplo::Upper)
132 uplo = Uplo::Lower;
133 swap( m, n );
134 }
135
136 // check remaining arguments
137 blas_error_if( lda < ((side == Side::Left) ? m : n) );
138 blas_error_if( ldb < m );
139
140 // quick return
141 if (m == 0 || n == 0)
142 return;
143
144 // alpha == zero
145 if (alpha == zero) {
146 for (int64_t j = 0; j < n; ++j) {
147 for (int64_t i = 0; i < m; ++i)
148 B(i, j) = zero;
149 }
150 return;
151 }
152
153 // alpha != zero
154 if (side == Side::Left) {
155 if (trans == Op::NoTrans) {
156 if (uplo == Uplo::Upper) {
157 for (int64_t j = 0; j < n; ++j) {
158 for (int64_t k = 0; k < m; ++k) {
159 scalar_t alpha_Bkj = alpha*B(k, j);
160 for (int64_t i = 0; i < k; ++i)
161 B(i, j) += A(i, k)*alpha_Bkj;
162 B(k, j) = (diag == Diag::NonUnit)
163 ? A(k, k)*alpha_Bkj
164 : alpha_Bkj;
165 }
166 }
167 }
168 else { // uplo == Uplo::Lower
169 for (int64_t j = 0; j < n; ++j) {
170 for (int64_t k = m-1; k >= 0; --k) {
171 scalar_t alpha_Bkj = alpha*B(k, j);
172 B(k, j) = (diag == Diag::NonUnit)
173 ? A(k, k)*alpha_Bkj
174 : alpha_Bkj;
175 for (int64_t i = k+1; i < m; ++i)
176 B(i, j) += A(i, k)*alpha_Bkj;
177 }
178 }
179 }
180 }
181 else if (trans == Op::Trans) {
182 if (uplo == Uplo::Upper) {
183 for (int64_t j = 0; j < n; ++j) {
184 for (int64_t i = m-1; i >= 0; --i) {
185 scalar_t sum = (diag == Diag::NonUnit)
186 ? A(i, i)*B(i, j)
187 : B(i, j);
188 for (int64_t k = 0; k < i; ++k)
189 sum += A(k, i)*B(k, j);
190 B(i, j) = alpha * sum;
191 }
192 }
193 }
194 else { // uplo == Uplo::Lower
195 for (int64_t j = 0; j < n; ++j) {
196 for (int64_t i = 0; i < m; ++i) {
197 scalar_t sum = (diag == Diag::NonUnit)
198 ? A(i, i)*B(i, j)
199 : B(i, j);
200 for (int64_t k = i+1; k < m; ++k)
201 sum += A(k, i)*B(k, j);
202 B(i, j) = alpha * sum;
203 }
204 }
205 }
206 }
207 else { // trans == Op::ConjTrans
208 if (uplo == Uplo::Upper) {
209 for (int64_t j = 0; j < n; ++j) {
210 for (int64_t i = m-1; i >= 0; --i) {
211 scalar_t sum = (diag == Diag::NonUnit)
212 ? conj(A(i, i))*B(i, j)
213 : B(i, j);
214 for (int64_t k = 0; k < i; ++k)
215 sum += conj(A(k, i))*B(k, j);
216 B(i, j) = alpha * sum;
217 }
218 }
219 }
220 else { // uplo == Uplo::Lower
221 for (int64_t j = 0; j < n; ++j) {
222 for (int64_t i = 0; i < m; ++i) {
223 scalar_t sum = (diag == Diag::NonUnit)
224 ? conj(A(i, i))*B(i, j)
225 : B(i, j);
226 for (int64_t k = i+1; k < m; ++k)
227 sum += conj(A(k, i))*B(k, j);
228 B(i, j) = alpha * sum;
229 }
230 }
231 }
232 }
233 }
234 else { // side == Side::Right
235 if (trans == Op::NoTrans) {
236 if (uplo == Uplo::Upper) {
237 for (int64_t j = n-1; j >= 0; --j) {
238
239 scalar_t alpha_Akj = (diag == Diag::NonUnit)
240 ? alpha*A(j, j)
241 : alpha;
242 for (int64_t i = 0; i < m; ++i)
243 B(i, j) *= alpha_Akj;
244
245 for (int64_t k = 0; k < j; ++k) {
246 alpha_Akj = alpha*A(k, j);
247 for (int64_t i = 0; i < m; ++i)
248 B(i, j) += B(i, k)*alpha_Akj;
249 }
250 }
251 }
252 else { // uplo == Uplo::Lower
253 for (int64_t j = 0; j < n; ++j) {
254
255 scalar_t alpha_Akj = (diag == Diag::NonUnit)
256 ? alpha*A(j, j)
257 : alpha;
258 for (int64_t i = 0; i < m; ++i)
259 B(i, j) *= alpha_Akj;
260
261 for (int64_t k = j+1; k < n; ++k) {
262 alpha_Akj = alpha*A(k, j);
263 for (int64_t i = 0; i < m; ++i)
264 B(i, j) += B(i, k)*alpha_Akj;
265 }
266 }
267 }
268 }
269 else if (trans == Op::Trans) {
270 if (uplo == Uplo::Upper) {
271 for (int64_t k = 0; k < n; ++k) {
272 for (int64_t j = 0; j < k; ++j) {
273 scalar_t alpha_Ajk = alpha*A(j, k);
274 for (int64_t i = 0; i < m; ++i)
275 B(i, j) += B(i, k)*alpha_Ajk;
276 }
277
278 scalar_t alpha_Akk = (diag == Diag::NonUnit)
279 ? alpha*A(k, k)
280 : alpha;
281 for (int64_t i = 0; i < m; ++i)
282 B(i, k) *= alpha_Akk;
283 }
284 }
285 else { // uplo == Uplo::Lower
286 for (int64_t k = n-1; k >= 0; --k) {
287 for (int64_t j = k+1; j < n; ++j) {
288 scalar_t alpha_Ajk = alpha*A(j, k);
289 for (int64_t i = 0; i < m; ++i)
290 B(i, j) += B(i, k)*alpha_Ajk;
291 }
292
293 scalar_t alpha_Akk = (diag == Diag::NonUnit)
294 ? alpha*A(k, k)
295 : alpha;
296 for (int64_t i = 0; i < m; ++i)
297 B(i, k) *= alpha_Akk;
298 }
299 }
300 }
301 else { // trans == Op::ConjTrans
302 if (uplo == Uplo::Upper) {
303 for (int64_t k = 0; k < n; ++k) {
304 for (int64_t j = 0; j < k; ++j) {
305 scalar_t alpha_Ajk = alpha*conj(A(j, k));
306 for (int64_t i = 0; i < m; ++i)
307 B(i, j) += B(i, k)*alpha_Ajk;
308 }
309
310 scalar_t alpha_Akk = (diag == Diag::NonUnit)
311 ? alpha*conj(A(k, k))
312 : alpha;
313 for (int64_t i = 0; i < m; ++i)
314 B(i, k) *= alpha_Akk;
315 }
316 }
317 else { // uplo == Uplo::Lower
318 for (int64_t k = n-1; k >= 0; --k) {
319 for (int64_t j = k+1; j < n; ++j) {
320 scalar_t alpha_Ajk = alpha*conj(A(j, k));
321 for (int64_t i = 0; i < m; ++i)
322 B(i, j) += B(i, k)*alpha_Ajk;
323 }
324
325 scalar_t alpha_Akk = (diag == Diag::NonUnit)
326 ? alpha*conj(A(k, k))
327 : alpha;
328 for (int64_t i = 0; i < m; ++i)
329 B(i, k) *= alpha_Akk;
330 }
331 }
332 }
333 }
334
335 #undef A
336 #undef B
337}
338
339} // namespace blas
340
341#endif // #ifndef BLAS_TRMM_HH
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67
void trmm(blas::Layout layout, blas::Side side, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float *B, int64_t ldb, blas::Queue &queue)
GPU device, float version.
Definition device_trmm.cc:104