87template <
typename TA,
typename TB>
96 blas::scalar_type<TA, TB> alpha,
97 TA
const *A, int64_t lda,
101 using scalar_t = blas::scalar_type<TA, TB>;
103 #define A(i_, j_) A[ (i_) + (j_)*lda ]
104 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
107 const scalar_t zero = 0;
110 blas_error_if( layout != Layout::ColMajor &&
111 layout != Layout::RowMajor );
112 blas_error_if( side != Side::Left &&
113 side != Side::Right );
114 blas_error_if( uplo != Uplo::Lower &&
115 uplo != Uplo::Upper );
116 blas_error_if( trans != Op::NoTrans &&
117 trans != Op::Trans &&
118 trans != Op::ConjTrans );
119 blas_error_if( diag != Diag::NonUnit &&
120 diag != Diag::Unit );
121 blas_error_if( m < 0 );
122 blas_error_if( n < 0 );
125 if (layout == Layout::RowMajor) {
126 side = (side == Side::Left)
129 if (uplo == Uplo::Lower)
131 else if (uplo == Uplo::Upper)
137 blas_error_if( lda < ((side == Side::Left) ? m : n) );
138 blas_error_if( ldb < m );
141 if (m == 0 || n == 0)
146 for (int64_t j = 0; j < n; ++j) {
147 for (int64_t i = 0; i < m; ++i)
154 if (side == Side::Left) {
155 if (trans == Op::NoTrans) {
156 if (uplo == Uplo::Upper) {
157 for (int64_t j = 0; j < n; ++j) {
158 for (int64_t k = 0; k < m; ++k) {
159 scalar_t alpha_Bkj = alpha*B(k, j);
160 for (int64_t i = 0; i < k; ++i)
161 B(i, j) += A(i, k)*alpha_Bkj;
162 B(k, j) = (diag == Diag::NonUnit)
169 for (int64_t j = 0; j < n; ++j) {
170 for (int64_t k = m-1; k >= 0; --k) {
171 scalar_t alpha_Bkj = alpha*B(k, j);
172 B(k, j) = (diag == Diag::NonUnit)
175 for (int64_t i = k+1; i < m; ++i)
176 B(i, j) += A(i, k)*alpha_Bkj;
181 else if (trans == Op::Trans) {
182 if (uplo == Uplo::Upper) {
183 for (int64_t j = 0; j < n; ++j) {
184 for (int64_t i = m-1; i >= 0; --i) {
185 scalar_t sum = (diag == Diag::NonUnit)
188 for (int64_t k = 0; k < i; ++k)
189 sum += A(k, i)*B(k, j);
190 B(i, j) = alpha * sum;
195 for (int64_t j = 0; j < n; ++j) {
196 for (int64_t i = 0; i < m; ++i) {
197 scalar_t sum = (diag == Diag::NonUnit)
200 for (int64_t k = i+1; k < m; ++k)
201 sum += A(k, i)*B(k, j);
202 B(i, j) = alpha * sum;
208 if (uplo == Uplo::Upper) {
209 for (int64_t j = 0; j < n; ++j) {
210 for (int64_t i = m-1; i >= 0; --i) {
211 scalar_t sum = (diag == Diag::NonUnit)
212 ? conj(A(i, i))*B(i, j)
214 for (int64_t k = 0; k < i; ++k)
215 sum += conj(A(k, i))*B(k, j);
216 B(i, j) = alpha * sum;
221 for (int64_t j = 0; j < n; ++j) {
222 for (int64_t i = 0; i < m; ++i) {
223 scalar_t sum = (diag == Diag::NonUnit)
224 ? conj(A(i, i))*B(i, j)
226 for (int64_t k = i+1; k < m; ++k)
227 sum += conj(A(k, i))*B(k, j);
228 B(i, j) = alpha * sum;
235 if (trans == Op::NoTrans) {
236 if (uplo == Uplo::Upper) {
237 for (int64_t j = n-1; j >= 0; --j) {
239 scalar_t alpha_Akj = (diag == Diag::NonUnit)
242 for (int64_t i = 0; i < m; ++i)
243 B(i, j) *= alpha_Akj;
245 for (int64_t k = 0; k < j; ++k) {
246 alpha_Akj = alpha*A(k, j);
247 for (int64_t i = 0; i < m; ++i)
248 B(i, j) += B(i, k)*alpha_Akj;
253 for (int64_t j = 0; j < n; ++j) {
255 scalar_t alpha_Akj = (diag == Diag::NonUnit)
258 for (int64_t i = 0; i < m; ++i)
259 B(i, j) *= alpha_Akj;
261 for (int64_t k = j+1; k < n; ++k) {
262 alpha_Akj = alpha*A(k, j);
263 for (int64_t i = 0; i < m; ++i)
264 B(i, j) += B(i, k)*alpha_Akj;
269 else if (trans == Op::Trans) {
270 if (uplo == Uplo::Upper) {
271 for (int64_t k = 0; k < n; ++k) {
272 for (int64_t j = 0; j < k; ++j) {
273 scalar_t alpha_Ajk = alpha*A(j, k);
274 for (int64_t i = 0; i < m; ++i)
275 B(i, j) += B(i, k)*alpha_Ajk;
278 scalar_t alpha_Akk = (diag == Diag::NonUnit)
281 for (int64_t i = 0; i < m; ++i)
282 B(i, k) *= alpha_Akk;
286 for (int64_t k = n-1; k >= 0; --k) {
287 for (int64_t j = k+1; j < n; ++j) {
288 scalar_t alpha_Ajk = alpha*A(j, k);
289 for (int64_t i = 0; i < m; ++i)
290 B(i, j) += B(i, k)*alpha_Ajk;
293 scalar_t alpha_Akk = (diag == Diag::NonUnit)
296 for (int64_t i = 0; i < m; ++i)
297 B(i, k) *= alpha_Akk;
302 if (uplo == Uplo::Upper) {
303 for (int64_t k = 0; k < n; ++k) {
304 for (int64_t j = 0; j < k; ++j) {
305 scalar_t alpha_Ajk = alpha*conj(A(j, k));
306 for (int64_t i = 0; i < m; ++i)
307 B(i, j) += B(i, k)*alpha_Ajk;
310 scalar_t alpha_Akk = (diag == Diag::NonUnit)
311 ? alpha*conj(A(k, k))
313 for (int64_t i = 0; i < m; ++i)
314 B(i, k) *= alpha_Akk;
318 for (int64_t k = n-1; k >= 0; --k) {
319 for (int64_t j = k+1; j < n; ++j) {
320 scalar_t alpha_Ajk = alpha*conj(A(j, k));
321 for (int64_t i = 0; i < m; ++i)
322 B(i, j) += B(i, k)*alpha_Ajk;
325 scalar_t alpha_Akk = (diag == Diag::NonUnit)
326 ? alpha*conj(A(k, k))
328 for (int64_t i = 0; i < m; ++i)
329 B(i, k) *= alpha_Akk;
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67
void trmm(blas::Layout layout, blas::Side side, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float *B, int64_t ldb, blas::Queue &queue)
GPU device, float version.
Definition device_trmm.cc:104