BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
gemv.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_GEMV_HH
7#define BLAS_GEMV_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
77
78template <typename TA, typename TX, typename TY>
79void gemv(
80 blas::Layout layout,
81 blas::Op trans,
82 int64_t m, int64_t n,
83 blas::scalar_type<TA, TX, TY> alpha,
84 TA const *A, int64_t lda,
85 TX const *x, int64_t incx,
86 blas::scalar_type<TA, TX, TY> beta,
87 TY *y, int64_t incy )
88{
89 using std::swap;
90 using scalar_t = blas::scalar_type<TA, TX, TY>;
91
92 #define A(i_, j_) A[ (i_) + (j_)*lda ]
93
94 // constants
95 const scalar_t zero = 0;
96 const scalar_t one = 1;
97
98 // check arguments
99 blas_error_if( layout != Layout::ColMajor &&
100 layout != Layout::RowMajor );
101 blas_error_if( trans != Op::NoTrans &&
102 trans != Op::Trans &&
103 trans != Op::ConjTrans );
104 blas_error_if( m < 0 );
105 blas_error_if( n < 0 );
106
107 if (layout == Layout::ColMajor)
108 blas_error_if( lda < m );
109 else
110 blas_error_if( lda < n );
111
112 blas_error_if( incx == 0 );
113 blas_error_if( incy == 0 );
114
115 // quick return
116 if (m == 0 || n == 0 || (alpha == zero && beta == one))
117 return;
118
119 bool doconj = false;
120 if (layout == Layout::RowMajor) {
121 // A => A^T; A^T => A; A^H => A & conj
122 swap( m, n );
123 if (trans == Op::NoTrans) {
124 trans = Op::Trans;
125 }
126 else {
127 if (trans == Op::ConjTrans) {
128 doconj = true;
129 }
130 trans = Op::NoTrans;
131 }
132 }
133
134 int64_t lenx = (trans == Op::NoTrans ? n : m);
135 int64_t leny = (trans == Op::NoTrans ? m : n);
136 int64_t kx = (incx > 0 ? 0 : (-lenx + 1)*incx);
137 int64_t ky = (incy > 0 ? 0 : (-leny + 1)*incy);
138
139 // ----------
140 // form y = beta*y
141 if (beta != one) {
142 if (incy == 1) {
143 if (beta == zero) {
144 for (int64_t i = 0; i < leny; ++i) {
145 y[i] = zero;
146 }
147 }
148 else {
149 for (int64_t i = 0; i < leny; ++i) {
150 y[i] *= beta;
151 }
152 }
153 }
154 else {
155 int64_t iy = ky;
156 if (beta == zero) {
157 for (int64_t i = 0; i < leny; ++i) {
158 y[iy] = zero;
159 iy += incy;
160 }
161 }
162 else {
163 for (int64_t i = 0; i < leny; ++i) {
164 y[iy] *= beta;
165 iy += incy;
166 }
167 }
168 }
169 }
170 if (alpha == zero)
171 return;
172
173 // ----------
174 if (trans == Op::NoTrans && ! doconj) {
175 // form y += alpha * A * x
176 int64_t jx = kx;
177 if (incy == 1) {
178 for (int64_t j = 0; j < n; ++j) {
179 scalar_t tmp = alpha*x[jx];
180 jx += incx;
181 for (int64_t i = 0; i < m; ++i) {
182 y[i] += tmp * A(i, j);
183 }
184 }
185 }
186 else {
187 for (int64_t j = 0; j < n; ++j) {
188 scalar_t tmp = alpha*x[jx];
189 jx += incx;
190 int64_t iy = ky;
191 for (int64_t i = 0; i < m; ++i) {
192 y[iy] += tmp * A(i, j);
193 iy += incy;
194 }
195 }
196 }
197 }
198 else if (trans == Op::NoTrans && doconj) {
199 // form y += alpha * conj( A ) * x
200 // this occurs for row-major A^H * x
201 int64_t jx = kx;
202 if (incy == 1) {
203 for (int64_t j = 0; j < n; ++j) {
204 scalar_t tmp = alpha*x[jx];
205 jx += incx;
206 for (int64_t i = 0; i < m; ++i) {
207 y[i] += tmp * conj(A(i, j));
208 }
209 }
210 }
211 else {
212 for (int64_t j = 0; j < n; ++j) {
213 scalar_t tmp = alpha*x[jx];
214 jx += incx;
215 int64_t iy = ky;
216 for (int64_t i = 0; i < m; ++i) {
217 y[iy] += tmp * conj(A(i, j));
218 iy += incy;
219 }
220 }
221 }
222 }
223 else if (trans == Op::Trans) {
224 // form y += alpha * A^T * x
225 int64_t jy = ky;
226 if (incx == 1) {
227 for (int64_t j = 0; j < n; ++j) {
228 scalar_t tmp = zero;
229 for (int64_t i = 0; i < m; ++i) {
230 tmp += A(i, j) * x[i];
231 }
232 y[jy] += alpha*tmp;
233 jy += incy;
234 }
235 }
236 else {
237 for (int64_t j = 0; j < n; ++j) {
238 scalar_t tmp = zero;
239 int64_t ix = kx;
240 for (int64_t i = 0; i < m; ++i) {
241 tmp += A(i, j) * x[ix];
242 ix += incx;
243 }
244 y[jy] += alpha*tmp;
245 jy += incy;
246 }
247 }
248 }
249 else {
250 // form y += alpha * A^H * x
251 int64_t jy = ky;
252 if (incx == 1) {
253 for (int64_t j = 0; j < n; ++j) {
254 scalar_t tmp = zero;
255 for (int64_t i = 0; i < m; ++i) {
256 tmp += conj(A(i, j)) * x[i];
257 }
258 y[jy] += alpha*tmp;
259 jy += incy;
260 }
261 }
262 else {
263 for (int64_t j = 0; j < n; ++j) {
264 scalar_t tmp = zero;
265 int64_t ix = kx;
266 for (int64_t i = 0; i < m; ++i) {
267 tmp += conj(A(i, j)) * x[ix];
268 ix += incx;
269 }
270 y[jy] += alpha*tmp;
271 jy += incy;
272 }
273 }
274 }
275
276 #undef A
277}
278
279} // namespace blas
280
281#endif // #ifndef BLAS_GEMV_HH
void gemv(blas::Layout layout, blas::Op trans, int64_t m, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TA const *A, int64_t lda, TX const *x, int64_t incx, blas::scalar_type< TA, TX, TY > beta, TY *y, int64_t incy)
General matrix-vector multiply:
Definition gemv.hh:79
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67