BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
trsm.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_TRSM_HH
7#define BLAS_TRSM_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
91
92template <typename TA, typename TB>
93void trsm(
94 blas::Layout layout,
95 blas::Side side,
96 blas::Uplo uplo,
97 blas::Op trans,
98 blas::Diag diag,
99 int64_t m,
100 int64_t n,
101 blas::scalar_type<TA, TB> alpha,
102 TA const *A, int64_t lda,
103 TB *B, int64_t ldb )
104{
105 using std::swap;
106 using scalar_t = blas::scalar_type<TA, TB>;
107
108 #define A(i_, j_) A[ (i_) + (j_)*lda ]
109 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
110
111 // constants
112 const scalar_t zero = 0;
113
114 // check arguments
115 blas_error_if( layout != Layout::ColMajor &&
116 layout != Layout::RowMajor );
117 blas_error_if( side != Side::Left &&
118 side != Side::Right );
119 blas_error_if( uplo != Uplo::Lower &&
120 uplo != Uplo::Upper );
121 blas_error_if( trans != Op::NoTrans &&
122 trans != Op::Trans &&
123 trans != Op::ConjTrans );
124 blas_error_if( diag != Diag::NonUnit &&
125 diag != Diag::Unit );
126 blas_error_if( m < 0 );
127 blas_error_if( n < 0 );
128
129 // adapt if row major
130 if (layout == Layout::RowMajor) {
131 side = (side == Side::Left)
132 ? Side::Right
133 : Side::Left;
134 if (uplo == Uplo::Lower)
135 uplo = Uplo::Upper;
136 else if (uplo == Uplo::Upper)
137 uplo = Uplo::Lower;
138 swap( m, n );
139 }
140
141 // check remaining arguments
142 blas_error_if( lda < ((side == Side::Left) ? m : n) );
143 blas_error_if( ldb < m );
144
145 // quick return
146 if (m == 0 || n == 0)
147 return;
148
149 // alpha == zero
150 if (alpha == zero) {
151 for (int64_t j = 0; j < n; ++j) {
152 for (int64_t i = 0; i < m; ++i)
153 B(i, j) = zero;
154 }
155 return;
156 }
157
158 // alpha != zero
159 if (side == Side::Left) {
160 if (trans == Op::NoTrans) {
161 if (uplo == Uplo::Upper) {
162 for (int64_t j = 0; j < n; ++j) {
163 for (int64_t i = 0; i < m; ++i)
164 B(i, j) *= alpha;
165 for (int64_t k = m-1; k >= 0; --k) {
166 if (diag == Diag::NonUnit)
167 B(k, j) /= A(k, k);
168 for (int64_t i = 0; i < k; ++i)
169 B(i, j) -= A(i, k)*B(k, j);
170 }
171 }
172 }
173 else { // uplo == Uplo::Lower
174 for (int64_t j = 0; j < n; ++j) {
175 for (int64_t i = 0; i < m; ++i)
176 B(i, j) *= alpha;
177 for (int64_t k = 0; k < m; ++k) {
178 if (diag == Diag::NonUnit)
179 B(k, j) /= A(k, k);
180 for (int64_t i = k+1; i < m; ++i)
181 B(i, j) -= A(i, k)*B(k, j);
182 }
183 }
184 }
185 }
186 else if (trans == Op::Trans) {
187 if (uplo == Uplo::Upper) {
188 for (int64_t j = 0; j < n; ++j) {
189 for (int64_t i = 0; i < m; ++i) {
190 scalar_t sum = alpha*B(i, j);
191 for (int64_t k = 0; k < i; ++k)
192 sum -= A(k, i)*B(k, j);
193 B(i, j) = (diag == Diag::NonUnit)
194 ? sum / A(i, i)
195 : sum;
196 }
197 }
198 }
199 else { // uplo == Uplo::Lower
200 for (int64_t j = 0; j < n; ++j) {
201 for (int64_t i = m-1; i >= 0; --i) {
202 scalar_t sum = alpha*B(i, j);
203 for (int64_t k = i+1; k < m; ++k)
204 sum -= A(k, i)*B(k, j);
205 B(i, j) = (diag == Diag::NonUnit)
206 ? sum / A(i, i)
207 : sum;
208 }
209 }
210 }
211 }
212 else { // trans == Op::ConjTrans
213 if (uplo == Uplo::Upper) {
214 for (int64_t j = 0; j < n; ++j) {
215 for (int64_t i = 0; i < m; ++i) {
216 scalar_t sum = alpha*B(i, j);
217 for (int64_t k = 0; k < i; ++k)
218 sum -= conj(A(k, i))*B(k, j);
219 B(i, j) = (diag == Diag::NonUnit)
220 ? sum / conj(A(i, i))
221 : sum;
222 }
223 }
224 }
225 else { // uplo == Uplo::Lower
226 for (int64_t j = 0; j < n; ++j) {
227 for (int64_t i = m-1; i >= 0; --i) {
228 scalar_t sum = alpha*B(i, j);
229 for (int64_t k = i+1; k < m; ++k)
230 sum -= conj(A(k, i))*B(k, j);
231 B(i, j) = (diag == Diag::NonUnit)
232 ? sum / conj(A(i, i))
233 : sum;
234 }
235 }
236 }
237 }
238 }
239 else { // side == Side::Right
240 if (trans == Op::NoTrans) {
241 if (uplo == Uplo::Upper) {
242 for (int64_t j = 0; j < n; ++j) {
243 for (int64_t i = 0; i < m; ++i)
244 B(i, j) *= alpha;
245 for (int64_t k = 0; k < j; ++k) {
246 for (int64_t i = 0; i < m; ++i)
247 B(i, j) -= B(i, k)*A(k, j);
248 }
249 if (diag == Diag::NonUnit) {
250 for (int64_t i = 0; i < m; ++i)
251 B(i, j) /= A(j, j);
252 }
253 }
254 }
255 else { // uplo == Uplo::Lower
256 for (int64_t j = n-1; j >= 0; --j) {
257 for (int64_t i = 0; i < m; ++i)
258 B(i, j) *= alpha;
259 for (int64_t k = j+1; k < n; ++k) {
260 for (int64_t i = 0; i < m; ++i)
261 B(i, j) -= B(i, k)*A(k, j);
262 }
263 if (diag == Diag::NonUnit) {
264 for (int64_t i = 0; i < m; ++i)
265 B(i, j) /= A(j, j);
266 }
267 }
268 }
269 }
270 else if (trans == Op::Trans) {
271 if (uplo == Uplo::Upper) {
272 for (int64_t k = n-1; k >= 0; --k) {
273 if (diag == Diag::NonUnit) {
274 for (int64_t i = 0; i < m; ++i)
275 B(i, k) /= A(k, k);
276 }
277 for (int64_t j = 0; j < k; ++j) {
278 for (int64_t i = 0; i < m; ++i)
279 B(i, j) -= B(i, k)*A(j, k);
280 }
281 for (int64_t i = 0; i < m; ++i)
282 B(i, k) *= alpha;
283 }
284 }
285 else { // uplo == Uplo::Lower
286 for (int64_t k = 0; k < n; ++k) {
287 if (diag == Diag::NonUnit) {
288 for (int64_t i = 0; i < m; ++i)
289 B(i, k) /= A(k, k);
290 }
291 for (int64_t j = k+1; j < n; ++j) {
292 for (int64_t i = 0; i < m; ++i)
293 B(i, j) -= B(i, k)*A(j, k);
294 }
295 for (int64_t i = 0; i < m; ++i)
296 B(i, k) *= alpha;
297 }
298 }
299 }
300 else { // trans == Op::ConjTrans
301 if (uplo == Uplo::Upper) {
302 for (int64_t k = n-1; k >= 0; --k) {
303 if (diag == Diag::NonUnit) {
304 for (int64_t i = 0; i < m; ++i)
305 B(i, k) /= conj(A(k, k));
306 }
307 for (int64_t j = 0; j < k; ++j) {
308 for (int64_t i = 0; i < m; ++i)
309 B(i, j) -= B(i, k)*conj(A(j, k));
310 }
311 for (int64_t i = 0; i < m; ++i)
312 B(i, k) *= alpha;
313 }
314 }
315 else { // uplo == Uplo::Lower
316 for (int64_t k = 0; k < n; ++k) {
317 if (diag == Diag::NonUnit) {
318 for (int64_t i = 0; i < m; ++i)
319 B(i, k) /= conj(A(k, k));
320 }
321 for (int64_t j = k+1; j < n; ++j) {
322 for (int64_t i = 0; i < m; ++i)
323 B(i, j) -= B(i, k)*conj(A(j, k));
324 }
325 for (int64_t i = 0; i < m; ++i)
326 B(i, k) *= alpha;
327 }
328 }
329 }
330 }
331
332 #undef A
333 #undef B
334}
335
336} // namespace blas
337
338#endif // #ifndef BLAS_TRSM_HH
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67
void trsm(blas::Layout layout, blas::Side side, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float *B, int64_t ldb, blas::Queue &queue)
GPU device, float version.
Definition device_trsm.cc:104