BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
trmv.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_TRMV_HH
7#define BLAS_TRMV_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
67
68template <typename TA, typename TX>
69void trmv(
70 blas::Layout layout,
71 blas::Uplo uplo,
72 blas::Op trans,
73 blas::Diag diag,
74 int64_t n,
75 TA const *A, int64_t lda,
76 TX *x, int64_t incx )
77{
78 #define A(i_, j_) A[ (i_) + (j_)*lda ]
79
80 // check arguments
81 blas_error_if( layout != Layout::ColMajor &&
82 layout != Layout::RowMajor );
83 blas_error_if( uplo != Uplo::Lower &&
84 uplo != Uplo::Upper );
85 blas_error_if( trans != Op::NoTrans &&
86 trans != Op::Trans &&
87 trans != Op::ConjTrans );
88 blas_error_if( diag != Diag::NonUnit &&
89 diag != Diag::Unit );
90 blas_error_if( n < 0 );
91 blas_error_if( lda < n );
92 blas_error_if( incx == 0 );
93
94 // quick return
95 if (n == 0)
96 return;
97
98 // for row major, swap lower <=> upper and
99 // A => A^T; A^T => A; A^H => A & conj
100 bool doconj = false;
101 if (layout == Layout::RowMajor) {
102 uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
103 if (trans == Op::NoTrans) {
104 trans = Op::Trans;
105 }
106 else {
107 if (trans == Op::ConjTrans) {
108 doconj = true;
109 }
110 trans = Op::NoTrans;
111 }
112 }
113
114 bool nonunit = (diag == Diag::NonUnit);
115 int64_t kx = (incx > 0 ? 0 : (-n + 1)*incx);
116
117 if (trans == Op::NoTrans && ! doconj) {
118 // Form x := A*x
119 if (uplo == Uplo::Upper) {
120 // upper
121 if (incx == 1) {
122 // unit stride
123 for (int64_t j = 0; j < n; ++j) {
124 // note: NOT skipping if x[j] is zero, for consistent NAN handling
125 TX tmp = x[j];
126 for (int64_t i = 0; i <= j-1; ++i) {
127 x[i] += tmp * A(i, j);
128 }
129 if (nonunit) {
130 x[j] *= A(j, j);
131 }
132 }
133 }
134 else {
135 // non-unit stride
136 int64_t jx = kx;
137 for (int64_t j = 0; j < n; ++j) {
138 // note: NOT skipping if x[j] is zero ...
139 TX tmp = x[jx];
140 int64_t ix = kx;
141 for (int64_t i = 0; i <= j-1; ++i) {
142 x[ix] += tmp * A(i, j);
143 ix += incx;
144 }
145 if (nonunit) {
146 x[jx] *= A(j, j);
147 }
148 jx += incx;
149 }
150 }
151 }
152 else {
153 // lower
154 if (incx == 1) {
155 // unit stride
156 for (int64_t j = n-1; j >= 0; --j) {
157 // note: NOT skipping if x[j] is zero ...
158 TX tmp = x[j];
159 for (int64_t i = n-1; i >= j+1; --i) {
160 x[i] += tmp * A(i, j);
161 }
162 if (nonunit) {
163 x[j] *= A(j, j);
164 }
165 }
166 }
167 else {
168 // non-unit stride
169 kx += (n - 1)*incx;
170 int64_t jx = kx;
171 for (int64_t j = n-1; j >= 0; --j) {
172 // note: NOT skipping if x[j] is zero ...
173 TX tmp = x[jx];
174 int64_t ix = kx;
175 for (int64_t i = n-1; i >= j+1; --i) {
176 x[ix] += tmp * A(i, j);
177 ix -= incx;
178 }
179 if (nonunit) {
180 x[jx] *= A(j, j);
181 }
182 jx -= incx;
183 }
184 }
185 }
186 }
187 else if (trans == Op::NoTrans && doconj) {
188 // Form x := A*x
189 if (uplo == Uplo::Upper) {
190 // upper
191 if (incx == 1) {
192 // unit stride
193 for (int64_t j = 0; j < n; ++j) {
194 // note: NOT skipping if x[j] is zero, for consistent NAN handling
195 TX tmp = x[j];
196 for (int64_t i = 0; i <= j-1; ++i) {
197 x[i] += tmp * conj( A(i, j) );
198 }
199 if (nonunit) {
200 x[j] *= conj( A(j, j) );
201 }
202 }
203 }
204 else {
205 // non-unit stride
206 int64_t jx = kx;
207 for (int64_t j = 0; j < n; ++j) {
208 // note: NOT skipping if x[j] is zero ...
209 TX tmp = x[jx];
210 int64_t ix = kx;
211 for (int64_t i = 0; i <= j-1; ++i) {
212 x[ix] += tmp * conj( A(i, j) );
213 ix += incx;
214 }
215 if (nonunit) {
216 x[jx] *= conj( A(j, j) );
217 }
218 jx += incx;
219 }
220 }
221 }
222 else {
223 // lower
224 if (incx == 1) {
225 // unit stride
226 for (int64_t j = n-1; j >= 0; --j) {
227 // note: NOT skipping if x[j] is zero ...
228 TX tmp = x[j];
229 for (int64_t i = n-1; i >= j+1; --i) {
230 x[i] += tmp * conj( A(i, j) );
231 }
232 if (nonunit) {
233 x[j] *= conj( A(j, j) );
234 }
235 }
236 }
237 else {
238 // non-unit stride
239 kx += (n - 1)*incx;
240 int64_t jx = kx;
241 for (int64_t j = n-1; j >= 0; --j) {
242 // note: NOT skipping if x[j] is zero ...
243 TX tmp = x[jx];
244 int64_t ix = kx;
245 for (int64_t i = n-1; i >= j+1; --i) {
246 x[ix] += tmp * conj( A(i, j) );
247 ix -= incx;
248 }
249 if (nonunit) {
250 x[jx] *= conj( A(j, j) );
251 }
252 jx -= incx;
253 }
254 }
255 }
256 }
257 else if (trans == Op::Trans) {
258 // Form x := A^T * x
259 if (uplo == Uplo::Upper) {
260 // upper
261 if (incx == 1) {
262 // unit stride
263 for (int64_t j = n-1; j >= 0; --j) {
264 TX tmp = x[j];
265 if (nonunit) {
266 tmp *= A(j, j);
267 }
268 for (int64_t i = j - 1; i >= 0; --i) {
269 tmp += A(i, j) * x[i];
270 }
271 x[j] = tmp;
272 }
273 }
274 else {
275 // non-unit stride
276 int64_t jx = kx + (n - 1)*incx;
277 for (int64_t j = n-1; j >= 0; --j) {
278 TX tmp = x[jx];
279 int64_t ix = jx;
280 if (nonunit) {
281 tmp *= A(j, j);
282 }
283 for (int64_t i = j - 1; i >= 0; --i) {
284 ix -= incx;
285 tmp += A(i, j) * x[ix];
286 }
287 x[jx] = tmp;
288 jx -= incx;
289 }
290 }
291 }
292 else {
293 // lower
294 if (incx == 1) {
295 // unit stride
296 for (int64_t j = 0; j < n; ++j) {
297 TX tmp = x[j];
298 if (nonunit) {
299 tmp *= A(j, j);
300 }
301 for (int64_t i = j + 1; i < n; ++i) {
302 tmp += A(i, j) * x[i];
303 }
304 x[j] = tmp;
305 }
306 }
307 else {
308 // non-unit stride
309 int64_t jx = kx;
310 for (int64_t j = 0; j < n; ++j) {
311 TX tmp = x[jx];
312 int64_t ix = jx;
313 if (nonunit) {
314 tmp *= A(j, j);
315 }
316 for (int64_t i = j + 1; i < n; ++i) {
317 ix += incx;
318 tmp += A(i, j) * x[ix];
319 }
320 x[jx] = tmp;
321 jx += incx;
322 }
323 }
324 }
325 }
326 else {
327 // Form x := A^H * x
328 // same code as above A^T * x case, except add conj()
329 if (uplo == Uplo::Upper) {
330 // upper
331 if (incx == 1) {
332 // unit stride
333 for (int64_t j = n-1; j >= 0; --j) {
334 TX tmp = x[j];
335 if (nonunit) {
336 tmp *= conj( A(j, j) );
337 }
338 for (int64_t i = j - 1; i >= 0; --i) {
339 tmp += conj( A(i, j) ) * x[i];
340 }
341 x[j] = tmp;
342 }
343 }
344 else {
345 // non-unit stride
346 int64_t jx = kx + (n - 1)*incx;
347 for (int64_t j = n-1; j >= 0; --j) {
348 TX tmp = x[jx];
349 int64_t ix = jx;
350 if (nonunit) {
351 tmp *= conj( A(j, j) );
352 }
353 for (int64_t i = j - 1; i >= 0; --i) {
354 ix -= incx;
355 tmp += conj( A(i, j) ) * x[ix];
356 }
357 x[jx] = tmp;
358 jx -= incx;
359 }
360 }
361 }
362 else {
363 // lower
364 if (incx == 1) {
365 // unit stride
366 for (int64_t j = 0; j < n; ++j) {
367 TX tmp = x[j];
368 if (nonunit) {
369 tmp *= conj( A(j, j) );
370 }
371 for (int64_t i = j + 1; i < n; ++i) {
372 tmp += conj( A(i, j) ) * x[i];
373 }
374 x[j] = tmp;
375 }
376 }
377 else {
378 // non-unit stride
379 int64_t jx = kx;
380 for (int64_t j = 0; j < n; ++j) {
381 TX tmp = x[jx];
382 int64_t ix = jx;
383 if (nonunit) {
384 tmp *= conj( A(j, j) );
385 }
386 for (int64_t i = j + 1; i < n; ++i) {
387 ix += incx;
388 tmp += conj( A(i, j) ) * x[ix];
389 }
390 x[jx] = tmp;
391 jx += incx;
392 }
393 }
394 }
395 }
396
397 #undef A
398}
399
400} // namespace blas
401
402#endif // #ifndef BLAS_TRMV_HH
void trmv(blas::Layout layout, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t n, TA const *A, int64_t lda, TX *x, int64_t incx)
Triangular matrix-vector multiply:
Definition trmv.hh:69