BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
trsv.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_TRSV_HH
7#define BLAS_TRSV_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
71
72template <typename TA, typename TX>
73void trsv(
74 blas::Layout layout,
75 blas::Uplo uplo,
76 blas::Op trans,
77 blas::Diag diag,
78 int64_t n,
79 TA const *A, int64_t lda,
80 TX *x, int64_t incx )
81{
82 #define A(i_, j_) A[ (i_) + (j_)*lda ]
83
84 // check arguments
85 blas_error_if( layout != Layout::ColMajor &&
86 layout != Layout::RowMajor );
87 blas_error_if( uplo != Uplo::Lower &&
88 uplo != Uplo::Upper );
89 blas_error_if( trans != Op::NoTrans &&
90 trans != Op::Trans &&
91 trans != Op::ConjTrans );
92 blas_error_if( diag != Diag::NonUnit &&
93 diag != Diag::Unit );
94 blas_error_if( n < 0 );
95 blas_error_if( lda < n );
96 blas_error_if( incx == 0 );
97
98 // quick return
99 if (n == 0)
100 return;
101
102 // for row major, swap lower <=> upper and
103 // A => A^T; A^T => A; A^H => A & conj
104 bool doconj = false;
105 if (layout == Layout::RowMajor) {
106 uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
107 if (trans == Op::NoTrans) {
108 trans = Op::Trans;
109 }
110 else {
111 if (trans == Op::ConjTrans) {
112 doconj = true;
113 }
114 trans = Op::NoTrans;
115 }
116 }
117
118 bool nonunit = (diag == Diag::NonUnit);
119 int64_t kx = (incx > 0 ? 0 : (-n + 1)*incx);
120
121 if (trans == Op::NoTrans && ! doconj) {
122 // Form x := A^{-1} * x
123 if (uplo == Uplo::Upper) {
124 // upper
125 if (incx == 1) {
126 // unit stride
127 for (int64_t j = n - 1; j >= 0; --j) {
128 // note: NOT skipping if x[j] is zero, for consistent NAN handling
129 if (nonunit) {
130 x[j] /= A(j, j);
131 }
132 TX tmp = x[j];
133 for (int64_t i = j - 1; i >= 0; --i) {
134 x[i] -= tmp * A(i, j);
135 }
136 }
137 }
138 else {
139 // non-unit stride
140 int64_t jx = kx + (n - 1)*incx;
141 for (int64_t j = n - 1; j >= 0; --j) {
142 // note: NOT skipping if x[j] is zero ...
143 if (nonunit) {
144 x[jx] /= A(j, j);
145 }
146 TX tmp = x[jx];
147 int64_t ix = jx;
148 for (int64_t i = j - 1; i >= 0; --i) {
149 ix -= incx;
150 x[ix] -= tmp * A(i, j);
151 }
152 jx -= incx;
153 }
154 }
155 }
156 else {
157 // lower
158 if (incx == 1) {
159 // unit stride
160 for (int64_t j = 0; j < n; ++j) {
161 // note: NOT skipping if x[j] is zero ...
162 if (nonunit) {
163 x[j] /= A(j, j);
164 }
165 TX tmp = x[j];
166 for (int64_t i = j + 1; i < n; ++i) {
167 x[i] -= tmp * A(i, j);
168 }
169 }
170 }
171 else {
172 // non-unit stride
173 int64_t jx = kx;
174 for (int64_t j = 0; j < n; ++j) {
175 // note: NOT skipping if x[j] is zero ...
176 if (nonunit) {
177 x[jx] /= A(j, j);
178 }
179 TX tmp = x[jx];
180 int64_t ix = jx;
181 for (int64_t i = j+1; i < n; ++i) {
182 ix += incx;
183 x[ix] -= tmp * A(i, j);
184 }
185 jx += incx;
186 }
187 }
188 }
189 }
190 else if (trans == Op::NoTrans && doconj) {
191 // Form x := A^{-1} * x
192 if (uplo == Uplo::Upper) {
193 // upper
194 if (incx == 1) {
195 // unit stride
196 for (int64_t j = n - 1; j >= 0; --j) {
197 // note: NOT skipping if x[j] is zero, for consistent NAN handling
198 if (nonunit) {
199 x[j] /= conj( A(j, j) );
200 }
201 TX tmp = x[j];
202 for (int64_t i = j - 1; i >= 0; --i) {
203 x[i] -= tmp * conj( A(i, j) );
204 }
205 }
206 }
207 else {
208 // non-unit stride
209 int64_t jx = kx + (n - 1)*incx;
210 for (int64_t j = n - 1; j >= 0; --j) {
211 // note: NOT skipping if x[j] is zero ...
212 if (nonunit) {
213 x[jx] /= conj( A(j, j) );
214 }
215 TX tmp = x[jx];
216 int64_t ix = jx;
217 for (int64_t i = j - 1; i >= 0; --i) {
218 ix -= incx;
219 x[ix] -= tmp * conj( A(i, j) );
220 }
221 jx -= incx;
222 }
223 }
224 }
225 else {
226 // lower
227 if (incx == 1) {
228 // unit stride
229 for (int64_t j = 0; j < n; ++j) {
230 // note: NOT skipping if x[j] is zero ...
231 if (nonunit) {
232 x[j] /= conj( A(j, j) );
233 }
234 TX tmp = x[j];
235 for (int64_t i = j + 1; i < n; ++i) {
236 x[i] -= tmp * conj( A(i, j) );
237 }
238 }
239 }
240 else {
241 // non-unit stride
242 int64_t jx = kx;
243 for (int64_t j = 0; j < n; ++j) {
244 // note: NOT skipping if x[j] is zero ...
245 if (nonunit) {
246 x[jx] /= conj( A(j, j) );
247 }
248 TX tmp = x[jx];
249 int64_t ix = jx;
250 for (int64_t i = j+1; i < n; ++i) {
251 ix += incx;
252 x[ix] -= tmp * conj( A(i, j) );
253 }
254 jx += incx;
255 }
256 }
257 }
258 }
259 else if (trans == Op::Trans) {
260 // Form x := A^{-T} * x
261 if (uplo == Uplo::Upper) {
262 // upper
263 if (incx == 1) {
264 // unit stride
265 for (int64_t j = 0; j < n; ++j) {
266 TX tmp = x[j];
267 for (int64_t i = 0; i <= j - 1; ++i) {
268 tmp -= A(i, j) * x[i];
269 }
270 if (nonunit) {
271 tmp /= A(j, j);
272 }
273 x[j] = tmp;
274 }
275 }
276 else {
277 // non-unit stride
278 int64_t jx = kx;
279 for (int64_t j = 0; j < n; ++j) {
280 TX tmp = x[jx];
281 int64_t ix = kx;
282 for (int64_t i = 0; i <= j - 1; ++i) {
283 tmp -= A(i, j) * x[ix];
284 ix += incx;
285 }
286 if (nonunit) {
287 tmp /= A(j, j);
288 }
289 x[jx] = tmp;
290 jx += incx;
291 }
292 }
293 }
294 else {
295 // lower
296 if (incx == 1) {
297 // unit stride
298 for (int64_t j = n - 1; j >= 0; --j) {
299 TX tmp = x[j];
300 for (int64_t i = j + 1; i < n; ++i) {
301 tmp -= A(i, j) * x[i];
302 }
303 if (nonunit) {
304 tmp /= A(j, j);
305 }
306 x[j] = tmp;
307 }
308 }
309 else {
310 // non-unit stride
311 kx += (n - 1)*incx;
312 int64_t jx = kx;
313 for (int64_t j = n - 1; j >= 0; --j) {
314 int64_t ix = kx;
315 TX tmp = x[jx];
316 for (int64_t i = n - 1; i >= j + 1; --i) {
317 tmp -= A(i, j) * x[ix];
318 ix -= incx;
319 }
320 if (nonunit) {
321 tmp /= A(j, j);
322 }
323 x[jx] = tmp;
324 jx -= incx;
325 }
326 }
327 }
328 }
329 else {
330 // Form x := A^{-H} * x
331 // same code as above A^{-T} * x case, except add conj()
332 if (uplo == Uplo::Upper) {
333 // upper
334 if (incx == 1) {
335 // unit stride
336 for (int64_t j = 0; j < n; ++j) {
337 TX tmp = x[j];
338 for (int64_t i = 0; i <= j - 1; ++i) {
339 tmp -= conj( A(i, j) ) * x[i];
340 }
341 if (nonunit) {
342 tmp /= conj( A(j, j) );
343 }
344 x[j] = tmp;
345 }
346 }
347 else {
348 // non-unit stride
349 int64_t jx = kx;
350 for (int64_t j = 0; j < n; ++j) {
351 TX tmp = x[jx];
352 int64_t ix = kx;
353 for (int64_t i = 0; i <= j - 1; ++i) {
354 tmp -= conj( A(i, j) ) * x[ix];
355 ix += incx;
356 }
357 if (nonunit) {
358 tmp /= conj( A(j, j) );
359 }
360 x[jx] = tmp;
361 jx += incx;
362 }
363 }
364 }
365 else {
366 // lower
367 if (incx == 1) {
368 // unit stride
369 for (int64_t j = n - 1; j >= 0; --j) {
370 TX tmp = x[j];
371 for (int64_t i = j + 1; i < n; ++i) {
372 tmp -= conj( A(i, j) ) * x[i];
373 }
374 if (nonunit) {
375 tmp /= conj( A(j, j) );
376 }
377 x[j] = tmp;
378 }
379 }
380 else {
381 // non-unit stride
382 kx += (n - 1)*incx;
383 int64_t jx = kx;
384 for (int64_t j = n - 1; j >= 0; --j) {
385 int64_t ix = kx;
386 TX tmp = x[jx];
387 for (int64_t i = n - 1; i >= j + 1; --i) {
388 tmp -= conj( A(i, j) ) * x[ix];
389 ix -= incx;
390 }
391 if (nonunit) {
392 tmp /= conj( A(j, j) );
393 }
394 x[jx] = tmp;
395 jx -= incx;
396 }
397 }
398 }
399 }
400
401 #undef A
402}
403
404} // namespace blas
405
406#endif // #ifndef BLAS_TRSV_HH
void trsv(blas::Layout layout, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t n, TA const *A, int64_t lda, TX *x, int64_t incx)
Solve the triangular matrix-vector equation.
Definition trsv.hh:73