BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
syr.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_SYR_HH
7#define BLAS_SYR_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
54
55template <typename TA, typename TX>
56void syr(
57 blas::Layout layout,
58 blas::Uplo uplo,
59 int64_t n,
60 blas::scalar_type<TA, TX> alpha,
61 TX const *x, int64_t incx,
62 TA *A, int64_t lda )
63{
64 typedef blas::scalar_type<TA, TX> scalar_t;
65
66 #define A(i_, j_) A[ (i_) + (j_)*lda ]
67
68 // constants
69 const scalar_t zero = 0;
70
71 // check arguments
72 blas_error_if( layout != Layout::ColMajor &&
73 layout != Layout::RowMajor );
74 blas_error_if( uplo != Uplo::Lower &&
75 uplo != Uplo::Upper );
76 blas_error_if( n < 0 );
77 blas_error_if( incx == 0 );
78 blas_error_if( lda < n );
79
80 // quick return
81 if (n == 0 || alpha == zero)
82 return;
83
84 // for row major, swap lower <=> upper
85 if (layout == Layout::RowMajor) {
86 uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
87 }
88
89 int64_t kx = (incx > 0 ? 0 : (-n + 1)*incx);
90 if (uplo == Uplo::Upper) {
91 if (incx == 1) {
92 // unit stride
93 for (int64_t j = 0; j < n; ++j) {
94 // note: NOT skipping if x[j] is zero, for consistent NAN handling
95 scalar_t tmp = alpha * x[j];
96 for (int64_t i = 0; i <= j; ++i) {
97 A(i, j) += x[i] * tmp;
98 }
99 }
100 }
101 else {
102 // non-unit stride
103 int64_t jx = kx;
104 for (int64_t j = 0; j < n; ++j) {
105 scalar_t tmp = alpha * x[jx];
106 int64_t ix = kx;
107 for (int64_t i = 0; i <= j; ++i) {
108 A(i, j) += x[ix] * tmp;
109 ix += incx;
110 }
111 jx += incx;
112 }
113 }
114 }
115 else {
116 // lower triangle
117 if (incx == 1) {
118 // unit stride
119 for (int64_t j = 0; j < n; ++j) {
120 scalar_t tmp = alpha * x[j];
121 for (int64_t i = j; i < n; ++i) {
122 A(i, j) += x[i] * tmp;
123 }
124 }
125 }
126 else {
127 // non-unit stride
128 int64_t jx = kx;
129 for (int64_t j = 0; j < n; ++j) {
130 scalar_t tmp = alpha * x[jx];
131 int64_t ix = jx;
132 for (int64_t i = j; i < n; ++i) {
133 A(i, j) += x[ix] * tmp;
134 ix += incx;
135 }
136 jx += incx;
137 }
138 }
139 }
140
141 #undef A
142}
143
144} // namespace blas
145
146#endif // #ifndef BLAS_SYR_HH
void syr(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX > alpha, TX const *x, int64_t incx, TA *A, int64_t lda)
Symmetric matrix rank-1 update:
Definition syr.hh:56