BLAS++ 2024.05.31
BLAS C++ API
Loading...
Searching...
No Matches
syr2.hh
1// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
2// SPDX-License-Identifier: BSD-3-Clause
3// This program is free software: you can redistribute it and/or modify it under
4// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5
6#ifndef BLAS_SYR2_HH
7#define BLAS_SYR2_HH
8
9#include "blas/util.hh"
10
11#include <limits>
12
13namespace blas {
14
15// =============================================================================
61
62template <typename TA, typename TX, typename TY>
63void syr2(
64 blas::Layout layout,
65 blas::Uplo uplo,
66 int64_t n,
67 blas::scalar_type<TA, TX, TY> alpha,
68 TX const *x, int64_t incx,
69 TY const *y, int64_t incy,
70 TA *A, int64_t lda )
71{
72 typedef blas::scalar_type<TA, TX, TY> scalar_t;
73
74 #define A(i_, j_) A[ (i_) + (j_)*lda ]
75
76 // constants
77 const scalar_t zero = 0;
78
79 // check arguments
80 blas_error_if( layout != Layout::ColMajor &&
81 layout != Layout::RowMajor );
82 blas_error_if( uplo != Uplo::Lower &&
83 uplo != Uplo::Upper );
84 blas_error_if( n < 0 );
85 blas_error_if( incx == 0 );
86 blas_error_if( incy == 0 );
87 blas_error_if( lda < n );
88
89 // quick return
90 if (n == 0 || alpha == zero)
91 return;
92
93 // for row major, swap lower <=> upper
94 if (layout == Layout::RowMajor) {
95 uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
96 }
97
98 int64_t kx = (incx > 0 ? 0 : (-n + 1)*incx);
99 int64_t ky = (incy > 0 ? 0 : (-n + 1)*incy);
100 if (uplo == Uplo::Upper) {
101 if (incx == 1 && incy == 1) {
102 // unit stride
103 for (int64_t j = 0; j < n; ++j) {
104 // note: NOT skipping if x[j] or y[j] is zero, for consistent NAN handling
105 scalar_t tmp1 = alpha * y[j];
106 scalar_t tmp2 = alpha * x[j];
107 for (int64_t i = 0; i <= j; ++i) {
108 A(i, j) += x[i]*tmp1 + y[i]*tmp2;
109 }
110 }
111 }
112 else {
113 // non-unit stride
114 int64_t jx = kx;
115 int64_t jy = ky;
116 for (int64_t j = 0; j < n; ++j) {
117 scalar_t tmp1 = alpha * y[jy];
118 scalar_t tmp2 = alpha * x[jx];
119 int64_t ix = kx;
120 int64_t iy = ky;
121 for (int64_t i = 0; i <= j; ++i) {
122 A(i, j) += x[ix]*tmp1 + y[iy]*tmp2;
123 ix += incx;
124 iy += incy;
125 }
126 jx += incx;
127 jy += incy;
128 }
129 }
130 }
131 else {
132 // lower triangle
133 if (incx == 1 && incy == 1) {
134 // unit stride
135 for (int64_t j = 0; j < n; ++j) {
136 scalar_t tmp1 = alpha * y[j];
137 scalar_t tmp2 = alpha * x[j];
138 for (int64_t i = j; i < n; ++i) {
139 A(i, j) += x[i]*tmp1 + y[i]*tmp2;
140 }
141 }
142 }
143 else {
144 // non-unit stride
145 int64_t jx = kx;
146 int64_t jy = ky;
147 for (int64_t j = 0; j < n; ++j) {
148 scalar_t tmp1 = alpha * y[jy];
149 scalar_t tmp2 = alpha * x[jx];
150 int64_t ix = jx;
151 int64_t iy = jy;
152 for (int64_t i = j; i < n; ++i) {
153 A(i, j) += x[ix]*tmp1 + y[iy]*tmp2;
154 ix += incx;
155 iy += incy;
156 }
157 jx += incx;
158 jy += incy;
159 }
160 }
161 }
162
163 #undef A
164}
165
166} // namespace blas
167
168#endif // #ifndef BLAS_SYR2_HH
void syr2(blas::Layout layout, blas::Uplo uplo, int64_t n, blas::scalar_type< TA, TX, TY > alpha, TX const *x, int64_t incx, TY const *y, int64_t incy, TA *A, int64_t lda)
Symmetric matrix rank-2 update:
Definition syr2.hh:63