BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
util.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_UTIL_HH
7#define BLAS_UTIL_HH
8
9#include <exception>
10#include <complex>
11#include <cstdarg>
12#include <limits>
13#include <vector>
14#include <algorithm>
15
16#include <assert.h>
17
18namespace blas {
19
21#define blas_unused( var ) ((void)var)
22
23// For printf, int64_t could be long (%ld), which is >= 32 bits,
24// or long long (%lld), guaranteed >= 64 bits.
25// Cast to llong to ensure printing 64 bits.
26using llong = long long;
27
28//------------------------------------------------------------------------------
30class Error: public std::exception {
31public:
34 std::exception()
35 {}
36
38 Error( std::string const& msg ):
39 std::exception(),
40 msg_( msg )
41 {}
42
44 Error( const char* msg, const char* func ):
45 std::exception(),
46 msg_( std::string(msg) + ", in function " + func )
47 {}
48
50 virtual const char* what() const noexcept override
51 { return msg_.c_str(); }
52
53private:
54 std::string msg_;
55};
56
57// -----------------------------------------------------------------------------
58enum class Layout : char { ColMajor = 'C', RowMajor = 'R' };
59enum class Op : char { NoTrans = 'N', Trans = 'T', ConjTrans = 'C' };
60enum class Uplo : char { Upper = 'U', Lower = 'L', General = 'G' };
61enum class Diag : char { NonUnit = 'N', Unit = 'U' };
62enum class Side : char { Left = 'L', Right = 'R' };
63
64extern const char* Layout_help;
65extern const char* Op_help;
66extern const char* Uplo_help;
67extern const char* Diag_help;
68extern const char* Side_help;
69
70// -----------------------------------------------------------------------------
71// Convert enum to LAPACK-style char.
72
73inline char to_char( Layout value ) { return char( value ); }
74inline char to_char( Op value ) { return char( value ); }
75inline char to_char( Uplo value ) { return char( value ); }
76inline char to_char( Diag value ) { return char( value ); }
77inline char to_char( Side value ) { return char( value ); }
78
79[[deprecated("use to_char. To be removed 2025-05.")]]
80inline char layout2char( Layout value ) { return char( value ); }
81
82[[deprecated("use to_char. To be removed 2025-05.")]]
83inline char op2char( Op value ) { return char( value ); }
84
85[[deprecated("use to_char. To be removed 2025-05.")]]
86inline char uplo2char( Uplo value ) { return char( value ); }
87
88[[deprecated("use to_char. To be removed 2025-05.")]]
89inline char diag2char( Diag value ) { return char( value ); }
90
91[[deprecated("use to_char. To be removed 2025-05.")]]
92inline char side2char( Side value ) { return char( value ); }
93
94//------------------------------------------------------------------------------
95// Convert enum to LAPACK-style C string (const char*).
96
97inline const char* to_c_string( Layout value )
98{
99 switch (value) {
100 case Layout::ColMajor: return "col";
101 case Layout::RowMajor: return "row";
102 }
103 return "?";
104}
105
106inline const char* to_c_string( Op value )
107{
108 switch (value) {
109 case Op::NoTrans: return "notrans";
110 case Op::Trans: return "trans";
111 case Op::ConjTrans: return "conj";
112 }
113 return "?";
114}
115
116inline const char* to_c_string( Uplo value )
117{
118 switch (value) {
119 case Uplo::Lower: return "lower";
120 case Uplo::Upper: return "upper";
121 case Uplo::General: return "general";
122 }
123 return "?";
124}
125
126inline const char* to_c_string( Diag value )
127{
128 switch (value) {
129 case Diag::NonUnit: return "nonunit";
130 case Diag::Unit: return "unit";
131 }
132 return "?";
133}
134
135inline const char* to_c_string( Side value )
136{
137 switch (value) {
138 case Side::Left: return "left";
139 case Side::Right: return "right";
140 }
141 return "?";
142}
143
144//------------------------------------------------------------------------------
145// Convert enum to LAPACK-style C++ string.
146
147inline std::string to_string( Layout value )
148{
149 return to_c_string( value );
150}
151
152inline std::string to_string( Op value )
153{
154 return to_c_string( value );
155}
156
157inline std::string to_string( Uplo value )
158{
159 return to_c_string( value );
160}
161
162inline std::string to_string( Diag value )
163{
164 return to_c_string( value );
165}
166
167inline std::string to_string( Side value )
168{
169 return to_c_string( value );
170}
171
172//------------------------------------------------------------------------------
173// Convert enum to LAPACK-style C string.
174
175[[deprecated("use to_string or to_c_string. To be removed 2025-05.")]]
176inline const char* layout2str( Layout value )
177{
178 return to_c_string( value );
179}
180
181[[deprecated("use to_string or to_c_string. To be removed 2025-05.")]]
182inline const char* op2str( Op value )
183{
184 return to_c_string( value );
185}
186
187[[deprecated("use to_string or to_c_string. To be removed 2025-05.")]]
188inline const char* uplo2str( Uplo value )
189{
190 return to_c_string( value );
191}
192
193[[deprecated("use to_string or to_c_string. To be removed 2025-05.")]]
194inline const char* diag2str( Diag value )
195{
196 return to_c_string( value );
197}
198
199[[deprecated("use to_string or to_c_string. To be removed 2025-05.")]]
200inline const char* side2str( Side value )
201{
202 return to_c_string( value );
203}
204
205//------------------------------------------------------------------------------
206// Convert LAPACK-style char or string to enum.
207
208inline void from_string( std::string const& str, Layout* val )
209{
210 std::string str_ = str;
211 std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower );
212 if (str_ == "c" || str_ == "colmajor")
213 *val = Layout::ColMajor;
214 else if (str_ == "r" || str_ == "rowmajor")
215 *val = Layout::RowMajor;
216 else
217 throw Error( "unknown Layout: " + str );
218}
219
220inline void from_string( std::string const& str, Op* val )
221{
222 std::string str_ = str;
223 std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower );
224 if (str_ == "n" || str_ == "notrans")
225 *val = Op::NoTrans;
226 else if (str_ == "t" || str_ == "trans")
227 *val = Op::Trans;
228 else if (str_ == "c" || str_ == "conjtrans")
229 *val = Op::ConjTrans;
230 else
231 throw Error( "unknown Op: " + str );
232}
233
234inline void from_string( std::string const& str, Uplo* val )
235{
236 std::string str_ = str;
237 std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower );
238 if (str_ == "l" || str_ == "lower")
239 *val = Uplo::Lower;
240 else if (str_ == "u" || str_ == "upper")
241 *val = Uplo::Upper;
242 else if (str_ == "g" || str_ == "general")
243 *val = Uplo::General;
244 else
245 throw Error( "unknown Uplo: " + str );
246}
247
248inline void from_string( std::string const& str, Diag* val )
249{
250 std::string str_ = str;
251 std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower );
252 if (str_ == "n" || str_ == "nonunit")
253 *val = Diag::NonUnit;
254 else if (str_ == "u" || str_ == "unit")
255 *val = Diag::Unit;
256 else
257 throw Error( "unknown Diag: " + str );
258}
259
260inline void from_string( std::string const& str, Side* val )
261{
262 std::string str_ = str;
263 std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower );
264 if (str_ == "l" || str_ == "left")
265 *val = Side::Left;
266 else if (str_ == "r" || str_ == "right")
267 *val = Side::Right;
268 else
269 throw Error( "unknown Side: " + str );
270}
271
273// Convert LAPACK-style char to enum.
274
275[[deprecated("use from_string. To be removed 2025-05.")]]
276inline Layout char2layout( char layout )
277{
278 layout = (char) toupper( layout );
279 assert( layout == 'C' || layout == 'R' );
280 return Layout( layout );
281}
282
283[[deprecated("use from_string. To be removed 2025-05.")]]
284inline Op char2op( char op )
285{
286 op = (char) toupper( op );
287 assert( op == 'N' || op == 'T' || op == 'C' );
288 return Op( op );
289}
290
291[[deprecated("use from_string. To be removed 2025-05.")]]
292inline Uplo char2uplo( char uplo )
293{
294 uplo = (char) toupper( uplo );
295 assert( uplo == 'L' || uplo == 'U' || uplo == 'G' );
296 return Uplo( uplo );
297}
298
299[[deprecated("use from_string. To be removed 2025-05.")]]
300inline Diag char2diag( char diag )
301{
302 diag = (char) toupper( diag );
303 assert( diag == 'N' || diag == 'U' );
304 return Diag( diag );
305}
306
307[[deprecated("use from_string. To be removed 2025-05.")]]
308inline Side char2side( char side )
309{
310 side = (char) toupper( side );
311 assert( side == 'L' || side == 'R' );
312 return Side( side );
313}
314
315// -----------------------------------------------------------------------------
316// 1-norm absolute value, |Re(x)| + |Im(x)|
317template <typename T>
318T abs1( T x )
319{
320 using std::abs;
321 return abs( x );
322}
323
324template <typename T>
325T abs1( std::complex<T> x )
326{
327 using std::abs;
328 return abs( real( x ) ) + abs( imag( x ) );
329}
330
331// -----------------------------------------------------------------------------
332// common_type_t is defined in C++14; here's a C++11 definition
333#if __cplusplus >= 201402L
334 using std::common_type_t;
335 using std::decay_t;
336#else
337 template <typename... Ts>
338 using common_type_t = typename std::common_type< Ts... >::type;
339
340 template <typename... Ts>
341 using decay_t = typename std::decay< Ts... >::type;
342#endif
343
344//------------------------------------------------------------------------------
346template <typename T>
348 std::integral_constant<bool, false>
349{};
350
351// specialize for std::complex
352template <typename T>
353struct is_complex< std::complex<T> >:
354 std::integral_constant<bool, true>
355{};
356
357// -----------------------------------------------------------------------------
358// Previously extended real and imag to real types. Belatedly discovered that
359// C++11 extends std::real and std::imag to float and integer types,
360// so just use those now.
361using std::real;
362using std::imag;
363
373template <typename T>
374inline T conj( T x )
375{
376 static_assert(
377 ! is_complex<T>::value,
378 "Usage: using blas::conj; y = conj(x); NOT: y = blas::conj(x);" );
379 return x;
380}
381
382// -----------------------------------------------------------------------------
383// Based on C++14 common_type implementation from
384// http://www.cplusplus.com/reference/type_traits/common_type/
385// Adds promotion of complex types based on the common type of the associated
386// real types. This fixes various cases:
387//
388// std::common_type_t< double, complex<float> > is complex<float> (wrong)
389// scalar_type< double, complex<float> > is complex<double> (right)
390//
391// std::common_type_t< int, complex<long> > is not defined (compile error)
392// scalar_type< int, complex<long> > is complex<long> (right)
393
394// for zero types
395template <typename... Types>
396struct scalar_type_traits;
397
398// define scalar_type<> type alias
399template <typename... Types>
400using scalar_type = typename scalar_type_traits< Types... >::type;
401
402// for one type
403template <typename T>
404struct scalar_type_traits< T >
405{
406 using type = decay_t<T>;
407};
408
409// for two types
410// relies on type of ?: operator being the common type of its two arguments
411template <typename T1, typename T2>
412struct scalar_type_traits< T1, T2 >
413{
414 using type = decay_t< decltype( true ? std::declval<T1>() : std::declval<T2>() ) >;
415};
416
417// for either or both complex,
418// find common type of associated real types, then add complex
419template <typename T1, typename T2>
420struct scalar_type_traits< std::complex<T1>, T2 >
421{
422 using type = std::complex< common_type_t< T1, T2 > >;
423};
424
425template <typename T1, typename T2>
426struct scalar_type_traits< T1, std::complex<T2> >
427{
428 using type = std::complex< common_type_t< T1, T2 > >;
429};
430
431template <typename T1, typename T2>
432struct scalar_type_traits< std::complex<T1>, std::complex<T2> >
433{
434 using type = std::complex< common_type_t< T1, T2 > >;
435};
436
437// for three or more types
438template <typename T1, typename T2, typename... Types>
439struct scalar_type_traits< T1, T2, Types... >
440{
441 using type = scalar_type< scalar_type< T1, T2 >, Types... >;
442};
443
444// -----------------------------------------------------------------------------
445// for any combination of types, determine associated real, scalar,
446// and complex types.
447//
448// real_type< float > is float
449// real_type< float, double, complex<float> > is double
450//
451// scalar_type< float > is float
452// scalar_type< float, complex<float> > is complex<float>
453// scalar_type< float, double, complex<float> > is complex<double>
454//
455// complex_type< float > is complex<float>
456// complex_type< float, double > is complex<double>
457// complex_type< float, double, complex<float> > is complex<double>
458
459// for zero types
460template <typename... Types>
461struct real_type_traits;
462
463// define real_type<> type alias
464template <typename... Types>
465using real_type = typename real_type_traits< Types... >::real_t;
466
467// define complex_type<> type alias
468template <typename... Types>
469using complex_type = std::complex< real_type< Types... > >;
470
471// for one type
472template <typename T>
473struct real_type_traits<T>
474{
475 using real_t = T;
476};
477
478// for one complex type, strip complex
479template <typename T>
480struct real_type_traits< std::complex<T> >
481{
482 using real_t = T;
483};
484
485// for two or more types
486template <typename T1, typename... Types>
487struct real_type_traits< T1, Types... >
488{
489 using real_t = scalar_type< real_type<T1>, real_type< Types... > >;
490};
491
492// -----------------------------------------------------------------------------
493// max that works with different data types: int64_t = max( int, int64_t )
494// and any number of arguments: max( a, b, c, d )
495
496// one argument
497template <typename T>
498T max( T x )
499{
500 return x;
501}
502
503// two arguments
504template <typename T1, typename T2>
505scalar_type< T1, T2 >
506 max( T1 x, T2 y )
507{
508 return (x >= y ? x : y);
509}
510
511// three or more arguments
512template <typename T1, typename... Types>
513scalar_type< T1, Types... >
514 max( T1 first, Types... args )
515{
516 return max( first, max( args... ) );
517}
518
519// -----------------------------------------------------------------------------
520// min that works with different data types: int64_t = min( int, int64_t )
521// and any number of arguments: min( a, b, c, d )
522
523// one argument
524template <typename T>
525T min( T x )
526{
527 return x;
528}
529
530// two arguments
531template <typename T1, typename T2>
532scalar_type< T1, T2 >
533 min( T1 x, T2 y )
534{
535 return (x <= y ? x : y);
536}
537
538// three or more arguments
539template <typename T1, typename... Types>
540scalar_type< T1, Types... >
541 min( T1 first, Types... args )
542{
543 return min( first, min( args... ) );
544}
545
546// -----------------------------------------------------------------------------
547// Generate a scalar from real and imaginary parts.
548// For real scalars, the imaginary part is ignored.
549
550// For real scalar types.
551template <typename real_t>
552struct MakeScalarTraits {
553 static real_t make( real_t re, real_t im )
554 { return re; }
555};
556
557// For complex scalar types.
558template <typename real_t>
559struct MakeScalarTraits< std::complex<real_t> > {
560 static std::complex<real_t> make( real_t re, real_t im )
561 { return std::complex<real_t>( re, im ); }
562};
563
564template <typename scalar_t>
565scalar_t make_scalar( blas::real_type<scalar_t> re,
566 blas::real_type<scalar_t> im=0 )
567{
568 return MakeScalarTraits<scalar_t>::make( re, im );
569}
570
571// -----------------------------------------------------------------------------
575template <typename real_t>
576int sgn( real_t val )
577{
578 return (real_t(0) < val) - (val < real_t(0));
579}
580
581// -----------------------------------------------------------------------------
582// Macros to compute scaling constants
583//
584// __Further details__
585//
586// Anderson E (2017) Algorithm 978: Safe scaling in the level 1 BLAS.
587// ACM Trans Math Softw 44:. https://doi.org/10.1145/3061665
588
590template <typename real_t>
591inline const real_t ulp()
592{
593 return std::numeric_limits< real_t >::epsilon();
594}
595
597template <typename real_t>
598inline const real_t safe_min()
599{
600 const int fradix = std::numeric_limits<real_t>::radix;
601 const int expm = std::numeric_limits<real_t>::min_exponent;
602 const int expM = std::numeric_limits<real_t>::max_exponent;
603
604 return max( pow(fradix, expm-1), pow(fradix, 1-expM) );
605}
606
608template <typename real_t>
609inline const real_t safe_max()
610{
611 const int fradix = std::numeric_limits<real_t>::radix;
612 const int expm = std::numeric_limits<real_t>::min_exponent;
613 const int expM = std::numeric_limits<real_t>::max_exponent;
614
615 return min( pow(fradix, 1-expm), pow(fradix, expM-1) );
616}
617
619template <typename real_t>
620inline const real_t root_min()
621{
622 return sqrt( safe_min<real_t>() / ulp<real_t>() );
623}
624
626template <typename real_t>
627inline const real_t root_max()
628{
629 return sqrt( safe_max<real_t>() * ulp<real_t>() );
630}
631
632//==============================================================================
633namespace internal {
634
635// -----------------------------------------------------------------------------
636// internal helper function; throws Error if cond is true
637// called by blas_error_if macro
638inline void throw_if( bool cond, const char* condstr, const char* func )
639{
640 if (cond) {
641 throw Error( condstr, func );
642 }
643}
644
645#if defined(_MSC_VER)
646 #define BLASPP_ATTR_FORMAT(I, F)
647#else
648 #define BLASPP_ATTR_FORMAT(I, F) __attribute__((format( printf, I, F )))
649#endif
650
651// -----------------------------------------------------------------------------
652// internal helper function; throws Error if cond is true
653// uses printf-style format for error message
654// called by blas_error_if_msg macro
655// condstr is ignored, but differentiates this from other version.
656inline void throw_if( bool cond, const char* condstr, const char* func, const char* format, ... )
657 BLASPP_ATTR_FORMAT(4, 5);
658
659inline void throw_if( bool cond, const char* condstr, const char* func, const char* format, ... )
660{
661 if (cond) {
662 char buf[80];
663 va_list va;
664 va_start( va, format );
665 vsnprintf( buf, sizeof(buf), format, va );
666 throw Error( buf, func );
667 }
668}
669
670// -----------------------------------------------------------------------------
671// internal helper function; aborts if cond is true
672// uses printf-style format for error message
673// called by blas_error_if_msg macro
674inline void abort_if( bool cond, const char* func, const char* format, ... )
675 BLASPP_ATTR_FORMAT(3, 4);
676
677inline void abort_if( bool cond, const char* func, const char* format, ... )
678{
679 if (cond) {
680 char buf[80];
681 va_list va;
682 va_start( va, format );
683 vsnprintf( buf, sizeof(buf), format, va );
684
685 fprintf( stderr, "Error: %s, in function %s\n", buf, func );
686 abort();
687 }
688}
689
690#undef BLASPP_ATTR_FORMAT
691
692} // namespace internal
693
694// -----------------------------------------------------------------------------
695// internal macros to handle error checks
696#if defined(BLAS_ERROR_NDEBUG) || (defined(BLAS_ERROR_ASSERT) && defined(NDEBUG))
697
698 // blaspp does no error checking;
699 // lower level BLAS may still handle errors via xerbla
700 #define blas_error_if( cond ) \
701 ((void)0)
702
703 #define blas_error_if_msg( cond, ... ) \
704 ((void)0)
705
706#elif defined(BLAS_ERROR_ASSERT)
707
708 // blaspp aborts on error
709 #define blas_error_if( cond ) \
710 blas::internal::abort_if( cond, __func__, "%s", #cond )
711
712 #define blas_error_if_msg( cond, ... ) \
713 blas::internal::abort_if( cond, __func__, __VA_ARGS__ )
714
715#else
716
717 // blaspp throws errors (default)
718 // internal macro to get string #cond; throws Error if cond is true
719 // ex: blas_error_if( a < b );
720 #define blas_error_if( cond ) \
721 blas::internal::throw_if( cond, #cond, __func__ )
722
723 // internal macro takes cond and printf-style format for error message.
724 // throws Error if cond is true.
725 // ex: blas_error_if_msg( a < b, "a %d < b %d", a, b );
726 #define blas_error_if_msg( cond, ... ) \
727 blas::internal::throw_if( cond, #cond, __func__, __VA_ARGS__ )
728
729#endif
730
731} // namespace blas
732
733#endif // #ifndef BLAS_UTIL_HH
Exception class for BLAS errors.
Definition util.hh:30
Error(std::string const &msg)
Constructs BLAS error with message.
Definition util.hh:38
Error(const char *msg, const char *func)
Constructs BLAS error with message: "msg, in function func".
Definition util.hh:44
virtual const char * what() const noexcept override
Returns BLAS error message.
Definition util.hh:50
Error()
Constructs BLAS error.
Definition util.hh:33
True if T is std::complex<T2> for some type T2.
Definition util.hh:349