BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
hemv.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_HEMV_HH
7#define BLAS_HEMV_HH
8
9#include "blas/util.hh"
10#include "blas/symv.hh"
11
12#include <limits>
13
14namespace blas {
15
16// =============================================================================
67
68template <typename TA, typename TX, typename TY>
69void hemv(
70 blas::Layout layout,
71 blas::Uplo uplo,
72 int64_t n,
73 blas::scalar_type<TA, TX, TY> alpha,
74 TA const *A, int64_t lda,
75 TX const *x, int64_t incx,
76 blas::scalar_type<TA, TX, TY> beta,
77 TY *y, int64_t incy )
78{
79 typedef blas::scalar_type<TA, TX, TY> scalar_t;
80
81 #define A(i_, j_) A[ (i_) + (j_)*lda ]
82
83 // constants
84 const scalar_t zero = 0;
85 const scalar_t one = 1;
86
87 // check arguments
88 blas_error_if( layout != Layout::ColMajor &&
89 layout != Layout::RowMajor );
90 blas_error_if( uplo != Uplo::Lower &&
91 uplo != Uplo::Upper );
92 blas_error_if( n < 0 );
93 blas_error_if( lda < n );
94 blas_error_if( incx == 0 );
95 blas_error_if( incy == 0 );
96
97 // quick return
98 if (n == 0 || (alpha == zero && beta == one))
99 return;
100
101 int64_t kx = (incx > 0 ? 0 : (-n + 1)*incx);
102 int64_t ky = (incy > 0 ? 0 : (-n + 1)*incy);
103
104 // form y = beta*y
105 if (beta != one) {
106 if (incy == 1) {
107 if (beta == zero) {
108 for (int64_t i = 0; i < n; ++i) {
109 y[i] = zero;
110 }
111 }
112 else {
113 for (int64_t i = 0; i < n; ++i) {
114 y[i] *= beta;
115 }
116 }
117 }
118 else {
119 int64_t iy = ky;
120 if (beta == zero) {
121 for (int64_t i = 0; i < n; ++i) {
122 y[iy] = zero;
123 iy += incy;
124 }
125 }
126 else {
127 for (int64_t i = 0; i < n; ++i) {
128 y[iy] *= beta;
129 iy += incy;
130 }
131 }
132 }
133 }
134 if (alpha == zero)
135 return;
136
137 if (layout == Layout::ColMajor) {
138 if (uplo == Uplo::Upper) {
139 // A is stored in upper triangle
140 // form y += alpha * A * x
141 if (incx == 1 && incy == 1) {
142 // unit stride
143 for (int64_t j = 0; j < n; ++j) {
144 scalar_t tmp1 = alpha*x[j];
145 scalar_t tmp2 = zero;
146 for (int64_t i = 0; i < j; ++i) {
147 y[i] += tmp1 * A(i, j);
148 tmp2 += conj( A(i, j) ) * x[i];
149 }
150 y[j] += tmp1 * real( A(j, j) ) + alpha * tmp2;
151 }
152 }
153 else {
154 // non-unit stride
155 int64_t jx = kx;
156 int64_t jy = ky;
157 for (int64_t j = 0; j < n; ++j) {
158 scalar_t tmp1 = alpha*x[jx];
159 scalar_t tmp2 = zero;
160 int64_t ix = kx;
161 int64_t iy = ky;
162 for (int64_t i = 0; i < j; ++i) {
163 y[iy] += tmp1 * A(i, j);
164 tmp2 += conj( A(i, j) ) * x[ix];
165 ix += incx;
166 iy += incy;
167 }
168 y[jy] += tmp1 * real( A(j, j) ) + alpha * tmp2;
169 jx += incx;
170 jy += incy;
171 }
172 }
173 }
174 else if (uplo == Uplo::Lower) {
175 // A is stored in lower triangle
176 // form y += alpha * A * x
177 if (incx == 1 && incy == 1) {
178 for (int64_t j = 0; j < n; ++j) {
179 scalar_t tmp1 = alpha*x[j];
180 scalar_t tmp2 = zero;
181 for (int64_t i = j+1; i < n; ++i) {
182 y[i] += tmp1 * A(i, j);
183 tmp2 += conj( A(i, j) ) * x[i];
184 }
185 y[j] += tmp1 * real( A(j, j) ) + alpha * tmp2;
186 }
187 }
188 else {
189 int64_t jx = kx;
190 int64_t jy = ky;
191 for (int64_t j = 0; j < n; ++j) {
192 scalar_t tmp1 = alpha*x[jx];
193 scalar_t tmp2 = zero;
194 int64_t ix = jx;
195 int64_t iy = jy;
196 for (int64_t i = j+1; i < n; ++i) {
197 ix += incx;
198 iy += incy;
199 y[iy] += tmp1 * A(i, j);
200 tmp2 += conj( A(i, j) ) * x[ix];
201 }
202 y[jy] += tmp1 * real( A(j, j) ) + alpha * tmp2;
203 jx += incx;
204 jy += incy;
205 }
206 }
207 }
208 }
209 else {
210 if (uplo == Uplo::Lower) {
211 // A is stored in lower triangle
212 // form y += alpha * A * x
213 if (incx == 1 && incy == 1) {
214 // unit stride
215 for (int64_t j = 0; j < n; ++j) {
216 scalar_t tmp1 = alpha*x[j];
217 scalar_t tmp2 = zero;
218 for (int64_t i = 0; i < j; ++i) {
219 y[i] += tmp1 * conj( A(i, j) );
220 tmp2 += A(i, j) * x[i];
221 }
222 y[j] += tmp1 * real( A(j, j) ) + alpha * tmp2;
223 }
224 }
225 else {
226 // non-unit stride
227 int64_t jx = kx;
228 int64_t jy = ky;
229 for (int64_t j = 0; j < n; ++j) {
230 scalar_t tmp1 = alpha*x[jx];
231 scalar_t tmp2 = zero;
232 int64_t ix = kx;
233 int64_t iy = ky;
234 for (int64_t i = 0; i < j; ++i) {
235 y[iy] += tmp1 * conj( A(i, j) );
236 tmp2 += A(i, j) * x[ix];
237 ix += incx;
238 iy += incy;
239 }
240 y[jy] += tmp1 * real( A(j, j) ) + alpha * tmp2;
241 jx += incx;
242 jy += incy;
243 }
244 }
245 }
246 else if (uplo == Uplo::Upper) {
247 // A is stored in upper triangle
248 // form y += alpha * A * x
249 if (incx == 1 && incy == 1) {
250 for (int64_t j = 0; j < n; ++j) {
251 scalar_t tmp1 = alpha*x[j];
252 scalar_t tmp2 = zero;
253 for (int64_t i = j+1; i < n; ++i) {
254 y[i] += tmp1 * conj( A(i, j) );
255 tmp2 += A(i, j) * x[i];
256 }
257 y[j] += tmp1 * real( A(j, j) ) + alpha * tmp2;
258 }
259 }
260 else {
261 int64_t jx = kx;
262 int64_t jy = ky;
263 for (int64_t j = 0; j < n; ++j) {
264 scalar_t tmp1 = alpha*x[jx];
265 scalar_t tmp2 = zero;
266 int64_t ix = jx;
267 int64_t iy = jy;
268 for (int64_t i = j+1; i < n; ++i) {
269 ix += incx;
270 iy += incy;
271 y[iy] += tmp1 * conj( A(i, j) );
272 tmp2 += A(i, j) * x[ix];
273 }
274 y[jy] += tmp1 * real( A(j, j) ) + alpha * tmp2;
275 jx += incx;
276 jy += incy;
277 }
278 }
279 }
280 }
281
282 #undef A
283}
284
285} // namespace blas
286
287#endif // #ifndef BLAS_HEMV_HH
void hemv(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TA const *A, int64_t lda, TX const *x, int64_t incx, blas::scalar_type< TA, TX, TY > beta, TY *y, int64_t incy)
Hermitian matrix-vector multiply:
Definition hemv.hh:69