BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
batch_common.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_BATCH_COMMON_HH
7#define BLAS_BATCH_COMMON_HH
8
9#include "blas/util.hh"
10#include <algorithm> // std::min/max
11#include <vector>
12
13namespace blas {
14namespace batch {
15
16#define INTERNAL_INFO_DEFAULT (-1000)
17
18template <typename T>
19T extract(std::vector<T> const &ivector, const int64_t index)
20{
21 return (ivector.size() == 1) ? ivector[0] : ivector[index];
22}
23
24// -----------------------------------------------------------------------------
25// batch gemm check
26template <typename T>
27void gemm_check(
28 blas::Layout layout,
29 std::vector<blas::Op> const &transA,
30 std::vector<blas::Op> const &transB,
31 std::vector<int64_t> const &m,
32 std::vector<int64_t> const &n,
33 std::vector<int64_t> const &k,
34 std::vector<T > const &alpha,
35 std::vector<T*> const &A, std::vector<int64_t> const &lda,
36 std::vector<T*> const &B, std::vector<int64_t> const &ldb,
37 std::vector<T > const &beta,
38 std::vector<T*> const &C, std::vector<int64_t> const &ldc,
39 const size_t batchCount, std::vector<int64_t> &info)
40{
41 // size error checking
42 blas_error_if( (transA.size() != 1 && transA.size() != batchCount) );
43 blas_error_if( (transB.size() != 1 && transB.size() != batchCount) );
44
45 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
46 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
47 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
48
49 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
50 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
51
52 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
53 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
54 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
55
56 // to support checking errors for the group interface, batchCount will be equal to group_count
57 // but the data arrays are generally >= group_count
58 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
59 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
60 blas_error_if( (C.size() < batchCount) );
61
62 blas_error_if( A.size() == 1 && (m.size() > 1 || k.size() > 1 || lda.size() > 1) );
63 blas_error_if( B.size() == 1 && (k.size() > 1 || n.size() > 1 || ldb.size() > 1) );
64 blas_error_if( C.size() == 1 &&
65 (transA.size() > 1 || transB.size() > 1 ||
66 m.size() > 1 || n.size() > 1 || k.size() > 1 ||
67 alpha.size() > 1 || beta.size() > 1 ||
68 lda.size() > 1 || ldb.size() > 1 || ldc.size() > 1 ||
69 A.size() > 1 || B.size() > 1
70 )
71 );
72
73 int64_t* internal_info;
74 if (info.size() == 1) {
75 internal_info = new int64_t[batchCount];
76 }
77 else {
78 internal_info = &info[0];
79 }
80
81 #pragma omp parallel for schedule(dynamic)
82 for (size_t i = 0; i < batchCount; ++i) {
83 Op transA_ = extract<Op>(transA, i);
84 Op transB_ = extract<Op>(transB, i);
85
86 int64_t m_ = extract<int64_t>(m, i);
87 int64_t n_ = extract<int64_t>(n, i);
88 int64_t k_ = extract<int64_t>(k, i);
89
90 int64_t lda_ = extract<int64_t>(lda, i);
91 int64_t ldb_ = extract<int64_t>(ldb, i);
92 int64_t ldc_ = extract<int64_t>(ldc, i);
93
94 int64_t nrowA_ = ((transA_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? m_ : k_;
95 int64_t nrowB_ = ((transB_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? k_ : n_;
96 int64_t nrowC_ = (layout == Layout::ColMajor) ? m_ : n_;
97
98 internal_info[i] = 0;
99 if (transA_ != Op::NoTrans &&
100 transA_ != Op::Trans &&
101 transA_ != Op::ConjTrans) {
102 internal_info[i] = -2;
103 }
104 else if (transB_ != Op::NoTrans &&
105 transB_ != Op::Trans &&
106 transB_ != Op::ConjTrans) {
107 internal_info[i] = -3;
108 }
109 else if (m_ < 0) internal_info[i] = -4;
110 else if (n_ < 0) internal_info[i] = -5;
111 else if (k_ < 0) internal_info[i] = -6;
112 else if (lda_ < nrowA_) internal_info[i] = -8;
113 else if (ldb_ < nrowB_) internal_info[i] = -11;
114 else if (ldc_ < nrowC_) internal_info[i] = -14;
115 }
116
117 if (info.size() == 1) {
118 // do a reduction that finds the first argument to encounter an error
119 int64_t lerror = INTERNAL_INFO_DEFAULT;
120 #pragma omp parallel for reduction(max:lerror)
121 for (size_t i = 0; i < batchCount; ++i) {
122 if (internal_info[i] == 0)
123 continue; // skip problems that passed error checks
124 lerror = std::max(lerror, internal_info[i]);
125 }
126 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
127
128 // delete the internal vector
129 delete[] internal_info;
130
131 // throw an exception if needed
132 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
133 }
134 else {
135 int64_t info_ = 0;
136 #pragma omp parallel for reduction(+:info_)
137 for (size_t i = 0; i < batchCount; ++i) {
138 info_ += info[i];
139 }
140 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
141 }
142}
143
144// -----------------------------------------------------------------------------
145// batch trsm check
146template <typename T>
147void trsm_check(
148 blas::Layout layout,
149 std::vector<blas::Side> const &side,
150 std::vector<blas::Uplo> const &uplo,
151 std::vector<blas::Op> const &trans,
152 std::vector<blas::Diag> const &diag,
153 std::vector<int64_t> const &m,
154 std::vector<int64_t> const &n,
155 std::vector<T> const &alpha,
156 std::vector<T*> const &A, std::vector<int64_t> const &lda,
157 std::vector<T*> const &B, std::vector<int64_t> const &ldb,
158 const size_t batchCount, std::vector<int64_t> &info)
159{
160 // size error checking
161 blas_error_if( (side.size() != 1 && side.size() != batchCount) );
162 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
163 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
164 blas_error_if( (diag.size() != 1 && diag.size() != batchCount) );
165
166 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
167 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
168
169 // to support checking errors for the group interface, batchCount will be equal to group_count
170 // but the data arrays are generally >= group_count
171 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
172 blas_error_if( B.size() < batchCount );
173
174 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
175 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
176
177 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
178
179 blas_error_if( A.size() == 1 && ( lda.size() > 1 ||
180 side.size() > 1 ||
181 (side[0] == Side::Left && m.size() > 1) ||
182 (side[0] == Side::Right && n.size() > 1) ));
183 blas_error_if( B.size() == 1 && ( side.size() > 1 || uplo.size() > 1 ||
184 trans.size() > 1 || diag.size() > 1 ||
185 m.size() > 1 || n.size() > 1 ||
186 alpha.size() > 1 || A.size() > 1 ||
187 lda.size() > 1 || ldb.size() > 1 ));
188
189 int64_t* internal_info;
190 if (info.size() == 1) {
191 internal_info = new int64_t[batchCount];
192 }
193 else {
194 internal_info = &info[0];
195 }
196
197 #pragma omp parallel for schedule(dynamic)
198 for (size_t i = 0; i < batchCount; ++i) {
199 Side side_ = extract<Side>( side, i );
200 Uplo uplo_ = extract<Uplo>( uplo, i );
201 Op trans_ = extract<Op >( trans, i );
202 Diag diag_ = extract<Diag>( diag, i );
203
204 int64_t m_ = extract<int64_t>(m, i);
205 int64_t n_ = extract<int64_t>(n, i);
206
207 int64_t lda_ = extract<int64_t>(lda, i);
208 int64_t ldb_ = extract<int64_t>(ldb, i);
209
210 int64_t nrowA_ = (side_ == Side::Left) ? m_ : n_;
211 int64_t nrowB_ = (layout == Layout::ColMajor) ? m_ : n_;
212
213 internal_info[i] = 0;
214 if (side_ != Side::Left && side_ != Side::Right) {
215 internal_info[i] = -2;
216 }
217 else if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
218 internal_info[i] = -3;
219 }
220 else if (trans_ != Op::NoTrans && trans_ != Op::Trans && trans_ != Op::ConjTrans) {
221 internal_info[i] = -4;
222 }
223 else if (diag_ != Diag::NonUnit && diag_ != Diag::Unit) {
224 internal_info[i] = -5;
225 }
226 else if (m_ < 0) internal_info[i] = -6;
227 else if (n_ < 0) internal_info[i] = -7;
228 else if (lda_ < nrowA_) internal_info[i] = -10;
229 else if (ldb_ < nrowB_) internal_info[i] = -12;
230 }
231
232 if (info.size() == 1) {
233 // do a reduction that finds the first argument to encounter an error
234 int64_t lerror = INTERNAL_INFO_DEFAULT;
235 #pragma omp parallel for reduction(max:lerror)
236 for (size_t i = 0; i < batchCount; ++i) {
237 if (internal_info[i] == 0)
238 continue; // skip problems that passed error checks
239 lerror = std::max(lerror, internal_info[i]);
240 }
241 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
242
243 // delete the internal vector
244 delete[] internal_info;
245
246 // throw an exception if needed
247 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
248 }
249 else {
250 int64_t info_ = 0;
251 #pragma omp parallel for reduction(+:info_)
252 for (size_t i = 0; i < batchCount; ++i) {
253 info_ += info[i];
254 }
255 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
256 }
257}
258
259// -----------------------------------------------------------------------------
260// batch trmm check
261template <typename T>
262void trmm_check(
263 blas::Layout layout,
264 std::vector<blas::Side> const &side,
265 std::vector<blas::Uplo> const &uplo,
266 std::vector<blas::Op> const &trans,
267 std::vector<blas::Diag> const &diag,
268 std::vector<int64_t> const &m,
269 std::vector<int64_t> const &n,
270 std::vector<T> const &alpha,
271 std::vector<T*> const &A, std::vector<int64_t> const &lda,
272 std::vector<T*> const &B, std::vector<int64_t> const &ldb,
273 const size_t batchCount, std::vector<int64_t> &info)
274{
275 // size error checking
276 blas_error_if( (side.size() != 1 && side.size() != batchCount) );
277 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
278 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
279 blas_error_if( (diag.size() != 1 && diag.size() != batchCount) );
280
281 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
282 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
283
284 // to support checking errors for the group interface, batchCount will be equal to group_count
285 // but the data arrays are generally >= group_count
286 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
287 blas_error_if( B.size() < batchCount );
288
289 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
290 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
291
292 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
293
294 blas_error_if( A.size() == 1 && ( lda.size() > 1 ||
295 side.size() > 1 ||
296 (side[0] == Side::Left && m.size() > 1) ||
297 (side[0] == Side::Right && n.size() > 1) ));
298 blas_error_if( B.size() == 1 && ( side.size() > 1 || uplo.size() > 1 ||
299 trans.size() > 1 || diag.size() > 1 ||
300 m.size() > 1 || n.size() > 1 ||
301 alpha.size() > 1 || A.size() > 1 ||
302 lda.size() > 1 || ldb.size() > 1 ));
303
304 int64_t* internal_info;
305 if (info.size() == 1) {
306 internal_info = new int64_t[batchCount];
307 }
308 else {
309 internal_info = &info[0];
310 }
311
312 #pragma omp parallel for schedule(dynamic)
313 for (size_t i = 0; i < batchCount; ++i) {
314 Side side_ = extract<Side>( side, i );
315 Uplo uplo_ = extract<Uplo>( uplo, i );
316 Op trans_ = extract<Op >( trans, i );
317 Diag diag_ = extract<Diag>( diag, i );
318
319 int64_t m_ = extract<int64_t>(m, i);
320 int64_t n_ = extract<int64_t>(n, i);
321
322 int64_t lda_ = extract<int64_t>(lda, i);
323 int64_t ldb_ = extract<int64_t>(ldb, i);
324
325 int64_t nrowA_ = (side_ == Side::Left) ? m_ : n_;
326 int64_t nrowB_ = (layout == Layout::ColMajor) ? m_ : n_;
327
328 internal_info[i] = 0;
329 if (side_ != Side::Left && side_ != Side::Right) {
330 internal_info[i] = -2;
331 }
332 else if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
333 internal_info[i] = -3;
334 }
335 else if (trans_ != Op::NoTrans && trans_ != Op::Trans && trans_ != Op::ConjTrans) {
336 internal_info[i] = -4;
337 }
338 else if (diag_ != Diag::NonUnit && diag_ != Diag::Unit) {
339 internal_info[i] = -5;
340 }
341 else if (m_ < 0) internal_info[i] = -6;
342 else if (n_ < 0) internal_info[i] = -7;
343 else if (lda_ < nrowA_) internal_info[i] = -10;
344 else if (ldb_ < nrowB_) internal_info[i] = -12;
345 }
346
347 if (info.size() == 1) {
348 // do a reduction that finds the first argument to encounter an error
349 int64_t lerror = INTERNAL_INFO_DEFAULT;
350 #pragma omp parallel for reduction(max:lerror)
351 for (size_t i = 0; i < batchCount; ++i) {
352 if (internal_info[i] == 0)
353 continue; // skip problems that passed error checks
354 lerror = std::max(lerror, internal_info[i]);
355 }
356 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
357
358 // delete the internal vector
359 delete[] internal_info;
360
361 // throw an exception if needed
362 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
363 }
364 else {
365 int64_t info_ = 0;
366 #pragma omp parallel for reduction(+:info_)
367 for (size_t i = 0; i < batchCount; ++i) {
368 info_ += info[i];
369 }
370 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
371 }
372}
373
374// -----------------------------------------------------------------------------
375// batch hemm check
376template <typename T>
377void hemm_check(
378 blas::Layout layout,
379 std::vector<blas::Side> const &side,
380 std::vector<blas::Uplo> const &uplo,
381 std::vector<int64_t> const &m,
382 std::vector<int64_t> const &n,
383 std::vector<T> const &alpha,
384 std::vector<T*> const &A, std::vector<int64_t> const &lda,
385 std::vector<T*> const &B, std::vector<int64_t> const &ldb,
386 std::vector<T> const &beta,
387 std::vector<T*> const &C, std::vector<int64_t> const &ldc,
388 const size_t batchCount, std::vector<int64_t> &info)
389{
390 // size error checking
391 blas_error_if( (side.size() != 1 && side.size() != batchCount) );
392 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
393
394 blas_error_if( (m.size() != 1 && m.size() != batchCount) );
395 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
396
397 // to support checking errors for the group interface, batchCount will be equal to group_count
398 // but the data arrays are generally >= group_count
399 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
400 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
401 blas_error_if( C.size() < batchCount );
402
403 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
404 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
405 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
406
407 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
408 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
409
410 blas_error_if( A.size() == 1 &&
411 (lda.size() > 1 ||
412 side.size() > 1 ||
413 (side[0] == Side::Left && m.size() > 1) ||
414 (side[0] == Side::Right && n.size() > 1) ));
415
416 blas_error_if( B.size() == 1 &&
417 (m.size() > 1 ||
418 n.size() > 1 ||
419 ldb.size() > 1 ));
420
421 blas_error_if( C.size() == 1 &&
422 (side.size() > 1 ||
423 uplo.size() > 1 ||
424 m.size() > 1 ||
425 n.size() > 1 ||
426 alpha.size() > 1 ||
427 A.size() > 1 ||
428 lda.size() > 1 ||
429 B.size() > 1 ||
430 ldb.size() > 1 ||
431 beta.size() > 1 ||
432 ldc.size() > 1 ));
433
434 int64_t* internal_info;
435 if (info.size() == 1) {
436 internal_info = new int64_t[batchCount];
437 }
438 else {
439 internal_info = &info[0];
440 }
441
442 #pragma omp parallel for schedule(dynamic)
443 for (size_t i = 0; i < batchCount; ++i) {
444 Side side_ = extract<Side>( side, i );
445 Uplo uplo_ = extract<Uplo>( uplo, i );
446
447 int64_t m_ = extract<int64_t>(m, i);
448 int64_t n_ = extract<int64_t>(n, i);
449
450 int64_t lda_ = extract<int64_t>(lda, i);
451 int64_t ldb_ = extract<int64_t>(ldb, i);
452 int64_t ldc_ = extract<int64_t>(ldc, i);
453
454 int64_t nrowA_ = (side_ == Side::Left) ? m_ : n_;
455 int64_t nrowB_ = (layout == Layout::ColMajor) ? m_ : n_;
456 int64_t nrowC_ = (layout == Layout::ColMajor) ? m_ : n_;
457
458 internal_info[i] = 0;
459 if (side_ != Side::Left && side_ != Side::Right) {
460 internal_info[i] = -2;
461 }
462 else if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
463 internal_info[i] = -3;
464 }
465 else if (m_ < 0) internal_info[i] = -4;
466 else if (n_ < 0) internal_info[i] = -5;
467 else if (lda_ < nrowA_) internal_info[i] = -8;
468 else if (ldb_ < nrowB_) internal_info[i] = -10;
469 else if (ldc_ < nrowC_) internal_info[i] = -13;
470 }
471
472 if (info.size() == 1) {
473 // do a reduction that finds the first argument to encounter an error
474 int64_t lerror = INTERNAL_INFO_DEFAULT;
475 #pragma omp parallel for reduction(max:lerror)
476 for (size_t i = 0; i < batchCount; ++i) {
477 if (internal_info[i] == 0)
478 continue; // skip problems that passed error checks
479 lerror = std::max(lerror, internal_info[i]);
480 }
481 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
482
483 // delete the internal vector
484 delete[] internal_info;
485
486 // throw an exception if needed
487 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
488 }
489 else {
490 int64_t info_ = 0;
491 #pragma omp parallel for reduction(+:info_)
492 for (size_t i = 0; i < batchCount; ++i) {
493 info_ += info[i];
494 }
495 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
496 }
497}
498
499// -----------------------------------------------------------------------------
500// batch herk check
501template <typename T, typename scalarT>
502void herk_check(
503 blas::Layout layout,
504 std::vector<blas::Uplo> const &uplo,
505 std::vector<blas::Op> const &trans,
506 std::vector<int64_t> const &n,
507 std::vector<int64_t> const &k,
508 std::vector<scalarT> const &alpha,
509 std::vector<T*> const &A, std::vector<int64_t> const &lda,
510 std::vector<scalarT> const &beta,
511 std::vector<T*> const &C, std::vector<int64_t> const &ldc,
512 const size_t batchCount, std::vector<int64_t> &info)
513{
514 // size error checking
515 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
516 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
517
518 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
519 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
520
521 // to support checking errors for the group interface, batchCount will be equal to group_count
522 // but the data arrays are generally >= group_count
523 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
524 blas_error_if( C.size() < batchCount );
525
526 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
527 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
528
529 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
530 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
531
532 blas_error_if( A.size() == 1 &&
533 (lda.size() > 1 ||
534 n.size() > 1 ||
535 k.size() > 1 ||
536 (trans.size() > 1 && n[0] != k[0]) ));
537
538 blas_error_if( C.size() == 1 &&
539 (uplo.size() > 1 ||
540 trans.size() > 1 ||
541 n.size() > 1 ||
542 k.size() > 1 ||
543 alpha.size() > 1 ||
544 A.size() > 1 ||
545 lda.size() > 1 ||
546 beta.size() > 1 ||
547 ldc.size() > 1 ));
548
549 int64_t* internal_info;
550 if (info.size() == 1) {
551 internal_info = new int64_t[batchCount];
552 }
553 else {
554 internal_info = &info[0];
555 }
556
557 #pragma omp parallel for schedule(dynamic)
558 for (size_t i = 0; i < batchCount; ++i) {
559 Uplo uplo_ = extract<Uplo>( uplo, i );
560 Op trans_ = extract<Op> ( trans, i );
561
562 int64_t n_ = extract<int64_t>(n, i);
563 int64_t k_ = extract<int64_t>(k, i);
564
565 int64_t lda_ = extract<int64_t>(lda, i);
566 int64_t ldc_ = extract<int64_t>(ldc, i);
567
568 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
569
570 internal_info[i] = 0;
571 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
572 internal_info[i] = -2;
573 }
574 else if (trans_ != Op::NoTrans && trans_ != Op::ConjTrans) {
575 internal_info[i] = -3;
576 }
577 else if (n_ < 0) internal_info[i] = -4;
578 else if (k_ < 0) internal_info[i] = -5;
579 else if (lda_ < nrowA_) internal_info[i] = -8;
580 else if (ldc_ < n_) internal_info[i] = -11;
581 }
582
583 if (info.size() == 1) {
584 // do a reduction that finds the first argument to encounter an error
585 int64_t lerror = INTERNAL_INFO_DEFAULT;
586 #pragma omp parallel for reduction(max:lerror)
587 for (size_t i = 0; i < batchCount; ++i) {
588 if (internal_info[i] == 0)
589 continue; // skip problems that passed error checks
590 lerror = std::max(lerror, internal_info[i]);
591 }
592 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
593
594 // delete the internal vector
595 delete[] internal_info;
596
597 // throw an exception if needed
598 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
599 }
600 else {
601 int64_t info_ = 0;
602 #pragma omp parallel for reduction(+:info_)
603 for (size_t i = 0; i < batchCount; ++i) {
604 info_ += info[i];
605 }
606 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
607 }
608}
609
610// -----------------------------------------------------------------------------
611// batch hemm check
612template <typename T>
613void symm_check(
614 blas::Layout layout,
615 std::vector<blas::Side> const &side,
616 std::vector<blas::Uplo> const &uplo,
617 std::vector<int64_t> const &m,
618 std::vector<int64_t> const &n,
619 std::vector<T> const &alpha,
620 std::vector<T*> const &A, std::vector<int64_t> const &lda,
621 std::vector<T*> const &B, std::vector<int64_t> const &ldb,
622 std::vector<T> const &beta,
623 std::vector<T*> const &C, std::vector<int64_t> const &ldc,
624 const size_t batchCount, std::vector<int64_t> &info)
625{
626 hemm_check(layout, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc, batchCount, info);
627}
628
629// -----------------------------------------------------------------------------
630// batch syrk check
631template <typename T>
632void syrk_check(
633 blas::Layout layout,
634 std::vector<blas::Uplo> const &uplo,
635 std::vector<blas::Op> const &trans,
636 std::vector<int64_t> const &n,
637 std::vector<int64_t> const &k,
638 std::vector<T> const &alpha,
639 std::vector<T*> const &A, std::vector<int64_t> const &lda,
640 std::vector<T> const &beta,
641 std::vector<T*> const &C, std::vector<int64_t> const &ldc,
642 const size_t batchCount, std::vector<int64_t> &info)
643{
644 // size error checking
645 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
646 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
647
648 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
649 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
650
651 // to support checking errors for the group interface, batchCount will be equal to group_count
652 // but the data arrays are generally >= group_count
653 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
654 blas_error_if( C.size() < batchCount );
655
656 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
657 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
658
659 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
660 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
661
662 blas_error_if( A.size() == 1 &&
663 (lda.size() > 1 ||
664 n.size() > 1 ||
665 k.size() > 1 ||
666 (trans.size() > 1 && n[0] != k[0]) ));
667
668 blas_error_if( C.size() == 1 &&
669 (uplo.size() > 1 ||
670 trans.size() > 1 ||
671 n.size() > 1 ||
672 k.size() > 1 ||
673 alpha.size() > 1 ||
674 A.size() > 1 ||
675 lda.size() > 1 ||
676 beta.size() > 1 ||
677 ldc.size() > 1 ));
678
679 int64_t* internal_info;
680 if (info.size() == 1) {
681 internal_info = new int64_t[batchCount];
682 }
683 else {
684 internal_info = &info[0];
685 }
686
687 #pragma omp parallel for schedule(dynamic)
688 for (size_t i = 0; i < batchCount; ++i) {
689 Uplo uplo_ = extract<Uplo>( uplo, i );
690 Op trans_ = extract<Op> ( trans, i );
691
692 int64_t n_ = extract<int64_t>(n, i);
693 int64_t k_ = extract<int64_t>(k, i);
694
695 int64_t lda_ = extract<int64_t>(lda, i);
696 int64_t ldc_ = extract<int64_t>(ldc, i);
697
698 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
699
700 internal_info[i] = 0;
701 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
702 internal_info[i] = -2;
703 }
704 else if (trans_ != Op::NoTrans && trans_ != Op::Trans) {
705 internal_info[i] = -3;
706 }
707 else if (n_ < 0) internal_info[i] = -4;
708 else if (k_ < 0) internal_info[i] = -5;
709 else if (lda_ < nrowA_) internal_info[i] = -8;
710 else if (ldc_ < n_) internal_info[i] = -11;
711 }
712
713 if (info.size() == 1) {
714 // do a reduction that finds the first argument to encounter an error
715 int64_t lerror = INTERNAL_INFO_DEFAULT;
716 #pragma omp parallel for reduction(max:lerror)
717 for (size_t i = 0; i < batchCount; ++i) {
718 if (internal_info[i] == 0)
719 continue; // skip problems that passed error checks
720 lerror = std::max(lerror, internal_info[i]);
721 }
722 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
723
724 // delete the internal vector
725 delete[] internal_info;
726
727 // throw an exception if needed
728 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
729 }
730 else {
731 int64_t info_ = 0;
732 #pragma omp parallel for reduction(+:info_)
733 for (size_t i = 0; i < batchCount; ++i) {
734 info_ += info[i];
735 }
736 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
737 }
738}
739
740// -----------------------------------------------------------------------------
741// batch her2k check
742template <typename T, typename scalarT>
743void her2k_check(
744 blas::Layout layout,
745 std::vector<blas::Uplo> const &uplo,
746 std::vector<blas::Op> const &trans,
747 std::vector<int64_t> const &n,
748 std::vector<int64_t> const &k,
749 std::vector<T> const &alpha,
750 std::vector<T*> const &A, std::vector<int64_t> const &lda,
751 std::vector<T*> const &B, std::vector<int64_t> const &ldb,
752 std::vector<scalarT> const &beta,
753 std::vector<T*> const &C, std::vector<int64_t> const &ldc,
754 const size_t batchCount, std::vector<int64_t> &info)
755{
756 // size error checking
757 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
758 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
759
760 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
761 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
762
763 // to support checking errors for the group interface, batchCount will be equal to group_count
764 // but the data arrays are generally >= group_count
765 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
766 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
767 blas_error_if( C.size() < batchCount );
768
769 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
770 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
771 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
772
773 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
774 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
775
776 blas_error_if( A.size() == 1 &&
777 (lda.size() > 1 ||
778 n.size() > 1 ||
779 k.size() > 1 ||
780 (trans.size() > 1 && n[0] != k[0]) ));
781
782 blas_error_if( B.size() == 1 &&
783 (ldb.size() > 1 ||
784 n.size() > 1 ||
785 k.size() > 1 ||
786 (trans.size() > 1 && n[0] != k[0]) ));
787
788 blas_error_if( C.size() == 1 &&
789 (uplo.size() > 1 ||
790 trans.size() > 1 ||
791 n.size() > 1 ||
792 k.size() > 1 ||
793 alpha.size() > 1 ||
794 A.size() > 1 ||
795 lda.size() > 1 ||
796 B.size() > 1 ||
797 ldb.size() > 1 ||
798 beta.size() > 1 ||
799 ldc.size() > 1 ));
800
801 int64_t* internal_info;
802 if (info.size() == 1) {
803 internal_info = new int64_t[batchCount];
804 }
805 else {
806 internal_info = &info[0];
807 }
808
809 #pragma omp parallel for schedule(dynamic)
810 for (size_t i = 0; i < batchCount; ++i) {
811 Uplo uplo_ = extract<Uplo>( uplo, i );
812 Op trans_ = extract<Op> ( trans, i );
813
814 int64_t n_ = extract<int64_t>(n, i);
815 int64_t k_ = extract<int64_t>(k, i);
816
817 int64_t lda_ = extract<int64_t>(lda, i);
818 int64_t ldb_ = extract<int64_t>(ldb, i);
819 int64_t ldc_ = extract<int64_t>(ldc, i);
820
821 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
822 int64_t nrowB_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
823
824 internal_info[i] = 0;
825 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
826 internal_info[i] = -2;
827 }
828 else if (trans_ != Op::NoTrans && trans_ != Op::ConjTrans) {
829 internal_info[i] = -3;
830 }
831 else if (n_ < 0) internal_info[i] = -4;
832 else if (k_ < 0) internal_info[i] = -5;
833 else if (lda_ < nrowA_) internal_info[i] = -8;
834 else if (ldb_ < nrowB_) internal_info[i] = -10;
835 else if (ldc_ < n_) internal_info[i] = -13;
836 }
837
838 if (info.size() == 1) {
839 // do a reduction that finds the first argument to encounter an error
840 int64_t lerror = INTERNAL_INFO_DEFAULT;
841 #pragma omp parallel for reduction(max:lerror)
842 for (size_t i = 0; i < batchCount; ++i) {
843 if (internal_info[i] == 0)
844 continue; // skip problems that passed error checks
845 lerror = std::max(lerror, internal_info[i]);
846 }
847 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
848
849 // delete the internal vector
850 delete[] internal_info;
851
852 // throw an exception if needed
853 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
854 }
855 else {
856 int64_t info_ = 0;
857 #pragma omp parallel for reduction(+:info_)
858 for (size_t i = 0; i < batchCount; ++i) {
859 info_ += info[i];
860 }
861 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
862 }
863}
864
865// -----------------------------------------------------------------------------
866// batch syr2k check
867template <typename T>
868void syr2k_check(
869 blas::Layout layout,
870 std::vector<blas::Uplo> const &uplo,
871 std::vector<blas::Op> const &trans,
872 std::vector<int64_t> const &n,
873 std::vector<int64_t> const &k,
874 std::vector<T> const &alpha,
875 std::vector<T*> const &A, std::vector<int64_t> const &lda,
876 std::vector<T*> const &B, std::vector<int64_t> const &ldb,
877 std::vector<T> const &beta,
878 std::vector<T*> const &C, std::vector<int64_t> const &ldc,
879 const size_t batchCount, std::vector<int64_t> &info)
880{
881 // size error checking
882 blas_error_if( (uplo.size() != 1 && uplo.size() != batchCount) );
883 blas_error_if( (trans.size() != 1 && trans.size() != batchCount) );
884
885 blas_error_if( (n.size() != 1 && n.size() != batchCount) );
886 blas_error_if( (k.size() != 1 && k.size() != batchCount) );
887
888 // to support checking errors for the group interface, batchCount will be equal to group_count
889 // but the data arrays are generally >= group_count
890 blas_error_if( (A.size() != 1 && A.size() < batchCount) );
891 blas_error_if( (B.size() != 1 && B.size() < batchCount) );
892 blas_error_if( C.size() < batchCount );
893
894 blas_error_if( (lda.size() != 1 && lda.size() != batchCount) );
895 blas_error_if( (ldb.size() != 1 && ldb.size() != batchCount) );
896 blas_error_if( (ldc.size() != 1 && ldc.size() != batchCount) );
897
898 blas_error_if( (alpha.size() != 1 && alpha.size() != batchCount) );
899 blas_error_if( (beta.size() != 1 && beta.size() != batchCount) );
900
901 blas_error_if( A.size() == 1 &&
902 (lda.size() > 1 ||
903 n.size() > 1 ||
904 k.size() > 1 ||
905 (trans.size() > 1 && n[0] != k[0]) ));
906
907 blas_error_if( B.size() == 1 &&
908 (ldb.size() > 1 ||
909 n.size() > 1 ||
910 k.size() > 1 ||
911 (trans.size() > 1 && n[0] != k[0]) ));
912
913 blas_error_if( C.size() == 1 &&
914 (uplo.size() > 1 ||
915 trans.size() > 1 ||
916 n.size() > 1 ||
917 k.size() > 1 ||
918 alpha.size() > 1 ||
919 A.size() > 1 ||
920 lda.size() > 1 ||
921 B.size() > 1 ||
922 ldb.size() > 1 ||
923 beta.size() > 1 ||
924 ldc.size() > 1 ));
925
926 int64_t* internal_info;
927 if (info.size() == 1) {
928 internal_info = new int64_t[batchCount];
929 }
930 else {
931 internal_info = &info[0];
932 }
933
934 #pragma omp parallel for schedule(dynamic)
935 for (size_t i = 0; i < batchCount; ++i) {
936 Uplo uplo_ = extract<Uplo>( uplo, i );
937 Op trans_ = extract<Op> ( trans, i );
938
939 int64_t n_ = extract<int64_t>(n, i);
940 int64_t k_ = extract<int64_t>(k, i);
941
942 int64_t lda_ = extract<int64_t>(lda, i);
943 int64_t ldb_ = extract<int64_t>(ldb, i);
944 int64_t ldc_ = extract<int64_t>(ldc, i);
945
946 int64_t nrowA_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
947 int64_t nrowB_ = ((trans_ == Op::NoTrans) ^ (layout == Layout::RowMajor)) ? n_ : k_;
948
949 internal_info[i] = 0;
950 if (uplo_ != Uplo::Lower && uplo_ != Uplo::Upper) {
951 internal_info[i] = -2;
952 }
953 else if (trans_ != Op::NoTrans && trans_ != Op::Trans) {
954 internal_info[i] = -3;
955 }
956 else if (n_ < 0) internal_info[i] = -4;
957 else if (k_ < 0) internal_info[i] = -5;
958 else if (lda_ < nrowA_) internal_info[i] = -8;
959 else if (ldb_ < nrowB_) internal_info[i] = -10;
960 else if (ldc_ < n_) internal_info[i] = -13;
961 }
962
963 if (info.size() == 1) {
964 // do a reduction that finds the first argument to encounter an error
965 int64_t lerror = INTERNAL_INFO_DEFAULT;
966 #pragma omp parallel for reduction(max:lerror)
967 for (size_t i = 0; i < batchCount; ++i) {
968 if (internal_info[i] == 0)
969 continue; // skip problems that passed error checks
970 lerror = std::max(lerror, internal_info[i]);
971 }
972 info[0] = (lerror == INTERNAL_INFO_DEFAULT) ? 0 : lerror;
973
974 // delete the internal vector
975 delete[] internal_info;
976
977 // throw an exception if needed
978 blas_error_if_msg( info[0] != 0, "info = %lld", llong( info[0] ) );
979 }
980 else {
981 int64_t info_ = 0;
982 #pragma omp parallel for reduction(+:info_)
983 for (size_t i = 0; i < batchCount; ++i) {
984 info_ += info[i];
985 }
986 blas_error_if_msg( info_ != 0, "One or more non-zero entry in vector info");
987 }
988}
989
990} // namespace batch
991} // namespace blas
992
993#endif // #ifndef BLAS_BATCH_COMMON_HH