89template <
typename TA,
typename TB,
typename TC>
94 int64_t m, int64_t n, int64_t k,
95 scalar_type<TA, TB, TC> alpha,
96 TA
const *A, int64_t lda,
97 TB
const *B, int64_t ldb,
98 scalar_type<TA, TB, TC> beta,
102 if (layout == Layout::RowMajor) {
116 blas_error_if_msg( layout != Layout::ColMajor,
117 "layout != Layout::ColMajor && layout != Layout::RowMajor" );
120 typedef blas::scalar_type<TA, TB, TC> scalar_t;
122 #define A(i_, j_) A[ (i_) + (j_)*lda ]
123 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
124 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
127 const scalar_t zero = 0;
128 const scalar_t one = 1;
131 blas_error_if( transA != Op::NoTrans &&
132 transA != Op::Trans &&
133 transA != Op::ConjTrans );
134 blas_error_if( transB != Op::NoTrans &&
135 transB != Op::Trans &&
136 transB != Op::ConjTrans );
137 blas_error_if( m < 0 );
138 blas_error_if( n < 0 );
139 blas_error_if( k < 0 );
141 blas_error_if( lda < ((transA != Op::NoTrans) ? k : m) );
142 blas_error_if( ldb < ((transB != Op::NoTrans) ? n : k) );
143 blas_error_if( ldc < m );
146 if (m == 0 || n == 0 || k == 0)
152 for (int64_t j = 0; j < n; ++j) {
153 for (int64_t i = 0; i < m; ++i)
157 else if (beta != one) {
158 for (int64_t j = 0; j < n; ++j) {
159 for (int64_t i = 0; i < m; ++i)
167 if (transA == Op::NoTrans) {
168 if (transB == Op::NoTrans) {
169 for (int64_t j = 0; j < n; ++j) {
170 for (int64_t i = 0; i < m; ++i)
172 for (int64_t l = 0; l < k; ++l) {
173 scalar_t alpha_Blj = alpha*B(l, j);
174 for (int64_t i = 0; i < m; ++i)
175 C(i, j) += A(i, l)*alpha_Blj;
179 else if (transB == Op::Trans) {
180 for (int64_t j = 0; j < n; ++j) {
181 for (int64_t i = 0; i < m; ++i)
183 for (int64_t l = 0; l < k; ++l) {
184 scalar_t alpha_Bjl = alpha*B(j, l);
185 for (int64_t i = 0; i < m; ++i)
186 C(i, j) += A(i, l)*alpha_Bjl;
191 for (int64_t j = 0; j < n; ++j) {
192 for (int64_t i = 0; i < m; ++i)
194 for (int64_t l = 0; l < k; ++l) {
195 scalar_t alpha_Bjl = alpha*conj(B(j, l));
196 for (int64_t i = 0; i < m; ++i)
197 C(i, j) += A(i, l)*alpha_Bjl;
202 else if (transA == Op::Trans) {
203 if (transB == Op::NoTrans) {
204 for (int64_t j = 0; j < n; ++j) {
205 for (int64_t i = 0; i < m; ++i) {
207 for (int64_t l = 0; l < k; ++l)
208 sum += A(l, i)*B(l, j);
209 C(i, j) = alpha*sum + beta*C(i, j);
213 else if (transB == Op::Trans) {
214 for (int64_t j = 0; j < n; ++j) {
215 for (int64_t i = 0; i < m; ++i) {
217 for (int64_t l = 0; l < k; ++l)
218 sum += A(l, i)*B(j, l);
219 C(i, j) = alpha*sum + beta*C(i, j);
224 for (int64_t j = 0; j < n; ++j) {
225 for (int64_t i = 0; i < m; ++i) {
227 for (int64_t l = 0; l < k; ++l)
228 sum += A(l, i)*conj(B(j, l));
229 C(i, j) = alpha*sum + beta*C(i, j);
235 if (transB == Op::NoTrans) {
236 for (int64_t j = 0; j < n; ++j) {
237 for (int64_t i = 0; i < m; ++i) {
239 for (int64_t l = 0; l < k; ++l)
240 sum += conj(A(l, i))*B(l, j);
241 C(i, j) = alpha*sum + beta*C(i, j);
245 else if (transB == Op::Trans) {
246 for (int64_t j = 0; j < n; ++j) {
247 for (int64_t i = 0; i < m; ++i) {
249 for (int64_t l = 0; l < k; ++l)
250 sum += conj(A(l, i))*B(j, l);
251 C(i, j) = alpha*sum + beta*C(i, j);
256 for (int64_t j = 0; j < n; ++j) {
257 for (int64_t i = 0; i < m; ++i) {
259 for (int64_t l = 0; l < k; ++l)
260 sum += A(l, i)*B(j, l);
261 C(i, j) = alpha*conj(sum) + beta*C(i, j);
void gemm(blas::Layout layout, blas::Op transA, blas::Op transB, int64_t m, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_gemm.cc:119