BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
symm.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_SYMM_HH
7#define BLAS_SYMM_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
76
77template <typename TA, typename TB, typename TC>
78void symm(
79 blas::Layout layout,
80 blas::Side side,
81 blas::Uplo uplo,
82 int64_t m, int64_t n,
83 scalar_type<TA, TB, TC> alpha,
84 TA const *A, int64_t lda,
85 TB const *B, int64_t ldb,
86 scalar_type<TA, TB, TC> beta,
87 TC *C, int64_t ldc )
88{
89 using std::swap;
90 using scalar_t = blas::scalar_type<TA, TB>;
91
92 #define A(i_, j_) A[ (i_) + (j_)*lda ]
93 #define B(i_, j_) B[ (i_) + (j_)*ldb ]
94 #define C(i_, j_) C[ (i_) + (j_)*ldc ]
95
96 // constants
97 const scalar_t zero = 0;
98 const scalar_t one = 1;
99
100 // check arguments
101 blas_error_if( layout != Layout::ColMajor &&
102 layout != Layout::RowMajor );
103 blas_error_if( side != Side::Left &&
104 side != Side::Right );
105 blas_error_if( uplo != Uplo::Lower &&
106 uplo != Uplo::Upper &&
107 uplo != Uplo::General );
108 blas_error_if( m < 0 );
109 blas_error_if( n < 0 );
110
111 // adapt if row major
112 if (layout == Layout::RowMajor) {
113 side = (side == Side::Left)
114 ? Side::Right
115 : Side::Left;
116 if (uplo == Uplo::Lower)
117 uplo = Uplo::Upper;
118 else if (uplo == Uplo::Upper)
119 uplo = Uplo::Lower;
120 swap( m, n );
121 }
122
123 // check remaining arguments
124 blas_error_if( lda < ((side == Side::Left) ? m : n) );
125 blas_error_if( ldb < m );
126 blas_error_if( ldc < m );
127
128 // quick return
129 if (m == 0 || n == 0)
130 return;
131
132 // alpha == zero
133 if (alpha == zero) {
134 if (beta == zero) {
135 for (int64_t j = 0; j < n; ++j) {
136 for (int64_t i = 0; i < m; ++i)
137 C(i, j) = zero;
138 }
139 }
140 else if (beta != one) {
141 for (int64_t j = 0; j < n; ++j) {
142 for (int64_t i = 0; i < m; ++i)
143 C(i, j) *= beta;
144 }
145 }
146 return;
147 }
148
149 // alpha != zero
150 if (side == Side::Left) {
151 if (uplo != Uplo::Lower) {
152 // uplo == Uplo::Upper or uplo == Uplo::General
153 for (int64_t j = 0; j < n; ++j) {
154 for (int64_t i = 0; i < m; ++i) {
155
156 scalar_t alpha_Bij = alpha*B(i, j);
157 scalar_t sum = zero;
158
159 for (int64_t k = 0; k < i; ++k) {
160 C(k, j) += A(k, i) * alpha_Bij;
161 sum += A(k, i) * B(k, j);
162 }
163 C(i, j) =
164 beta * C(i, j)
165 + A(i, i) * alpha_Bij
166 + alpha * sum;
167 }
168 }
169 }
170 else {
171 // uplo == Uplo::Lower
172 for (int64_t j = 0; j < n; ++j) {
173 for (int64_t i = m-1; i >= 0; --i) {
174
175 scalar_t alpha_Bij = alpha*B(i, j);
176 scalar_t sum = zero;
177
178 for (int64_t k = i+1; k < m; ++k) {
179 C(k, j) += A(k, i) * alpha_Bij;
180 sum += A(k, i) * B(k, j);
181 }
182 C(i, j) =
183 beta * C(i, j)
184 + A(i, i) * alpha_Bij
185 + alpha * sum;
186 }
187 }
188 }
189 }
190 else { // side == Side::Right
191 if (uplo != Uplo::Lower) {
192 // uplo == Uplo::Upper or uplo == Uplo::General
193 for (int64_t j = 0; j < n; ++j) {
194
195 scalar_t alpha_Akj = alpha * A(j, j);
196
197 for (int64_t i = 0; i < m; ++i)
198 C(i, j) = beta * C(i, j) + B(i, j) * alpha_Akj;
199
200 for (int64_t k = 0; k < j; ++k) {
201 alpha_Akj = alpha*A(k, j);
202 for (int64_t i = 0; i < m; ++i)
203 C(i, j) += B(i, k) * alpha_Akj;
204 }
205
206 for (int64_t k = j+1; k < n; ++k) {
207 alpha_Akj = alpha * A(j, k);
208 for (int64_t i = 0; i < m; ++i)
209 C(i, j) += B(i, k) * alpha_Akj;
210 }
211 }
212 }
213 else {
214 // uplo == Uplo::Lower
215 for (int64_t j = 0; j < n; ++j) {
216
217 scalar_t alpha_Akj = alpha * A(j, j);
218
219 for (int64_t i = 0; i < m; ++i)
220 C(i, j) = beta * C(i, j) + B(i, j) * alpha_Akj;
221
222 for (int64_t k = 0; k < j; ++k) {
223 alpha_Akj = alpha * A(j, k);
224 for (int64_t i = 0; i < m; ++i)
225 C(i, j) += B(i, k) * alpha_Akj;
226 }
227
228 for (int64_t k = j+1; k < n; ++k) {
229 alpha_Akj = alpha*A(k, j);
230 for (int64_t i = 0; i < m; ++i)
231 C(i, j) += B(i, k) * alpha_Akj;
232 }
233 }
234 }
235 }
236
237 #undef A
238 #undef B
239 #undef C
240}
241
242} // namespace blas
243
244#endif // #ifndef BLAS_SYMM_HH
void swap(int64_t n, float *x, int64_t incx, float *y, int64_t incy, blas::Queue &queue)
GPU device, float version.
Definition device_swap.cc:67
void symm(blas::Layout layout, blas::Side side, blas::Uplo uplo, int64_t m, int64_t n, float alpha, float const *A, int64_t lda, float const *B, int64_t ldb, float beta, float *C, int64_t ldc, blas::Queue &queue)
GPU device, float version.
Definition device_symm.cc:106