BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
hemm.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_HEMM_HH
7#define BLAS_HEMM_HH
8
9#include "blas/util.hh"
10#include "blas/symm.hh"
11
12#include <limits>
13
14namespace blas {
15
16// =============================================================================
83
84template <typename TA, typename TB, typename TC>
85void hemm(
86 blas::Layout layout,
87 blas::Side side,
88 blas::Uplo uplo,
89 int64_t m, int64_t n,
90 scalar_type<TA, TB, TC> alpha,
91 TA const *A, int64_t lda,
92 TB const *B, int64_t ldb,
93 scalar_type<TA, TB, TC> beta,
94 TC *C, int64_t ldc )
95{
96 using std::swap;
97 using scalar_t = blas::scalar_type<TA, TB, TC>;
98
99 #define A(i_, j_) A[ (i_) + (j_)*lda ]
100 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
101 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
102
103 // constants
104 const scalar_t zero = 0;
105 const scalar_t one = 1;
106
107 // check arguments
108 blas_error_if( layout != Layout::ColMajor &&
109 layout != Layout::RowMajor );
110 blas_error_if( side != Side::Left &&
111 side != Side::Right );
112 blas_error_if( uplo != Uplo::Lower &&
113 uplo != Uplo::Upper &&
114 uplo != Uplo::General );
115 blas_error_if( m < 0 );
116 blas_error_if( n < 0 );
117
118 // adapt if row major
119 if (layout == Layout::RowMajor) {
120 side = (side == Side::Left)
121 ? Side::Right
122 : Side::Left;
123 if (uplo == Uplo::Lower)
124 uplo = Uplo::Upper;
125 else if (uplo == Uplo::Upper)
126 uplo = Uplo::Lower;
127 swap( m, n );
128 }
129
130 // check remaining arguments
131 blas_error_if( lda < ((side == Side::Left) ? m : n) );
132 blas_error_if( ldb < m );
133 blas_error_if( ldc < m );
134
135 // quick return
136 if (m == 0 || n == 0)
137 return;
138
139 // alpha == zero
140 if (alpha == zero) {
141 if (beta == zero) {
142 for (int64_t j = 0; j < n; ++j) {
143 for (int64_t i = 0; i < m; ++i)
144 C(i, j) = zero;
145 }
146 }
147 else if (beta != one) {
148 for (int64_t j = 0; j < n; ++j) {
149 for (int64_t i = 0; i < m; ++i)
150 C(i, j) *= beta;
151 }
152 }
153 return;
154 }
155
156 // alpha != zero
157 if (side == Side::Left) {
158 if (uplo != Uplo::Lower) {
159 // uplo == Uplo::Upper or uplo == Uplo::General
160 for (int64_t j = 0; j < n; ++j) {
161 for (int64_t i = 0; i < m; ++i) {
162
163 scalar_t alpha_Bij = alpha*B(i, j);
164 scalar_t sum = zero;
165
166 for (int64_t k = 0; k < i; ++k) {
167 C(k, j) += A(k, i) * alpha_Bij;
168 sum += conj( A(k, i) ) * B(k, j);
169 }
170 C(i, j) =
171 beta * C(i, j)
172 + real( A(i, i) ) * alpha_Bij
173 + alpha * sum;
174 }
175 }
176 }
177 else {
178 // uplo == Uplo::Lower
179 for (int64_t j = 0; j < n; ++j) {
180 for (int64_t i = m-1; i >= 0; --i) {
181
182 scalar_t alpha_Bij = alpha*B(i, j);
183 scalar_t sum = zero;
184
185 for (int64_t k = i+1; k < m; ++k) {
186 C(k, j) += A(k, i) * alpha_Bij;
187 sum += conj( A(k, i) ) * B(k, j);
188 }
189 C(i, j) =
190 beta * C(i, j)
191 + real( A(i, i) ) * alpha_Bij
192 + alpha * sum;
193 }
194 }
195 }
196 }
197 else { // side == Side::Right
198 if (uplo != Uplo::Lower) {
199 // uplo == Uplo::Upper or uplo == Uplo::General
200 for (int64_t j = 0; j < n; ++j) {
201
202 scalar_t alpha_Akj = alpha * real( A(j, j) );
203
204 for (int64_t i = 0; i < m; ++i)
205 C(i, j) = beta * C(i, j) + B(i, j) * alpha_Akj;
206
207 for (int64_t k = 0; k < j; ++k) {
208 alpha_Akj = alpha*A(k, j);
209 for (int64_t i = 0; i < m; ++i)
210 C(i, j) += B(i, k) * alpha_Akj;
211 }
212
213 for (int64_t k = j+1; k < n; ++k) {
214 alpha_Akj = alpha * conj( A(j, k) );
215 for (int64_t i = 0; i < m; ++i)
216 C(i, j) += B(i, k) * alpha_Akj;
217 }
218 }
219 }
220 else {
221 // uplo == Uplo::Lower
222 for (int64_t j = 0; j < n; ++j) {
223
224 scalar_t alpha_Akj = alpha * real( A(j, j) );
225
226 for (int64_t i = 0; i < m; ++i)
227 C(i, j) = beta * C(i, j) + B(i, j) * alpha_Akj;
228
229 for (int64_t k = 0; k < j; ++k) {
230 alpha_Akj = alpha * conj( A(j, k) );
231 for (int64_t i = 0; i < m; ++i)
232 C(i, j) += B(i, k) * alpha_Akj;
233 }
234
235 for (int64_t k = j+1; k < n; ++k) {
236 alpha_Akj = alpha*A(k, j);
237 for (int64_t i = 0; i < m; ++i)
238 C(i, j) += B(i, k) * alpha_Akj;
239 }
240 }
241 }
242 }
243
244 #undef A
245 #undef B
246 #undef C
247}
248
249} // namespace blas
250
251#endif // #ifndef BLAS_HEMM_HH
void hemm(blas::Layout layout, blas::Side side, blas::Uplo uplo, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_hemm.cc:107
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67