92template <
typename TA,
typename TB>
101 blas::scalar_type<TA, TB> alpha,
102 TA
const *A, int64_t lda,
106 using scalar_t = blas::scalar_type<TA, TB>;
108 #define A(i_, j_) A[ (i_) + (j_)*lda ]
109 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
112 const scalar_t zero = 0;
115 blas_error_if( layout != Layout::ColMajor &&
116 layout != Layout::RowMajor );
117 blas_error_if( side != Side::Left &&
118 side != Side::Right );
119 blas_error_if( uplo != Uplo::Lower &&
120 uplo != Uplo::Upper );
121 blas_error_if( trans != Op::NoTrans &&
122 trans != Op::Trans &&
123 trans != Op::ConjTrans );
124 blas_error_if( diag != Diag::NonUnit &&
125 diag != Diag::Unit );
126 blas_error_if( m < 0 );
127 blas_error_if( n < 0 );
130 if (layout == Layout::RowMajor) {
131 side = (side == Side::Left)
134 if (uplo == Uplo::Lower)
136 else if (uplo == Uplo::Upper)
142 blas_error_if( lda < ((side == Side::Left) ? m : n) );
143 blas_error_if( ldb < m );
146 if (m == 0 || n == 0)
151 for (int64_t j = 0; j < n; ++j) {
152 for (int64_t i = 0; i < m; ++i)
159 if (side == Side::Left) {
160 if (trans == Op::NoTrans) {
161 if (uplo == Uplo::Upper) {
162 for (int64_t j = 0; j < n; ++j) {
163 for (int64_t i = 0; i < m; ++i)
165 for (int64_t k = m-1; k >= 0; --k) {
166 if (diag == Diag::NonUnit)
168 for (int64_t i = 0; i < k; ++i)
169 B(i, j) -= A(i, k)*B(k, j);
174 for (int64_t j = 0; j < n; ++j) {
175 for (int64_t i = 0; i < m; ++i)
177 for (int64_t k = 0; k < m; ++k) {
178 if (diag == Diag::NonUnit)
180 for (int64_t i = k+1; i < m; ++i)
181 B(i, j) -= A(i, k)*B(k, j);
186 else if (trans == Op::Trans) {
187 if (uplo == Uplo::Upper) {
188 for (int64_t j = 0; j < n; ++j) {
189 for (int64_t i = 0; i < m; ++i) {
190 scalar_t sum = alpha*B(i, j);
191 for (int64_t k = 0; k < i; ++k)
192 sum -= A(k, i)*B(k, j);
193 B(i, j) = (diag == Diag::NonUnit)
200 for (int64_t j = 0; j < n; ++j) {
201 for (int64_t i = m-1; i >= 0; --i) {
202 scalar_t sum = alpha*B(i, j);
203 for (int64_t k = i+1; k < m; ++k)
204 sum -= A(k, i)*B(k, j);
205 B(i, j) = (diag == Diag::NonUnit)
213 if (uplo == Uplo::Upper) {
214 for (int64_t j = 0; j < n; ++j) {
215 for (int64_t i = 0; i < m; ++i) {
216 scalar_t sum = alpha*B(i, j);
217 for (int64_t k = 0; k < i; ++k)
218 sum -= conj(A(k, i))*B(k, j);
219 B(i, j) = (diag == Diag::NonUnit)
220 ? sum / conj(A(i, i))
226 for (int64_t j = 0; j < n; ++j) {
227 for (int64_t i = m-1; i >= 0; --i) {
228 scalar_t sum = alpha*B(i, j);
229 for (int64_t k = i+1; k < m; ++k)
230 sum -= conj(A(k, i))*B(k, j);
231 B(i, j) = (diag == Diag::NonUnit)
232 ? sum / conj(A(i, i))
240 if (trans == Op::NoTrans) {
241 if (uplo == Uplo::Upper) {
242 for (int64_t j = 0; j < n; ++j) {
243 for (int64_t i = 0; i < m; ++i)
245 for (int64_t k = 0; k < j; ++k) {
246 for (int64_t i = 0; i < m; ++i)
247 B(i, j) -= B(i, k)*A(k, j);
249 if (diag == Diag::NonUnit) {
250 for (int64_t i = 0; i < m; ++i)
256 for (int64_t j = n-1; j >= 0; --j) {
257 for (int64_t i = 0; i < m; ++i)
259 for (int64_t k = j+1; k < n; ++k) {
260 for (int64_t i = 0; i < m; ++i)
261 B(i, j) -= B(i, k)*A(k, j);
263 if (diag == Diag::NonUnit) {
264 for (int64_t i = 0; i < m; ++i)
270 else if (trans == Op::Trans) {
271 if (uplo == Uplo::Upper) {
272 for (int64_t k = n-1; k >= 0; --k) {
273 if (diag == Diag::NonUnit) {
274 for (int64_t i = 0; i < m; ++i)
277 for (int64_t j = 0; j < k; ++j) {
278 for (int64_t i = 0; i < m; ++i)
279 B(i, j) -= B(i, k)*A(j, k);
281 for (int64_t i = 0; i < m; ++i)
286 for (int64_t k = 0; k < n; ++k) {
287 if (diag == Diag::NonUnit) {
288 for (int64_t i = 0; i < m; ++i)
291 for (int64_t j = k+1; j < n; ++j) {
292 for (int64_t i = 0; i < m; ++i)
293 B(i, j) -= B(i, k)*A(j, k);
295 for (int64_t i = 0; i < m; ++i)
301 if (uplo == Uplo::Upper) {
302 for (int64_t k = n-1; k >= 0; --k) {
303 if (diag == Diag::NonUnit) {
304 for (int64_t i = 0; i < m; ++i)
305 B(i, k) /= conj(A(k, k));
307 for (int64_t j = 0; j < k; ++j) {
308 for (int64_t i = 0; i < m; ++i)
309 B(i, j) -= B(i, k)*conj(A(j, k));
311 for (int64_t i = 0; i < m; ++i)
316 for (int64_t k = 0; k < n; ++k) {
317 if (diag == Diag::NonUnit) {
318 for (int64_t i = 0; i < m; ++i)
319 B(i, k) /= conj(A(k, k));
321 for (int64_t j = k+1; j < n; ++j) {
322 for (int64_t i = 0; i < m; ++i)
323 B(i, j) -= B(i, k)*conj(A(j, k));
325 for (int64_t i = 0; i < m; ++i)
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67
void trsm(blas::Layout layout, blas::Side side, blas::Uplo uplo, blas::Op trans, blas::Diag diag, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float *B, int64_t ldb, blas::Queue &queue)
GPU device, float version.
Definition device_trsm.cc:104