BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
gemm.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_GEMM_HH
7#define BLAS_GEMM_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
88
89template <typename TA, typename TB, typename TC>
90void gemm(
91 blas::Layout layout,
92 blas::Op transA,
93 blas::Op transB,
94 int64_t m, int64_t n, int64_t k,
95 scalar_type<TA, TB, TC> alpha,
96 TA const *A, int64_t lda,
97 TB const *B, int64_t ldb,
98 scalar_type<TA, TB, TC> beta,
99 TC *C, int64_t ldc )
100{
101 // redirect if row major
102 if (layout == Layout::RowMajor) {
103 return gemm(
104 Layout::ColMajor,
105 transB,
106 transA,
107 n, m, k,
108 alpha,
109 B, ldb,
110 A, lda,
111 beta,
112 C, ldc);
113 }
114 else {
115 // check layout
116 blas_error_if_msg( layout != Layout::ColMajor,
117 "layout != Layout::ColMajor && layout != Layout::RowMajor" );
118 }
119
120 typedef blas::scalar_type<TA, TB, TC> scalar_t;
121
122 #define A(i_, j_) A[ (i_) + (j_)*lda ]
123 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
124 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
125
126 // constants
127 const scalar_t zero = 0;
128 const scalar_t one = 1;
129
130 // check arguments
131 blas_error_if( transA != Op::NoTrans &&
132 transA != Op::Trans &&
133 transA != Op::ConjTrans );
134 blas_error_if( transB != Op::NoTrans &&
135 transB != Op::Trans &&
136 transB != Op::ConjTrans );
137 blas_error_if( m < 0 );
138 blas_error_if( n < 0 );
139 blas_error_if( k < 0 );
140
141 blas_error_if( lda < ((transA != Op::NoTrans) ? k : m) );
142 blas_error_if( ldb < ((transB != Op::NoTrans) ? n : k) );
143 blas_error_if( ldc < m );
144
145 // quick return
146 if (m == 0 || n == 0 || k == 0)
147 return;
148
149 // alpha == zero
150 if (alpha == zero) {
151 if (beta == zero) {
152 for (int64_t j = 0; j < n; ++j) {
153 for (int64_t i = 0; i < m; ++i)
154 C(i, j) = zero;
155 }
156 }
157 else if (beta != one) {
158 for (int64_t j = 0; j < n; ++j) {
159 for (int64_t i = 0; i < m; ++i)
160 C(i, j) *= beta;
161 }
162 }
163 return;
164 }
165
166 // alpha != zero
167 if (transA == Op::NoTrans) {
168 if (transB == Op::NoTrans) {
169 for (int64_t j = 0; j < n; ++j) {
170 for (int64_t i = 0; i < m; ++i)
171 C(i, j) *= beta;
172 for (int64_t l = 0; l < k; ++l) {
173 scalar_t alpha_Blj = alpha*B(l, j);
174 for (int64_t i = 0; i < m; ++i)
175 C(i, j) += A(i, l)*alpha_Blj;
176 }
177 }
178 }
179 else if (transB == Op::Trans) {
180 for (int64_t j = 0; j < n; ++j) {
181 for (int64_t i = 0; i < m; ++i)
182 C(i, j) *= beta;
183 for (int64_t l = 0; l < k; ++l) {
184 scalar_t alpha_Bjl = alpha*B(j, l);
185 for (int64_t i = 0; i < m; ++i)
186 C(i, j) += A(i, l)*alpha_Bjl;
187 }
188 }
189 }
190 else { // transB == Op::ConjTrans
191 for (int64_t j = 0; j < n; ++j) {
192 for (int64_t i = 0; i < m; ++i)
193 C(i, j) *= beta;
194 for (int64_t l = 0; l < k; ++l) {
195 scalar_t alpha_Bjl = alpha*conj(B(j, l));
196 for (int64_t i = 0; i < m; ++i)
197 C(i, j) += A(i, l)*alpha_Bjl;
198 }
199 }
200 }
201 }
202 else if (transA == Op::Trans) {
203 if (transB == Op::NoTrans) {
204 for (int64_t j = 0; j < n; ++j) {
205 for (int64_t i = 0; i < m; ++i) {
206 scalar_t sum = zero;
207 for (int64_t l = 0; l < k; ++l)
208 sum += A(l, i)*B(l, j);
209 C(i, j) = alpha*sum + beta*C(i, j);
210 }
211 }
212 }
213 else if (transB == Op::Trans) {
214 for (int64_t j = 0; j < n; ++j) {
215 for (int64_t i = 0; i < m; ++i) {
216 scalar_t sum = zero;
217 for (int64_t l = 0; l < k; ++l)
218 sum += A(l, i)*B(j, l);
219 C(i, j) = alpha*sum + beta*C(i, j);
220 }
221 }
222 }
223 else { // transB == Op::ConjTrans
224 for (int64_t j = 0; j < n; ++j) {
225 for (int64_t i = 0; i < m; ++i) {
226 scalar_t sum = zero;
227 for (int64_t l = 0; l < k; ++l)
228 sum += A(l, i)*conj(B(j, l));
229 C(i, j) = alpha*sum + beta*C(i, j);
230 }
231 }
232 }
233 }
234 else { // transA == Op::ConjTrans
235 if (transB == Op::NoTrans) {
236 for (int64_t j = 0; j < n; ++j) {
237 for (int64_t i = 0; i < m; ++i) {
238 scalar_t sum = zero;
239 for (int64_t l = 0; l < k; ++l)
240 sum += conj(A(l, i))*B(l, j);
241 C(i, j) = alpha*sum + beta*C(i, j);
242 }
243 }
244 }
245 else if (transB == Op::Trans) {
246 for (int64_t j = 0; j < n; ++j) {
247 for (int64_t i = 0; i < m; ++i) {
248 scalar_t sum = zero;
249 for (int64_t l = 0; l < k; ++l)
250 sum += conj(A(l, i))*B(j, l);
251 C(i, j) = alpha*sum + beta*C(i, j);
252 }
253 }
254 }
255 else { // transB == Op::ConjTrans
256 for (int64_t j = 0; j < n; ++j) {
257 for (int64_t i = 0; i < m; ++i) {
258 scalar_t sum = zero;
259 for (int64_t l = 0; l < k; ++l)
260 sum += A(l, i)*B(j, l); // little improvement here
261 C(i, j) = alpha*conj(sum) + beta*C(i, j);
262 }
263 }
264 }
265 }
266
267 #undef A
268 #undef B
269 #undef C
270}
271
272} // namespace blas
273
274#endif // #ifndef BLAS_GEMM_HH
void gemm(blas::Layout layout, blas::Op transA, blas::Op transB, int64_t m, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_gemm.cc:119