BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
flops.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_FLOPS_HH
7#define BLAS_FLOPS_HH
8
9#include "blas.hh"
10
11namespace blas {
12
13// =============================================================================
14// Level 1 BLAS
15
16// -----------------------------------------------------------------------------
17inline double fmuls_asum( double n )
18 { return 0; }
19
20inline double fadds_asum( double n )
21 { return n-1; }
22
23// -----------------------------------------------------------------------------
24inline double fmuls_axpy( double n )
25 { return n; }
26
27inline double fadds_axpy( double n )
28 { return n; }
29
30// -----------------------------------------------------------------------------
31inline double fmuls_iamax( double n )
32 { return 0; }
33
34// n-1 compares, which are essentially adds (x > y is x - y > 0)
35inline double fadds_iamax( double n )
36 { return n-1; }
37
38// -----------------------------------------------------------------------------
39inline double fmuls_nrm2( double n )
40 { return n; }
41
42inline double fadds_nrm2( double n )
43 { return n-1; }
44
45// -----------------------------------------------------------------------------
46inline double fmuls_dot( double n )
47 { return n; }
48
49inline double fadds_dot( double n )
50 { return n-1; }
51
52// -----------------------------------------------------------------------------
53inline double fmuls_scal( double n )
54 { return n; }
55
56inline double fadds_scal( double n )
57 { return 0; }
58
59// -----------------------------------------------------------------------------
60inline double fmuls_rot( double n )
61 { return 4 * n; }
62
63inline double fadds_rot( double n )
64 { return 2 * n; }
65
66// -----------------------------------------------------------------------------
67inline double fmuls_rotm( double n )
68 { return 2 * n; }
69
70inline double fadds_rotm( double n )
71 { return 2 * n; }
72
73// =============================================================================
74// Level 2 BLAS
75// most formulas assume alpha=1, beta=0 or 1; otherwise add lower-order terms.
76// i.e., this is minimum flops and bandwidth that could be consumed.
77
78// -----------------------------------------------------------------------------
79inline double fmuls_gemv( double m, double n )
80 { return m*n; }
81
82inline double fadds_gemv( double m, double n )
83 { return m*n; }
84
85// -----------------------------------------------------------------------------
86inline double fmuls_trmv( double n )
87 { return 0.5*n*(n + 1); }
88
89inline double fadds_trmv( double n )
90 { return 0.5*n*(n - 1); }
91
92// -----------------------------------------------------------------------------
93inline double fmuls_ger( double m, double n )
94 { return m*n; }
95
96inline double fadds_ger( double m, double n )
97 { return m*n; }
98
99// -----------------------------------------------------------------------------
100inline double fmuls_gemm( double m, double n, double k )
101 { return m*n*k; }
102
103inline double fadds_gemm( double m, double n, double k )
104 { return m*n*k; }
105
106// -----------------------------------------------------------------------------
107// Assume gbmm is band matrix A (m-by-k) and general matrix B (k-by-n).
108// Usually, the bottom equation (m-kl <= k and k-ku <= m) calculates the flops,
109// but some matrices are too tall or too wide and require extra care.
110// This bottom equation fails because a triangle it subtracts extends beyond
111// the matrix, so it should subtract a trapezoid instead.
112// For the first corner (m-kl > k) case,
113// think rectangle minus trapezoid minus triangle and reduce:
114// (m*k - (m-kl+m-k-kl-1)/2*k - (k-ku-1)*(k-ku)/2)*n;
115// (m*k - (m-kl)*k+(k-1)*k/2 - (k-ku-1)*(k-ku)/2)*n;
116// (kl*k + (k+1)*k/2 - (k-ku-1)*(k-ku)/2)*n;
117// We are conveniently left with the geometric interpretation of
118// rectangle plus triangle minus triangle.
119inline double fmuls_gbmm( double m, double n, double k, double kl, double ku )
120{
121 if (m-kl > k)
122 return (kl*k + (k+1)*k/2. - (k-ku-1)*(k-ku)/2.)*n;
123 if (k-ku > m)
124 return (ku*m - (m-kl-1)*(m-kl)/2. + (m+1)*m/2.)*n;
125 return (m*k - (m-kl-1)*(m-kl)/2. - (k-ku-1)*(k-ku)/2.)*n;
126}
127
128// Assuming alpha=1, beta=1, adds are same as muls.
129inline double fadds_gbmm( double m, double n, double k, double kl, double ku )
130{
131 return fmuls_gbmm( m, n, k, kl, ku );
132}
133
134// -----------------------------------------------------------------------------
135inline double fmuls_hemm( blas::Side side, double m, double n )
136 { return (side == blas::Side::Left ? m*m*n : m*n*n); }
137
138inline double fadds_hemm( blas::Side side, double m, double n )
139 { return (side == blas::Side::Left ? m*m*n : m*n*n); }
140
141// -----------------------------------------------------------------------------
142inline double fmuls_herk( double n, double k )
143 { return 0.5*k*n*(n+1); }
144
145inline double fadds_herk( double n, double k )
146 { return 0.5*k*n*(n+1); }
147
148// -----------------------------------------------------------------------------
149inline double fmuls_her2k( double n, double k )
150 { return k*n*n; }
151
152inline double fadds_her2k( double n, double k )
153 { return k*n*n; }
154
155// -----------------------------------------------------------------------------
156inline double fmuls_trmm( blas::Side side, double m, double n )
157{
158 if (side == blas::Side::Left)
159 return 0.5*n*m*(m + 1);
160 else
161 return 0.5*m*n*(n + 1);
162}
163
164inline double fadds_trmm( blas::Side side, double m, double n )
165{
166 if (side == blas::Side::Left)
167 return 0.5*n*m*(m - 1);
168 else
169 return 0.5*m*n*(n - 1);
170}
171
172//==============================================================================
173// template class. Example:
174// gflop< float >::gemm( m, n, k ) yields flops for sgemm.
175// gflop< std::complex<float> >::gemm( m, n, k ) yields flops for cgemm.
176//==============================================================================
177template <typename T>
178class Gbyte
179{
180public:
181 // ----------------------------------------
182 // Level 1 BLAS
183 // read x
184 static double asum( double n )
185 { return 1e-9 * (n * sizeof(T)); }
186
187 // read x, y; write y
188 static double axpy( double n )
189 { return 1e-9 * (3*n * sizeof(T)); }
190
191 // read x; write y
192 static double copy( double n )
193 { return 1e-9 * (2*n * sizeof(T)); }
194
195 // read x
196 static double iamax( double n )
197 { return 1e-9 * (n * sizeof(T)); }
198
199 // read x
200 static double nrm2( double n )
201 { return 1e-9 * (n * sizeof(T)); }
202
203 // read x, y
204 static double dot( double n )
205 { return 1e-9 * (2*n * sizeof(T)); }
206
207 // read x; write x
208 static double scal( double n )
209 { return 1e-9 * (2*n * sizeof(T)); }
210
211 // read x, y; write x, y
212 static double swap( double n )
213 { return 1e-9 * (4*n * sizeof(T)); }
214
215 // ----------------------------------------
216 // Level 2 BLAS
217 // read A, x; write y
218 static double gemv( double m, double n )
219 { return 1e-9 * ((m*n + m + n) * sizeof(T)); }
220
221 // read A triangle, x; write y
222 static double hemv( double n )
223 { return 1e-9 * ((0.5*(n+1)*n + 2*n) * sizeof(T)); }
224
225 static double symv( double n )
226 { return hemv( n ); }
227
228 // read A triangle, x; write x
229 static double trmv( double n )
230 { return 1e-9 * ((0.5*(n+1)*n + 2*n) * sizeof(T)); }
231
232 static double trsv( double n )
233 { return trmv( n ); }
234
235 // read A, x, y; write A
236 static double ger( double m, double n )
237 { return 1e-9 * ((2*m*n + m + n) * sizeof(T)); }
238
239 // read A triangle, x; write A triangle
240 static double her( double n )
241 { return 1e-9 * (((n+1)*n + n) * sizeof(T)); }
242
243 static double syr( double n )
244 { return her( n ); }
245
246 // read A triangle, x, y; write A triangle
247 static double her2( double n )
248 { return 1e-9 * (((n+1)*n + n + n) * sizeof(T)); }
249
250 static double syr2( double n )
251 { return her2( n ); }
252
253 // read A; write B
254 static double copy_2d( double m, double n )
255 { return 1e-9 * (2*m*n * sizeof(T)); }
256
257 // ----------------------------------------
258 // Level 3 BLAS
259 // read A, B, C; write C
260 static double gemm( double m, double n, double k )
261 { return 1e-9 * ((m*k + k*n + 2*m*n) * sizeof(T)); }
262
263 static double hemm( blas::Side side, double m, double n )
264 {
265 // read A, B, C; write C
266 double sizeA = (side == blas::Side::Left ? 0.5*m*(m+1) : 0.5*n*(n+1));
267 return 1e-9 * ((sizeA + 3*m*n) * sizeof(T));
268 }
269
270 static double symm( blas::Side side, double m, double n )
271 { return hemm( side, m, n ); }
272
273 static double herk( double n, double k )
274 {
275 // read A, C; write C
276 double sizeC = 0.5*n*(n+1);
277 return 1e-9 * ((n*k + 2*sizeC) * sizeof(T));
278 }
279
280 static double syrk( double n, double k )
281 { return herk( n, k ); }
282
283 static double her2k( double n, double k )
284 {
285 // read A, B, C; write C
286 double sizeC = 0.5*n*(n+1);
287 return 1e-9 * ((2*n*k + 2*sizeC) * sizeof(T));
288 }
289
290 static double syr2k( double n, double k )
291 { return her2k( n, k ); }
292
293 static double trmm( blas::Side side, double m, double n )
294 {
295 // read A triangle, x; write x
296 if (side == blas::Side::Left)
297 return 1e-9 * ((0.5*(m+1)*m + 2*m*n) * sizeof(T));
298 else
299 return 1e-9 * ((0.5*(n+1)*n + 2*m*n) * sizeof(T));
300 }
301
302 static double trsm( blas::Side side, double m, double n )
303 { return trmm( side, m, n ); }
304};
305
306//==============================================================================
307// Traits to lookup number of operations per multiply and add.
308template <typename T>
309class FlopTraits
310{
311public:
312 static constexpr double mul_ops = 1;
313 static constexpr double add_ops = 1;
314};
315
316//------------------------------------------------------------------------------
317// specialization for complex
318// flops = 6*muls + 2*adds
319template <typename T>
320class FlopTraits< std::complex<T> >
321{
322public:
323 static constexpr double mul_ops = 6;
324 static constexpr double add_ops = 2;
325};
326
327//==============================================================================
328// template class. Example:
329// gflop< float >::gemm( m, n, k ) yields flops for sgemm.
330// gflop< std::complex<float> >::gemm( m, n, k ) yields flops for cgemm.
331//==============================================================================
332template <typename T>
333class Gflop
334{
335public:
336 static constexpr double mul_ops = FlopTraits<T>::mul_ops;
337 static constexpr double add_ops = FlopTraits<T>::add_ops;
338
339 // ----------------------------------------
340 // Level 1 BLAS
341 static double asum( double n )
342 { return 1e-9 * (mul_ops*fmuls_asum(n) +
343 add_ops*fadds_asum(n)); }
344
345 static double axpy( double n )
346 { return 1e-9 * (mul_ops*fmuls_axpy(n) +
347 add_ops*fadds_axpy(n)); }
348
349 static double copy( double n )
350 { return 0; }
351
352 static double iamax( double n )
353 { return 1e-9 * (mul_ops*fmuls_iamax(n) +
354 add_ops*fadds_iamax(n)); }
355
356 static double nrm2( double n )
357 { return 1e-9 * (mul_ops*fmuls_nrm2(n) +
358 add_ops*fadds_nrm2(n)); }
359
360 static double dot( double n )
361 { return 1e-9 * (mul_ops*fmuls_dot(n) +
362 add_ops*fadds_dot(n)); }
363
364 static double scal( double n )
365 { return 1e-9 * (mul_ops*fmuls_scal(n) +
366 add_ops*fadds_scal(n)); }
367
368 static double swap( double n )
369 { return 0; }
370
371 static double rot( double n )
372 { return 1e-9 * (mul_ops*fmuls_rot(n) +
373 add_ops*fadds_rot(n)); }
374
375 static double rotm( double n )
376 { return 1e-9 * (mul_ops*fmuls_rotm(n) +
377 add_ops*fadds_rotm(n)); }
378
379 // ----------------------------------------
380 // Level 2 BLAS
381 static double gemv(double m, double n)
382 { return 1e-9 * (mul_ops*fmuls_gemv(m, n) +
383 add_ops*fadds_gemv(m, n)); }
384
385 static double symv(double n)
386 { return gemv( n, n ); }
387
388 static double hemv(double n)
389 { return symv( n ); }
390
391 static double trmv( double n )
392 { return 1e-9 * (mul_ops*fmuls_trmv(n) +
393 add_ops*fadds_trmv(n)); }
394
395 static double trsv( double n )
396 { return trmv( n ); }
397
398 static double her( double n )
399 { return ger( n, n ); }
400
401 static double syr( double n )
402 { return her( n ); }
403
404 static double ger( double m, double n )
405 { return 1e-9 * (mul_ops*fmuls_ger(m, n) +
406 add_ops*fadds_ger(m, n)); }
407
408 static double her2( double n )
409 { return 2*ger( n, n ); }
410
411 static double syr2( double n )
412 { return her2( n ); }
413
414 // ----------------------------------------
415 // Level 3 BLAS
416 static double gemm(double m, double n, double k)
417 { return 1e-9 * (mul_ops*fmuls_gemm(m, n, k) +
418 add_ops*fadds_gemm(m, n, k)); }
419
420 static double gbmm(double m, double n, double k, double kl, double ku)
421 { return 1e-9 * (mul_ops*fmuls_gbmm(m, n, k, kl, ku) +
422 add_ops*fadds_gbmm(m, n, k, kl, ku)); }
423
424 static double hemm(blas::Side side, double m, double n)
425 { return 1e-9 * (mul_ops*fmuls_hemm(side, m, n) +
426 add_ops*fadds_hemm(side, m, n)); }
427
428 static double hbmm(double m, double n, double kd)
429 { return gbmm(m, n, m, kd, kd); }
430
431 static double symm(blas::Side side, double m, double n)
432 { return hemm( side, m, n ); }
433
434 static double herk(double n, double k)
435 { return 1e-9 * (mul_ops*fmuls_herk(n, k) +
436 add_ops*fadds_herk(n, k)); }
437
438 static double syrk(double n, double k)
439 { return herk( n, k ); }
440
441 static double her2k(double n, double k)
442 { return 1e-9 * (mul_ops*fmuls_her2k(n, k) +
443 add_ops*fadds_her2k(n, k)); }
444
445 static double syr2k(double n, double k)
446 { return her2k( n, k ); }
447
448 static double trmm(blas::Side side, double m, double n)
449 { return 1e-9 * (mul_ops*fmuls_trmm(side, m, n) +
450 add_ops*fadds_trmm(side, m, n)); }
451
452 static double trsm(blas::Side side, double m, double n)
453 { return trmm( side, m, n ); }
454
455};
456
457} // namespace blas
458
459#endif // #ifndef BLAS_FLOPS_HH
real_type< T > asum(int64_t n, T const *x, int64_t incx)
Definition asum.hh:35
void axpy(int64_t n, blas::scalar_type< TX, TY > alpha, TX const *x, int64_t incx, TY *y, int64_t incy)
Add scaled vector, .
Definition axpy.hh:43
void copy(int64_t n, TX const *x, int64_t incx, TY *y, int64_t incy)
Copy vector, .
Definition copy.hh:40
void dot(int64_t n, float const *x, int64_t incx, float const *y, int64_t incy, float *result, blas::Queue &queue)
GPU device, float version.
Definition device_dot.cc:139
void gemm(blas::Layout layout, blas::Op transA, blas::Op transB, int64_t m, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_gemm.cc:119
void gemv(blas::Layout layout, blas::Op trans, int64_t m, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TA const *A, int64_t lda, TX const *x, int64_t incx, blas::scalar_type< TA, TX, TY > beta, TY *y, int64_t incy)
General matrix-vector multiply:
Definition gemv.hh:79
void ger(blas::Layout layout, int64_t m, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TX const *x, int64_t incx, TY const *y, int64_t incy, TA *A, int64_t lda)
General matrix rank-1 update:
Definition ger.hh:60
void hemm(blas::Layout layout, blas::Side side, blas::Uplo uplo, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_hemm.cc:107
void hemv(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TA const *A, int64_t lda, TX const *x, int64_t incx, blas::scalar_type< TA, TX, TY > beta, TY *y, int64_t incy)
Hermitian matrix-vector multiply:
Definition hemv.hh:69
void her2(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TX const *x, int64_t incx, TY const *y, int64_t incy, TA *A, int64_t lda)
Hermitian matrix rank-2 update:
Definition her2.hh:66
void her2k(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_her2k.cc:100
void her(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::real_type< TA, TX > alpha, TX const *x, int64_t incx, TA *A, int64_t lda)
Hermitian matrix rank-1 update:
Definition her.hh:59
void herk(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_herk.cc:92
int64_t iamax(int64_t n, T const *x, int64_t incx)
Definition iamax.hh:34
void nrm2(int64_t n, float const *x, int64_t incx, float *result, blas::Queue &queue)
GPU device, float version.
Definition device_nrm2.cc:84
void rot(int64_t n, TX *x, int64_t incx, TY *y, int64_t incy, blas::real_type< TX, TY > c, blas::scalar_type< TX, TY > s)
Apply plane rotation:
Definition rot.hh:53
void rotm(int64_t n, TX *x, int64_t incx, TY *y, int64_t incy, blas::scalar_type< TX, TY > const param[5])
Apply modified (fast) plane rotation, H:
Definition rotm.hh:50
void scal(int64_t n, float alpha, float *x, int64_t incx, blas::Queue &queue)
GPU device, float version.
Definition device_scal.cc:65
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67
void symm(blas::Layout layout, blas::Side side, blas::Uplo uplo, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_symm.cc:106
void symv(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TA const *A, int64_t lda, TX const *x, int64_t incx, blas::scalar_type< TA, TX, TY > beta, TY *y, int64_t incy)
Symmetric matrix-vector multiply:
Definition symv.hh:66
void syr2(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TX const *x, int64_t incx, TY const *y, int64_t incy, TA *A, int64_t lda)
Symmetric matrix rank-2 update:
Definition syr2.hh:63
void syr2k(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_syr2k.cc:107
void syr(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX > alpha, TX const *x, int64_t incx, TA *A, int64_t lda)
Symmetric matrix rank-1 update:
Definition syr.hh:56
void syrk(blas::Layout layout, blas::Uplo uplo, blas::Op trans, int64_t n, int64_t k, float alpha, float const *A, int64_t lda, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_syrk.cc:101
void trmm(blas::Layout layout, blas::Side side, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float *B, int64_t ldb, blas::Queue &queue)
GPU device, float version.
Definition device_trmm.cc:104
void trmv(blas::Layout layout, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t n, TA const *A, int64_t lda, TX *x, int64_t incx)
Triangular matrix-vector multiply:
Definition trmv.hh:69
void trsm(blas::Layout layout, blas::Side side, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float *B, int64_t ldb, blas::Queue &queue)
GPU device, float version.
Definition device_trsm.cc:104
void trsv(blas::Layout layout, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t n, TA const *A, int64_t lda, TX *x, int64_t incx)
Solve the triangular matrix-vector equation.
Definition trsv.hh:73