BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
symv.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_SYMV_HH
7#define BLAS_SYMV_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
64
65template <typename TA, typename TX, typename TY>
66void symv(
67 blas::Layout layout,
68 blas::Uplo uplo,
69 int64_t n,
70 blas::scalar_type<TA, TX, TY> alpha,
71 TA const *A, int64_t lda,
72 TX const *x, int64_t incx,
73 blas::scalar_type<TA, TX, TY> beta,
74 TY *y, int64_t incy )
75{
76printf( "%s: %s\n", __func__, __PRETTY_FUNCTION__ );
77 typedef blas::scalar_type<TA, TX, TY> scalar_t;
78
79 #define A(i_, j_) A[ (i_) + (j_)*lda ]
80
81 // constants
82 const scalar_t zero = 0;
83 const scalar_t one = 1;
84
85 // check arguments
86 blas_error_if( layout != Layout::ColMajor &&
87 layout != Layout::RowMajor );
88 blas_error_if( uplo != Uplo::Lower &&
89 uplo != Uplo::Upper );
90 blas_error_if( n < 0 );
91 blas_error_if( lda < n );
92 blas_error_if( incx == 0 );
93 blas_error_if( incy == 0 );
94
95 // quick return
96 if (n == 0 || (alpha == zero && beta == one))
97 return;
98
99 // for row major, swap lower <=> upper
100 if (layout == Layout::RowMajor) {
101 uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
102 }
103
104 int64_t kx = (incx > 0 ? 0 : (-n + 1)*incx);
105 int64_t ky = (incy > 0 ? 0 : (-n + 1)*incy);
106
107 // form y = beta*y
108 if (beta != one) {
109 if (incy == 1) {
110 if (beta == zero) {
111 for (int64_t i = 0; i < n; ++i) {
112 y[i] = zero;
113 }
114 }
115 else {
116 for (int64_t i = 0; i < n; ++i) {
117 y[i] *= beta;
118 }
119 }
120 }
121 else {
122 int64_t iy = ky;
123 if (beta == zero) {
124 for (int64_t i = 0; i < n; ++i) {
125 y[iy] = zero;
126 iy += incy;
127 }
128 }
129 else {
130 for (int64_t i = 0; i < n; ++i) {
131 y[iy] *= beta;
132 iy += incy;
133 }
134 }
135 }
136 }
137 if (alpha == zero)
138 return;
139
140 if (uplo == Uplo::Upper) {
141 // A is stored in upper triangle
142 // form y += alpha * A * x
143 if (incx == 1 && incy == 1) {
144 // unit stride
145 for (int64_t j = 0; j < n; ++j) {
146 scalar_t tmp1 = alpha*x[j];
147 scalar_t tmp2 = zero;
148 for (int64_t i = 0; i < j; ++i) {
149 y[i] += tmp1 * A(i, j);
150 tmp2 += A(i, j) * x[i];
151 }
152 y[j] += tmp1 * A(j, j) + alpha * tmp2;
153 }
154 }
155 else {
156 // non-unit stride
157 int64_t jx = kx;
158 int64_t jy = ky;
159 for (int64_t j = 0; j < n; ++j) {
160 scalar_t tmp1 = alpha*x[jx];
161 scalar_t tmp2 = zero;
162 int64_t ix = kx;
163 int64_t iy = ky;
164 for (int64_t i = 0; i < j; ++i) {
165 y[iy] += tmp1 * A(i, j);
166 tmp2 += A(i, j) * x[ix];
167 ix += incx;
168 iy += incy;
169 }
170 y[jy] += tmp1 * A(j, j) + alpha * tmp2;
171 jx += incx;
172 jy += incy;
173 }
174 }
175 }
176 else {
177 // A is stored in lower triangle
178 // form y += alpha * A * x
179 if (incx == 1 && incy == 1) {
180 // unit stride
181 for (int64_t j = 0; j < n; ++j) {
182 scalar_t tmp1 = alpha*x[j];
183 scalar_t tmp2 = zero;
184 for (int64_t i = j+1; i < n; ++i) {
185 y[i] += tmp1 * A(i, j);
186 tmp2 += A(i, j) * x[i];
187 }
188 y[j] += tmp1 * A(j, j) + alpha * tmp2;
189 }
190 }
191 else {
192 // non-unit stride
193 int64_t jx = kx;
194 int64_t jy = ky;
195 for (int64_t j = 0; j < n; ++j) {
196 scalar_t tmp1 = alpha*x[jx];
197 scalar_t tmp2 = zero;
198 int64_t ix = jx;
199 int64_t iy = jy;
200 for (int64_t i = j+1; i < n; ++i) {
201 ix += incx;
202 iy += incy;
203 y[iy] += tmp1 * A(i, j);
204 tmp2 += A(i, j) * x[ix];
205 }
206 y[jy] += tmp1 * A(j, j) + alpha * tmp2;
207 jx += incx;
208 jy += incy;
209 }
210 }
211 }
212
213 #undef A
214}
215
216} // namespace blas
217
218#endif // #ifndef BLAS_SYMV_HH
void symv(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TA const *A, int64_t lda, TX const *x, int64_t incx, blas::scalar_type< TA, TX, TY > beta, TY *y, int64_t incy)
Symmetric matrix-vector multiply:
Definition symv.hh:66