10#include "blas/symm.hh"
84template <
typename TA,
typename TB,
typename TC>
90 scalar_type<TA, TB, TC> alpha,
91 TA
const *A, int64_t lda,
92 TB
const *B, int64_t ldb,
93 scalar_type<TA, TB, TC> beta,
97 using scalar_t = blas::scalar_type<TA, TB, TC>;
99 #define A(i_, j_) A[ (i_) + (j_)*lda ]
100 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
101 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
104 const scalar_t zero = 0;
105 const scalar_t one = 1;
108 blas_error_if( layout != Layout::ColMajor &&
109 layout != Layout::RowMajor );
110 blas_error_if( side != Side::Left &&
111 side != Side::Right );
112 blas_error_if( uplo != Uplo::Lower &&
113 uplo != Uplo::Upper &&
114 uplo != Uplo::General );
115 blas_error_if( m < 0 );
116 blas_error_if( n < 0 );
119 if (layout == Layout::RowMajor) {
120 side = (side == Side::Left)
123 if (uplo == Uplo::Lower)
125 else if (uplo == Uplo::Upper)
131 blas_error_if( lda < ((side == Side::Left) ? m : n) );
132 blas_error_if( ldb < m );
133 blas_error_if( ldc < m );
136 if (m == 0 || n == 0)
142 for (int64_t j = 0; j < n; ++j) {
143 for (int64_t i = 0; i < m; ++i)
147 else if (beta != one) {
148 for (int64_t j = 0; j < n; ++j) {
149 for (int64_t i = 0; i < m; ++i)
157 if (side == Side::Left) {
158 if (uplo != Uplo::Lower) {
160 for (int64_t j = 0; j < n; ++j) {
161 for (int64_t i = 0; i < m; ++i) {
163 scalar_t alpha_Bij = alpha*B(i, j);
166 for (int64_t k = 0; k < i; ++k) {
167 C(k, j) += A(k, i) * alpha_Bij;
168 sum += conj( A(k, i) ) * B(k, j);
172 + real( A(i, i) ) * alpha_Bij
179 for (int64_t j = 0; j < n; ++j) {
180 for (int64_t i = m-1; i >= 0; --i) {
182 scalar_t alpha_Bij = alpha*B(i, j);
185 for (int64_t k = i+1; k < m; ++k) {
186 C(k, j) += A(k, i) * alpha_Bij;
187 sum += conj( A(k, i) ) * B(k, j);
191 + real( A(i, i) ) * alpha_Bij
198 if (uplo != Uplo::Lower) {
200 for (int64_t j = 0; j < n; ++j) {
202 scalar_t alpha_Akj = alpha * real( A(j, j) );
204 for (int64_t i = 0; i < m; ++i)
205 C(i, j) = beta * C(i, j) + B(i, j) * alpha_Akj;
207 for (int64_t k = 0; k < j; ++k) {
208 alpha_Akj = alpha*A(k, j);
209 for (int64_t i = 0; i < m; ++i)
210 C(i, j) += B(i, k) * alpha_Akj;
213 for (int64_t k = j+1; k < n; ++k) {
214 alpha_Akj = alpha * conj( A(j, k) );
215 for (int64_t i = 0; i < m; ++i)
216 C(i, j) += B(i, k) * alpha_Akj;
222 for (int64_t j = 0; j < n; ++j) {
224 scalar_t alpha_Akj = alpha * real( A(j, j) );
226 for (int64_t i = 0; i < m; ++i)
227 C(i, j) = beta * C(i, j) + B(i, j) * alpha_Akj;
229 for (int64_t k = 0; k < j; ++k) {
230 alpha_Akj = alpha * conj( A(j, k) );
231 for (int64_t i = 0; i < m; ++i)
232 C(i, j) += B(i, k) * alpha_Akj;
235 for (int64_t k = j+1; k < n; ++k) {
236 alpha_Akj = alpha*A(k, j);
237 for (int64_t i = 0; i < m; ++i)
238 C(i, j) += B(i, k) * alpha_Akj;
void hemm(blas::Layout layout, blas::Side side, blas::Uplo uplo, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_hemm.cc:107
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67