6#ifndef BLAS_BATCH_COMMON_HH
7#define BLAS_BATCH_COMMON_HH
16#define INTERNAL_INFO_DEFAULT (-1000)
19T extract(std::vector<T>
const &ivector,
const int64_t index)
21 return (ivector.size() == 1) ? ivector[0] : ivector[index];
29 std::vector<blas::Op>
const &transA,
30 std::vector<blas::Op>
const &transB,
31 std::vector<int64_t>
const &m,
32 std::vector<int64_t>
const &n,
33 std::vector<int64_t>
const &k,
34 std::vector<T >
const &alpha,
35 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
36 std::vector<T*>
const &B, std::vector<int64_t>
const &ldb,
37 std::vector<T >
const &beta,
38 std::vector<T*>
const &C, std::vector<int64_t>
const &ldc,
39 const size_t batchCount, std::vector<int64_t> &info)
42 blas_error_if( (transA.size() != 1 && transA.size() != batchCount) );
43 blas_error_if( (transB.size() != 1 && transB.size() != batchCount) );
45 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
46 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
47 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
49 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
50 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
52 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
53 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
54 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
58 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
59 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
60 blas_error_if( (C.size() < batchCount) );
62 blas_error_if( A.size() == 1 && (m.size() > 1 || k.size() > 1 || lda.size() > 1) );
63 blas_error_if( B.size() == 1 && (k.size() > 1 || n.size() > 1 || ldb.size() > 1) );
64 blas_error_if( C.size() == 1 &&
65 (transA.size() > 1 || transB.size() > 1 ||
66 m.size() > 1 || n.size() > 1 || k.size() > 1 ||
67 alpha.size() > 1 || beta.size() > 1 ||
68 lda.size() > 1 || ldb.size() > 1 || ldc.size() > 1 ||
69 A.size() > 1 || B.size() > 1
73 int64_t* internal_info;
74 if (info.size() == 1) {
75 internal_info =
new int64_t[batchCount];
78 internal_info = &info[0];
81 #pragma omp parallel for schedule(dynamic)
82 for (
size_t i = 0; i < batchCount; ++i) {
83 Op transA_ = extract<Op>(transA, i);
84 Op transB_ = extract<Op>(transB, i);
86 int64_t m_ = extract<int64_t>(m, i);
87 int64_t n_ = extract<int64_t>(n, i);
88 int64_t k_ = extract<int64_t>(k, i);
90 int64_t lda_ = extract<int64_t>(lda, i);
91 int64_t ldb_ = extract<int64_t>(ldb, i);
92 int64_t ldc_ = extract<int64_t>(ldc, i);
94 int64_t nrowA_ = ((transA_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? m_ : k_;
95 int64_t nrowB_ = ((transB_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? k_ : n_;
96 int64_t nrowC_ = (layout == Layout::ColMajor) ? m_ : n_;
99 if (transA_ != Op::NoTrans &&
100 transA_ != Op::Trans &&
101 transA_ != Op::ConjTrans) {
102 internal_info[i] = -2;
104 else if (transB_ != Op::NoTrans &&
105 transB_ != Op::Trans &&
106 transB_ != Op::ConjTrans) {
107 internal_info[i] = -3;
109 else if (m_ < 0) internal_info[i] = -4;
110 else if (n_ < 0) internal_info[i] = -5;
111 else if (k_ < 0) internal_info[i] = -6;
112 else if (lda_ < nrowA_) internal_info[i] = -8;
113 else if (ldb_ < nrowB_) internal_info[i] = -11;
114 else if (ldc_ < nrowC_) internal_info[i] = -14;
117 if (info.size() == 1) {
119 int64_t lerror = INTERNAL_INFO_DEFAULT;
120 #pragma omp parallel for reduction(max:lerror)
121 for (
size_t i = 0; i < batchCount; ++i) {
122 if (internal_info[i] == 0)
124 lerror = std::max(lerror, internal_info[i]);
126 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
129 delete[] internal_info;
132 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
136 #pragma omp parallel for reduction(+:info_)
137 for (
size_t i = 0; i < batchCount; ++i) {
140 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");
149 std::vector<blas::Side>
const &side,
150 std::vector<blas::Uplo>
const &uplo,
151 std::vector<blas::Op>
const &trans,
152 std::vector<blas::Diag>
const &diag,
153 std::vector<int64_t>
const &m,
154 std::vector<int64_t>
const &n,
155 std::vector<T>
const &alpha,
156 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
157 std::vector<T*>
const &B, std::vector<int64_t>
const &ldb,
158 const size_t batchCount, std::vector<int64_t> &info)
161 blas_error_if( (side.size() != 1 && side.size() != batchCount) );
162 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
163 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
164 blas_error_if( (diag.size() != 1 && diag.size() != batchCount) );
166 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
167 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
171 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
172 blas_error_if( B.size() < batchCount );
174 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
175 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
177 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
179 blas_error_if( A.size() == 1 && ( lda.size() > 1 ||
181 (side[0] == Side::Left && m.size() > 1) ||
182 (side[0] == Side::Right && n.size() > 1) ));
183 blas_error_if( B.size() == 1 && ( side.size() > 1 || uplo.size() > 1 ||
184 trans.size() > 1 || diag.size() > 1 ||
185 m.size() > 1 || n.size() > 1 ||
186 alpha.size() > 1 || A.size() > 1 ||
187 lda.size() > 1 || ldb.size() > 1 ));
189 int64_t* internal_info;
190 if (info.size() == 1) {
191 internal_info =
new int64_t[batchCount];
194 internal_info = &info[0];
197 #pragma omp parallel for schedule(dynamic)
198 for (
size_t i = 0; i < batchCount; ++i) {
199 Side side_ = extract<Side>( side, i );
200 Uplo uplo_ = extract<Uplo>( uplo, i );
201 Op trans_ = extract<Op >( trans, i );
202 Diag diag_ = extract<Diag>( diag, i );
204 int64_t m_ = extract<int64_t>(m, i);
205 int64_t n_ = extract<int64_t>(n, i);
207 int64_t lda_ = extract<int64_t>(lda, i);
208 int64_t ldb_ = extract<int64_t>(ldb, i);
210 int64_t nrowA_ = (side_ == Side::Left) ? m_ : n_;
211 int64_t nrowB_ = (layout == Layout::ColMajor) ? m_ : n_;
213 internal_info[i] = 0;
214 if (side_ != Side::Left && side_ != Side::Right) {
215 internal_info[i] = -2;
217 else if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
218 internal_info[i] = -3;
220 else if (trans_ != Op::NoTrans && trans_ != Op::Trans && trans_ != Op::ConjTrans) {
221 internal_info[i] = -4;
223 else if (diag_ != Diag::NonUnit && diag_ != Diag::Unit) {
224 internal_info[i] = -5;
226 else if (m_ < 0) internal_info[i] = -6;
227 else if (n_ < 0) internal_info[i] = -7;
228 else if (lda_ < nrowA_) internal_info[i] = -10;
229 else if (ldb_ < nrowB_) internal_info[i] = -12;
232 if (info.size() == 1) {
234 int64_t lerror = INTERNAL_INFO_DEFAULT;
235 #pragma omp parallel for reduction(max:lerror)
236 for (
size_t i = 0; i < batchCount; ++i) {
237 if (internal_info[i] == 0)
239 lerror = std::max(lerror, internal_info[i]);
241 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
244 delete[] internal_info;
247 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
251 #pragma omp parallel for reduction(+:info_)
252 for (
size_t i = 0; i < batchCount; ++i) {
255 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");
264 std::vector<blas::Side>
const &side,
265 std::vector<blas::Uplo>
const &uplo,
266 std::vector<blas::Op>
const &trans,
267 std::vector<blas::Diag>
const &diag,
268 std::vector<int64_t>
const &m,
269 std::vector<int64_t>
const &n,
270 std::vector<T>
const &alpha,
271 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
272 std::vector<T*>
const &B, std::vector<int64_t>
const &ldb,
273 const size_t batchCount, std::vector<int64_t> &info)
276 blas_error_if( (side.size() != 1 && side.size() != batchCount) );
277 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
278 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
279 blas_error_if( (diag.size() != 1 && diag.size() != batchCount) );
281 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
282 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
286 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
287 blas_error_if( B.size() < batchCount );
289 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
290 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
292 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
294 blas_error_if( A.size() == 1 && ( lda.size() > 1 ||
296 (side[0] == Side::Left && m.size() > 1) ||
297 (side[0] == Side::Right && n.size() > 1) ));
298 blas_error_if( B.size() == 1 && ( side.size() > 1 || uplo.size() > 1 ||
299 trans.size() > 1 || diag.size() > 1 ||
300 m.size() > 1 || n.size() > 1 ||
301 alpha.size() > 1 || A.size() > 1 ||
302 lda.size() > 1 || ldb.size() > 1 ));
304 int64_t* internal_info;
305 if (info.size() == 1) {
306 internal_info =
new int64_t[batchCount];
309 internal_info = &info[0];
312 #pragma omp parallel for schedule(dynamic)
313 for (
size_t i = 0; i < batchCount; ++i) {
314 Side side_ = extract<Side>( side, i );
315 Uplo uplo_ = extract<Uplo>( uplo, i );
316 Op trans_ = extract<Op >( trans, i );
317 Diag diag_ = extract<Diag>( diag, i );
319 int64_t m_ = extract<int64_t>(m, i);
320 int64_t n_ = extract<int64_t>(n, i);
322 int64_t lda_ = extract<int64_t>(lda, i);
323 int64_t ldb_ = extract<int64_t>(ldb, i);
325 int64_t nrowA_ = (side_ == Side::Left) ? m_ : n_;
326 int64_t nrowB_ = (layout == Layout::ColMajor) ? m_ : n_;
328 internal_info[i] = 0;
329 if (side_ != Side::Left && side_ != Side::Right) {
330 internal_info[i] = -2;
332 else if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
333 internal_info[i] = -3;
335 else if (trans_ != Op::NoTrans && trans_ != Op::Trans && trans_ != Op::ConjTrans) {
336 internal_info[i] = -4;
338 else if (diag_ != Diag::NonUnit && diag_ != Diag::Unit) {
339 internal_info[i] = -5;
341 else if (m_ < 0) internal_info[i] = -6;
342 else if (n_ < 0) internal_info[i] = -7;
343 else if (lda_ < nrowA_) internal_info[i] = -10;
344 else if (ldb_ < nrowB_) internal_info[i] = -12;
347 if (info.size() == 1) {
349 int64_t lerror = INTERNAL_INFO_DEFAULT;
350 #pragma omp parallel for reduction(max:lerror)
351 for (
size_t i = 0; i < batchCount; ++i) {
352 if (internal_info[i] == 0)
354 lerror = std::max(lerror, internal_info[i]);
356 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
359 delete[] internal_info;
362 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
366 #pragma omp parallel for reduction(+:info_)
367 for (
size_t i = 0; i < batchCount; ++i) {
370 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");
379 std::vector<blas::Side>
const &side,
380 std::vector<blas::Uplo>
const &uplo,
381 std::vector<int64_t>
const &m,
382 std::vector<int64_t>
const &n,
383 std::vector<T>
const &alpha,
384 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
385 std::vector<T*>
const &B, std::vector<int64_t>
const &ldb,
386 std::vector<T>
const &beta,
387 std::vector<T*>
const &C, std::vector<int64_t>
const &ldc,
388 const size_t batchCount, std::vector<int64_t> &info)
391 blas_error_if( (side.size() != 1 && side.size() != batchCount) );
392 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
394 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
395 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
399 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
400 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
401 blas_error_if( C.size() < batchCount );
403 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
404 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
405 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
407 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
408 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
410 blas_error_if( A.size() == 1 &&
413 (side[0] == Side::Left && m.size() > 1) ||
414 (side[0] == Side::Right && n.size() > 1) ));
416 blas_error_if( B.size() == 1 &&
421 blas_error_if( C.size() == 1 &&
434 int64_t* internal_info;
435 if (info.size() == 1) {
436 internal_info =
new int64_t[batchCount];
439 internal_info = &info[0];
442 #pragma omp parallel for schedule(dynamic)
443 for (
size_t i = 0; i < batchCount; ++i) {
444 Side side_ = extract<Side>( side, i );
445 Uplo uplo_ = extract<Uplo>( uplo, i );
447 int64_t m_ = extract<int64_t>(m, i);
448 int64_t n_ = extract<int64_t>(n, i);
450 int64_t lda_ = extract<int64_t>(lda, i);
451 int64_t ldb_ = extract<int64_t>(ldb, i);
452 int64_t ldc_ = extract<int64_t>(ldc, i);
454 int64_t nrowA_ = (side_ == Side::Left) ? m_ : n_;
455 int64_t nrowB_ = (layout == Layout::ColMajor) ? m_ : n_;
456 int64_t nrowC_ = (layout == Layout::ColMajor) ? m_ : n_;
458 internal_info[i] = 0;
459 if (side_ != Side::Left && side_ != Side::Right) {
460 internal_info[i] = -2;
462 else if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
463 internal_info[i] = -3;
465 else if (m_ < 0) internal_info[i] = -4;
466 else if (n_ < 0) internal_info[i] = -5;
467 else if (lda_ < nrowA_) internal_info[i] = -8;
468 else if (ldb_ < nrowB_) internal_info[i] = -10;
469 else if (ldc_ < nrowC_) internal_info[i] = -13;
472 if (info.size() == 1) {
474 int64_t lerror = INTERNAL_INFO_DEFAULT;
475 #pragma omp parallel for reduction(max:lerror)
476 for (
size_t i = 0; i < batchCount; ++i) {
477 if (internal_info[i] == 0)
479 lerror = std::max(lerror, internal_info[i]);
481 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
484 delete[] internal_info;
487 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
491 #pragma omp parallel for reduction(+:info_)
492 for (
size_t i = 0; i < batchCount; ++i) {
495 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");
501template <
typename T,
typename scalarT>
504 std::vector<blas::Uplo>
const &uplo,
505 std::vector<blas::Op>
const &trans,
506 std::vector<int64_t>
const &n,
507 std::vector<int64_t>
const &k,
508 std::vector<scalarT>
const &alpha,
509 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
510 std::vector<scalarT>
const &beta,
511 std::vector<T*>
const &C, std::vector<int64_t>
const &ldc,
512 const size_t batchCount, std::vector<int64_t> &info)
515 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
516 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
518 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
519 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
523 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
524 blas_error_if( C.size() < batchCount );
526 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
527 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
529 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
530 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
532 blas_error_if( A.size() == 1 &&
536 (trans.size() > 1 && n[0] != k[0]) ));
538 blas_error_if( C.size() == 1 &&
549 int64_t* internal_info;
550 if (info.size() == 1) {
551 internal_info =
new int64_t[batchCount];
554 internal_info = &info[0];
557 #pragma omp parallel for schedule(dynamic)
558 for (
size_t i = 0; i < batchCount; ++i) {
559 Uplo uplo_ = extract<Uplo>( uplo, i );
560 Op trans_ = extract<Op> ( trans, i );
562 int64_t n_ = extract<int64_t>(n, i);
563 int64_t k_ = extract<int64_t>(k, i);
565 int64_t lda_ = extract<int64_t>(lda, i);
566 int64_t ldc_ = extract<int64_t>(ldc, i);
568 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
570 internal_info[i] = 0;
571 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
572 internal_info[i] = -2;
574 else if (trans_ != Op::NoTrans && trans_ != Op::ConjTrans) {
575 internal_info[i] = -3;
577 else if (n_ < 0) internal_info[i] = -4;
578 else if (k_ < 0) internal_info[i] = -5;
579 else if (lda_ < nrowA_) internal_info[i] = -8;
580 else if (ldc_ < n_) internal_info[i] = -11;
583 if (info.size() == 1) {
585 int64_t lerror = INTERNAL_INFO_DEFAULT;
586 #pragma omp parallel for reduction(max:lerror)
587 for (
size_t i = 0; i < batchCount; ++i) {
588 if (internal_info[i] == 0)
590 lerror = std::max(lerror, internal_info[i]);
592 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
595 delete[] internal_info;
598 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
602 #pragma omp parallel for reduction(+:info_)
603 for (
size_t i = 0; i < batchCount; ++i) {
606 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");
615 std::vector<blas::Side>
const &side,
616 std::vector<blas::Uplo>
const &uplo,
617 std::vector<int64_t>
const &m,
618 std::vector<int64_t>
const &n,
619 std::vector<T>
const &alpha,
620 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
621 std::vector<T*>
const &B, std::vector<int64_t>
const &ldb,
622 std::vector<T>
const &beta,
623 std::vector<T*>
const &C, std::vector<int64_t>
const &ldc,
624 const size_t batchCount, std::vector<int64_t> &info)
626 hemm_check(layout, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc, batchCount, info);
634 std::vector<blas::Uplo>
const &uplo,
635 std::vector<blas::Op>
const &trans,
636 std::vector<int64_t>
const &n,
637 std::vector<int64_t>
const &k,
638 std::vector<T>
const &alpha,
639 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
640 std::vector<T>
const &beta,
641 std::vector<T*>
const &C, std::vector<int64_t>
const &ldc,
642 const size_t batchCount, std::vector<int64_t> &info)
645 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
646 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
648 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
649 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
653 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
654 blas_error_if( C.size() < batchCount );
656 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
657 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
659 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
660 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
662 blas_error_if( A.size() == 1 &&
666 (trans.size() > 1 && n[0] != k[0]) ));
668 blas_error_if( C.size() == 1 &&
679 int64_t* internal_info;
680 if (info.size() == 1) {
681 internal_info =
new int64_t[batchCount];
684 internal_info = &info[0];
687 #pragma omp parallel for schedule(dynamic)
688 for (
size_t i = 0; i < batchCount; ++i) {
689 Uplo uplo_ = extract<Uplo>( uplo, i );
690 Op trans_ = extract<Op> ( trans, i );
692 int64_t n_ = extract<int64_t>(n, i);
693 int64_t k_ = extract<int64_t>(k, i);
695 int64_t lda_ = extract<int64_t>(lda, i);
696 int64_t ldc_ = extract<int64_t>(ldc, i);
698 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
700 internal_info[i] = 0;
701 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
702 internal_info[i] = -2;
704 else if (trans_ != Op::NoTrans && trans_ != Op::Trans) {
705 internal_info[i] = -3;
707 else if (n_ < 0) internal_info[i] = -4;
708 else if (k_ < 0) internal_info[i] = -5;
709 else if (lda_ < nrowA_) internal_info[i] = -8;
710 else if (ldc_ < n_) internal_info[i] = -11;
713 if (info.size() == 1) {
715 int64_t lerror = INTERNAL_INFO_DEFAULT;
716 #pragma omp parallel for reduction(max:lerror)
717 for (
size_t i = 0; i < batchCount; ++i) {
718 if (internal_info[i] == 0)
720 lerror = std::max(lerror, internal_info[i]);
722 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
725 delete[] internal_info;
728 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
732 #pragma omp parallel for reduction(+:info_)
733 for (
size_t i = 0; i < batchCount; ++i) {
736 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");
742template <
typename T,
typename scalarT>
745 std::vector<blas::Uplo>
const &uplo,
746 std::vector<blas::Op>
const &trans,
747 std::vector<int64_t>
const &n,
748 std::vector<int64_t>
const &k,
749 std::vector<T>
const &alpha,
750 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
751 std::vector<T*>
const &B, std::vector<int64_t>
const &ldb,
752 std::vector<scalarT>
const &beta,
753 std::vector<T*>
const &C, std::vector<int64_t>
const &ldc,
754 const size_t batchCount, std::vector<int64_t> &info)
757 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
758 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
760 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
761 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
765 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
766 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
767 blas_error_if( C.size() < batchCount );
769 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
770 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
771 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
773 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
774 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
776 blas_error_if( A.size() == 1 &&
780 (trans.size() > 1 && n[0] != k[0]) ));
782 blas_error_if( B.size() == 1 &&
786 (trans.size() > 1 && n[0] != k[0]) ));
788 blas_error_if( C.size() == 1 &&
801 int64_t* internal_info;
802 if (info.size() == 1) {
803 internal_info =
new int64_t[batchCount];
806 internal_info = &info[0];
809 #pragma omp parallel for schedule(dynamic)
810 for (
size_t i = 0; i < batchCount; ++i) {
811 Uplo uplo_ = extract<Uplo>( uplo, i );
812 Op trans_ = extract<Op> ( trans, i );
814 int64_t n_ = extract<int64_t>(n, i);
815 int64_t k_ = extract<int64_t>(k, i);
817 int64_t lda_ = extract<int64_t>(lda, i);
818 int64_t ldb_ = extract<int64_t>(ldb, i);
819 int64_t ldc_ = extract<int64_t>(ldc, i);
821 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
822 int64_t nrowB_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
824 internal_info[i] = 0;
825 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
826 internal_info[i] = -2;
828 else if (trans_ != Op::NoTrans && trans_ != Op::ConjTrans) {
829 internal_info[i] = -3;
831 else if (n_ < 0) internal_info[i] = -4;
832 else if (k_ < 0) internal_info[i] = -5;
833 else if (lda_ < nrowA_) internal_info[i] = -8;
834 else if (ldb_ < nrowB_) internal_info[i] = -10;
835 else if (ldc_ < n_) internal_info[i] = -13;
838 if (info.size() == 1) {
840 int64_t lerror = INTERNAL_INFO_DEFAULT;
841 #pragma omp parallel for reduction(max:lerror)
842 for (
size_t i = 0; i < batchCount; ++i) {
843 if (internal_info[i] == 0)
845 lerror = std::max(lerror, internal_info[i]);
847 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
850 delete[] internal_info;
853 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
857 #pragma omp parallel for reduction(+:info_)
858 for (
size_t i = 0; i < batchCount; ++i) {
861 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");
870 std::vector<blas::Uplo>
const &uplo,
871 std::vector<blas::Op>
const &trans,
872 std::vector<int64_t>
const &n,
873 std::vector<int64_t>
const &k,
874 std::vector<T>
const &alpha,
875 std::vector<T*>
const &A, std::vector<int64_t>
const &lda,
876 std::vector<T*>
const &B, std::vector<int64_t>
const &ldb,
877 std::vector<T>
const &beta,
878 std::vector<T*>
const &C, std::vector<int64_t>
const &ldc,
879 const size_t batchCount, std::vector<int64_t> &info)
882 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
883 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
885 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
886 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
890 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
891 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
892 blas_error_if( C.size() < batchCount );
894 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
895 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
896 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
898 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
899 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
901 blas_error_if( A.size() == 1 &&
905 (trans.size() > 1 && n[0] != k[0]) ));
907 blas_error_if( B.size() == 1 &&
911 (trans.size() > 1 && n[0] != k[0]) ));
913 blas_error_if( C.size() == 1 &&
926 int64_t* internal_info;
927 if (info.size() == 1) {
928 internal_info =
new int64_t[batchCount];
931 internal_info = &info[0];
934 #pragma omp parallel for schedule(dynamic)
935 for (
size_t i = 0; i < batchCount; ++i) {
936 Uplo uplo_ = extract<Uplo>( uplo, i );
937 Op trans_ = extract<Op> ( trans, i );
939 int64_t n_ = extract<int64_t>(n, i);
940 int64_t k_ = extract<int64_t>(k, i);
942 int64_t lda_ = extract<int64_t>(lda, i);
943 int64_t ldb_ = extract<int64_t>(ldb, i);
944 int64_t ldc_ = extract<int64_t>(ldc, i);
946 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
947 int64_t nrowB_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
949 internal_info[i] = 0;
950 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
951 internal_info[i] = -2;
953 else if (trans_ != Op::NoTrans && trans_ != Op::Trans) {
954 internal_info[i] = -3;
956 else if (n_ < 0) internal_info[i] = -4;
957 else if (k_ < 0) internal_info[i] = -5;
958 else if (lda_ < nrowA_) internal_info[i] = -8;
959 else if (ldb_ < nrowB_) internal_info[i] = -10;
960 else if (ldc_ < n_) internal_info[i] = -13;
963 if (info.size() == 1) {
965 int64_t lerror = INTERNAL_INFO_DEFAULT;
966 #pragma omp parallel for reduction(max:lerror)
967 for (
size_t i = 0; i < batchCount; ++i) {
968 if (internal_info[i] == 0)
970 lerror = std::max(lerror, internal_info[i]);
972 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
975 delete[] internal_info;
978 blas_error_if_msg( info[0] != 0,
"info = %lld", llong( info[0] ) );
982 #pragma omp parallel for reduction(+:info_)
983 for (
size_t i = 0; i < batchCount; ++i) {
986 blas_error_if_msg( info_ != 0,
"One or more non-zero entry in vector info");