10#include "blas/defines.h"
12#if defined( BLAS_HAVE_CUBLAS ) \
13 || defined( BLAS_HAVE_ROCBLAS ) \
14 || defined( BLAS_HAVE_SYCL )
15 #define BLAS_HAVE_DEVICE
18#ifdef BLAS_HAVE_CUBLAS
19 #include <cuda_runtime.h>
20 #include <cublas_v2.h>
22#elif defined(BLAS_HAVE_ROCBLAS)
24 #if ! defined(__HIP_PLATFORM_NVCC__) && ! defined(__HIP_PLATFORM_HCC__)
25 #define __HIP_PLATFORM_HCC__
26 #define BLAS_HIP_PLATFORM_HCC
29 #include <hip/hip_runtime.h>
32 #if HIP_VERSION >= 50200000
33 #include <rocblas/rocblas.h>
39 #ifdef BLAS_HIP_PLATFORM_HCC
40 #undef __HIP_PLATFORM_HCC__
41 #undef BLAS_HIP_PLATFORM_HCC
44#elif defined(BLAS_HAVE_SYCL)
45 #include <sycl/detail/cl.h>
55#ifdef BLAS_HAVE_CUBLAS
56 typedef int device_blas_int;
57#elif defined(BLAS_HAVE_ROCBLAS)
58 typedef int device_blas_int;
59#elif defined(BLAS_HAVE_SYCL)
60 typedef std::int64_t device_blas_int;
62 typedef int device_blas_int;
79enum class MemcpyKind : device_blas_int {
88#if defined(BLAS_HAVE_CUBLAS)
91 inline cudaMemcpyKind memcpy2cuda( MemcpyKind kind )
94 case MemcpyKind::HostToHost:
return cudaMemcpyHostToHost;
break;
95 case MemcpyKind::HostToDevice:
return cudaMemcpyHostToDevice;
break;
96 case MemcpyKind::DeviceToHost:
return cudaMemcpyDeviceToHost;
break;
97 case MemcpyKind::DeviceToDevice:
return cudaMemcpyDeviceToDevice;
break;
98 case MemcpyKind::Default:
return cudaMemcpyDefault;
99 default:
throw blas::Error(
"unknown memcpy direction" );
102#elif defined(BLAS_HAVE_ROCBLAS)
105 inline hipMemcpyKind memcpy2hip( MemcpyKind kind )
108 case MemcpyKind::HostToHost:
return hipMemcpyHostToHost;
break;
109 case MemcpyKind::HostToDevice:
return hipMemcpyHostToDevice;
break;
110 case MemcpyKind::DeviceToHost:
return hipMemcpyDeviceToHost;
break;
111 case MemcpyKind::DeviceToDevice:
return hipMemcpyDeviceToDevice;
break;
112 case MemcpyKind::Default:
return hipMemcpyDefault;
113 default:
throw blas::Error(
"unknown memcpy direction" );
116#elif defined(BLAS_HAVE_SYCL)
122 inline int64_t memcpy2sycl( MemcpyKind kind ) {
return 0; }
127const int MaxBatchChunk = 50000;
129#if defined( BLAS_HAVE_CUBLAS ) || defined( BLAS_HAVE_ROCBLAS )
130 const int MaxForkSize = 10;
133 const int MaxForkSize = 1;
146 #if defined( BLAS_HAVE_CUBLAS )
147 using stream_t = cudaStream_t;
148 using event_t = cudaEvent_t;
149 using handle_t = cublasHandle_t;
151 #elif defined( BLAS_HAVE_ROCBLAS )
152 using stream_t = hipStream_t;
153 using event_t = hipEvent_t;
154 using handle_t = rocblas_handle;
156 #elif defined( BLAS_HAVE_SYCL )
157 using stream_t = sycl::queue;
161 using stream_t =
void*;
167 Queue(
int device, stream_t& stream );
169 #if defined( BLAS_HAVE_CUBLAS ) || defined( BLAS_HAVE_ROCBLAS )
170 Queue(
int device, handle_t handle );
179 int device()
const {
return device_; }
183 void*
work() {
return (
void*) work_; }
186 template <
typename scalar_t>
187 size_t work_size()
const {
return lwork_ /
sizeof(scalar_t); }
189 template <
typename scalar_t>
193 void fork(
int num_streams=MaxForkSize );
201 #if defined( BLAS_HAVE_CUBLAS ) || defined( BLAS_HAVE_ROCBLAS )
203 void set_handle( handle_t& in_handle );
204 handle_t handle()
const {
return handle_; }
208 void set_stream( stream_t& in_stream );
212 #if defined( BLAS_HAVE_CUBLAS ) || defined( BLAS_HAVE_ROCBLAS )
213 return streams_[ current_stream_index_ ];
215 return streams_[ 0 ];
225 stream_t streams_[ MaxForkSize ];
227 #if defined( BLAS_HAVE_CUBLAS ) || defined( BLAS_HAVE_ROCBLAS )
231 event_t events_[ MaxForkSize ];
235 int num_active_streams_;
238 int current_stream_index_;
243 bool own_default_stream_;
252#ifdef BLAS_HAVE_CUBLAS
254inline bool is_device_error( cudaError_t error )
256 return (error != cudaSuccess);
259inline bool is_device_error( cublasStatus_t error )
261 return (error != CUBLAS_STATUS_SUCCESS);
264inline const char* device_error_string( cudaError_t error )
266 return cudaGetErrorString( error );
270const char* device_error_string( cublasStatus_t error );
276#ifdef BLAS_HAVE_ROCBLAS
278inline bool is_device_error( hipError_t error )
280 return (error != hipSuccess);
283inline bool is_device_error( rocblas_status error )
285 return (error != rocblas_status_success);
288inline const char* device_error_string( hipError_t error )
290 return hipGetErrorString( error );
293inline const char* device_error_string( rocblas_status error )
295 return rocblas_status_to_string( error );
302#if defined(BLAS_ERROR_NDEBUG) || (defined(BLAS_ERROR_ASSERT) && defined(NDEBUG))
305 #define blas_dev_call( error ) \
308#elif defined(BLAS_ERROR_ASSERT)
311 #if defined(BLAS_HAVE_SYCL)
312 #define blas_dev_call( error ) \
317 catch (sycl::exception const& e) { \
318 blas::internal::abort_if( true, __func__, \
321 catch (std::exception const& e) { \
322 blas::internal::abort_if( true, __func__, \
326 blas::internal::abort_if( true, __func__, \
327 "%s", "unknown exception" ); \
332 #define blas_dev_call( error ) \
335 blas::internal::abort_if( blas::is_device_error(e), __func__, \
336 "%s", blas::device_error_string(e) ); \
343 #if defined(BLAS_HAVE_SYCL)
344 #define blas_dev_call( error ) \
349 catch (sycl::exception const& e) { \
350 blas::internal::throw_if( true, \
351 e.what(), __func__ ); \
353 catch (std::exception const& e) { \
354 blas::internal::throw_if( true, \
355 e.what(), __func__ ); \
358 blas::internal::throw_if( true, \
359 "unknown exception", __func__ ); \
364 #define blas_dev_call( error ) \
367 blas::internal::throw_if( blas::is_device_error(e), \
368 blas::device_error_string(e), \
379void internal_set_device(
int device );
381int get_device_count();
388void host_free_pinned(
void* ptr,
blas::Queue &queue );
408 blas_error_if( nelements < 0 );
411 #ifdef BLAS_HAVE_CUBLAS
412 blas::internal_set_device( queue.device() );
414 cudaMalloc( (
void**)&ptr, nelements *
sizeof(T) ) );
416 #elif defined(BLAS_HAVE_ROCBLAS)
417 blas::internal_set_device( queue.device() );
419 hipMalloc( (
void**)&ptr, nelements *
sizeof(T) ) );
421 #elif defined(BLAS_HAVE_SYCL)
423 ptr = (T*)sycl::malloc_shared( nelements*
sizeof(T), queue.stream() ) );
426 throw blas::Error(
"device BLAS not available", __func__ );
446T* host_malloc_pinned(
449 blas_error_if( nelements < 0 );
452 #ifdef BLAS_HAVE_CUBLAS
454 cudaMallocHost( (
void**)&ptr, nelements *
sizeof(T) ) );
456 #elif defined(BLAS_HAVE_ROCBLAS)
458 hipHostMalloc( (
void**)&ptr, nelements *
sizeof(T) ) );
460 #elif defined(BLAS_HAVE_SYCL)
462 ptr = (T*)sycl::malloc_host( nelements*
sizeof(T), queue.stream() ) );
465 throw blas::Error(
"device BLAS not available", __func__ );
489 int value, int64_t nelements, Queue& queue)
491 blas_error_if( nelements < 0 );
493 #ifdef BLAS_HAVE_CUBLAS
494 blas::internal_set_device( queue.device() );
498 nelements *
sizeof(T), queue.stream() ) );
500 #elif defined(BLAS_HAVE_ROCBLAS)
501 blas::internal_set_device( queue.device() );
505 nelements *
sizeof(T), queue.stream() ) );
507 #elif defined(BLAS_HAVE_SYCL)
509 queue.stream().memset( ptr, value, nelements *
sizeof(T) ) );
512 throw blas::Error(
"device BLAS not available", __func__ );
529[[deprecated(
"Use device_memcpy without kind. To be removed 2025-05.")]]
533 int64_t nelements, MemcpyKind kind, Queue& queue)
535 blas_error_if( nelements < 0 );
537 #ifdef BLAS_HAVE_CUBLAS
538 blas::internal_set_device( queue.device() );
541 dst, src,
sizeof(T)*nelements,
542 memcpy2cuda(kind), queue.stream() ) );
544 #elif defined(BLAS_HAVE_ROCBLAS)
545 blas::internal_set_device( queue.device() );
548 dst, src,
sizeof(T)*nelements,
549 memcpy2hip(kind), queue.stream() ) );
551 #elif defined(BLAS_HAVE_SYCL)
553 queue.stream().memcpy( dst, src,
sizeof(T)*nelements ) );
556 throw blas::Error(
"device BLAS not available", __func__ );
582 int64_t nelements, Queue& queue)
584 blas_error_if( nelements < 0 );
586 #ifdef BLAS_HAVE_CUBLAS
587 blas::internal_set_device( queue.device() );
590 dst, src,
sizeof(T)*nelements,
591 cudaMemcpyDefault, queue.stream() ) );
593 #elif defined(BLAS_HAVE_ROCBLAS)
594 blas::internal_set_device( queue.device() );
597 dst, src,
sizeof(T)*nelements,
598 hipMemcpyDefault, queue.stream() ) );
600 #elif defined(BLAS_HAVE_SYCL)
602 queue.stream().memcpy( dst, src,
sizeof(T)*nelements ) );
605 throw blas::Error(
"device BLAS not available", __func__ );
622[[deprecated(
"Use device_memcpy_2d without kind. To be removed 2025-05.")]]
623void device_memcpy_2d(
624 T* dst, int64_t dst_pitch,
625 T
const* src, int64_t src_pitch,
626 int64_t width, int64_t height, MemcpyKind kind, Queue& queue)
628 blas_error_if( width < 0 );
629 blas_error_if( height < 0 );
630 blas_error_if( dst_pitch < width );
631 blas_error_if( src_pitch < width );
633 #ifdef BLAS_HAVE_CUBLAS
634 blas::internal_set_device( queue.device() );
637 dst,
sizeof(T)*dst_pitch,
638 src,
sizeof(T)*src_pitch,
639 sizeof(T)*width, height, memcpy2cuda(kind), queue.stream() ) );
641 #elif defined(BLAS_HAVE_ROCBLAS)
642 blas::internal_set_device( queue.device() );
645 dst,
sizeof(T)*dst_pitch,
646 src,
sizeof(T)*src_pitch,
647 sizeof(T)*width, height, memcpy2hip(kind), queue.stream() ) );
649 #elif defined(BLAS_HAVE_SYCL)
650 if (dst_pitch == width && src_pitch == width) {
653 queue.stream().memcpy( dst, src, width * height *
sizeof(T) ) );
658 for (int64_t i = 0; i < height; ++i) {
659 T* dst_row = dst + i*dst_pitch;
660 T
const* src_row = src + i*src_pitch;
662 queue.stream().memcpy( dst_row, src_row, width*
sizeof(T) ) );
666 throw blas::Error(
"device BLAS not available", __func__ );
714void device_memcpy_2d(
715 T* dst, int64_t dst_pitch,
716 T
const* src, int64_t src_pitch,
717 int64_t width, int64_t height, Queue& queue)
719 blas_error_if( width < 0 );
720 blas_error_if( height < 0 );
721 blas_error_if( dst_pitch < width );
722 blas_error_if( src_pitch < width );
724 #ifdef BLAS_HAVE_CUBLAS
725 blas::internal_set_device( queue.device() );
728 dst,
sizeof(T)*dst_pitch,
729 src,
sizeof(T)*src_pitch,
730 sizeof(T)*width, height,
731 cudaMemcpyDefault, queue.stream() ) );
733 #elif defined(BLAS_HAVE_ROCBLAS)
734 blas::internal_set_device( queue.device() );
737 dst,
sizeof(T)*dst_pitch,
738 src,
sizeof(T)*src_pitch,
739 sizeof(T)*width, height,
740 hipMemcpyDefault, queue.stream() ) );
742 #elif defined(BLAS_HAVE_SYCL)
743 if (dst_pitch == width && src_pitch == width) {
746 queue.stream().memcpy( dst, src, width * height *
sizeof(T) ) );
751 for (int64_t i = 0; i < height; ++i) {
752 T* dst_row = dst + i*dst_pitch;
753 T
const* src_row = src + i*src_pitch;
755 queue.stream().memcpy( dst_row, src_row, width*
sizeof(T) ) );
759 throw blas::Error(
"device BLAS not available", __func__ );
787void device_copy_vector(
789 T
const* src, int64_t inc_src,
790 T* dst, int64_t inc_dst, Queue& queue)
792 if (inc_src == 1 && inc_dst == 1) {
794 device_memcpy( dst, src, n, queue );
798 device_memcpy_2d( dst, inc_dst, src, inc_src, 1, n, queue );
834void device_copy_matrix(
835 int64_t m, int64_t n,
836 T
const* src, int64_t ld_src,
837 T* dst, int64_t ld_dst, Queue& queue)
839 device_memcpy_2d( dst, ld_dst, src, ld_src, m, n, queue );
851template <
typename scalar_t>
854 lwork *=
sizeof(scalar_t);
855 if (lwork > lwork_) {
858 device_free( work_, *
this );
860 lwork_ = max( lwork, 3*MaxBatchChunk*
sizeof(
void*) );
861 work_ = device_malloc<char>( lwork_, *
this );
Exception class for BLAS errors.
Definition util.hh:30
Queue for executing GPU device routines.
Definition device.hh:143
void work_ensure_size(size_t lwork)
Ensures GPU device workspace is of size at least lwork elements of scalar_t, synchronizing and reallo...
Definition device.hh:852
void sync()
Synchronize with queue.
Definition device_queue.cc:238
Queue()
Default constructor.
Definition device_queue.cc:19
void * work()
Definition device.hh:183
void fork(int num_streams=MaxForkSize)
Forks the kernel launches assigned to this queue to parallel streams.
Definition device_queue.cc:255
void join()
Switch executions on this queue back from parallel streams to the default stream.
Definition device_queue.cc:296
size_t work_size() const
Definition device.hh:187
void revolve()
In fork mode, switch execution to the next-in-line stream.
Definition device_queue.cc:322