ʻOhana
Population structure, admixture history, and selection using learning methods.
jade.blas.hpp
1 /* -------------------------------------------------------------------------
2  Ohana
3  Copyright (c) 2015-2020 Jade Cheng (\___/)
4  Jade Cheng <info@jade-cheng.com> (='.'=)
5  ------------------------------------------------------------------------- */
6 
7 #ifndef JADE_BLAS_HPP__
8 #define JADE_BLAS_HPP__
9 
10 #include "jade.assert.hpp"
11 
12 namespace jade
13 {
14  ///
15  /// A template for a class providing access to BLAS.
16  ///
17  template <typename TValue>
18  class basic_blas
19  {
20  public:
21  /// The value type.
22  typedef TValue value_type;
23 
24  /// A type indicating whether the elements of a matrix are in row-major
25  /// or column-major order.
26  typedef CBLAS_ORDER layout_type;
27 
28  /// A type indicating a kind of transpose operation to perform on a
29  /// matrix.
30  typedef CBLAS_TRANSPOSE transpose_type;
31 
32  ///
33  /// Computes a vector-vector dot product.
34  ///
35  /// \return The result of the dot product of x and y, if n is positive.
36  /// Otherwise, returns 0.
37  ///
38  static value_type dot(
39  const int n, ///< The number of elements.
40  const value_type * x, ///< The X array.
41  const int incx, ///< The X increment.
42  const value_type * y, ///< The Y array.
43  const int incy) ///< The Y increment.
44  ;
45 
46  ///
47  /// Computes a matrix-matrix product with general matrices.
48  ///
49  static void gemm(
50  const layout_type Order, ///< The order.
51  const transpose_type TransA, ///< Transpose A.
52  const transpose_type TransB, ///< Transpose B.
53  const int M, ///< Rows of A.
54  const int N, ///< Columns of B.
55  const int K, ///< The order.
56  const value_type alpha, ///< Alpha scalar.
57  const value_type * A, ///< A matrix.
58  const int lda, ///< Stride of A.
59  const value_type * B, ///< B matrix.
60  const int ldb, ///< Stride of B.
61  const value_type beta, ///< Beta scalar.
62  value_type * C, ///< C matrix.
63  const int ldc) ///< Stride of C.
64  ;
65 
66  ///
67  /// Computes a matrix-vector product using a general matrix.
68  ///
69  static void gemv(
70  const layout_type order, ///< The layout.
71  const transpose_type trans, ///< Transpose.
72  const int m, ///< Rows of A.
73  const int n, ///< Columns of A.
74  const value_type alpha, ///< Alpha scalar.
75  const value_type * a, ///< A matrix.
76  const int lda, ///< Stride of A.
77  const value_type * x, ///< X vector.
78  const int incx, ///< Stride of X.
79  const value_type beta, ///< Beta scalar.
80  value_type * y, ///< Y vector.
81  const int incy) ///< Stride of Y.
82  ;
83  };
84 
85  #ifndef DOXYGEN_IGNORE
86 
87  // ------------------------------------------------------------------------
88  template <>
89  inline double basic_blas<double>::dot(
90  const int n,
91  const double * x,
92  const int incx,
93  const double * y,
94  const int incy)
95  {
96  assert(x != nullptr);
97  assert(y != nullptr);
98 
99  return ::cblas_ddot(n, x, incx, y, incy);
100  }
101 
102  // ------------------------------------------------------------------------
103  template <>
104  inline float basic_blas<float>::dot(
105  const int n,
106  const float * x,
107  const int incx,
108  const float * y,
109  const int incy)
110  {
111  assert(x != nullptr);
112  assert(y != nullptr);
113 
114  return ::cblas_sdot(n, x, incx, y, incy);
115  }
116 
117  // ------------------------------------------------------------------------
118  template <>
119  inline void basic_blas<double>::gemm(
120  const layout_type Order,
121  const transpose_type TransA,
122  const transpose_type TransB,
123  const int M,
124  const int N,
125  const int K,
126  const double alpha,
127  const double * A,
128  const int lda,
129  const double * B,
130  const int ldb,
131  const double beta,
132  double * C,
133  const int ldc)
134  {
135  assert(A != nullptr);
136  assert(B != nullptr);
137  assert(C != nullptr);
138 
139  ::cblas_dgemm(
140  Order, TransA, TransB, M, N, K, alpha,
141  A, lda, B, ldb, beta, C, ldc);
142  }
143 
144  // ------------------------------------------------------------------------
145  template <>
146  inline void basic_blas<float>::gemm(
147  const layout_type Order,
148  const transpose_type TransA,
149  const transpose_type TransB,
150  const int M,
151  const int N,
152  const int K,
153  const float alpha,
154  const float * A,
155  const int lda,
156  const float * B,
157  const int ldb,
158  const float beta,
159  float * C,
160  const int ldc)
161  {
162  assert(A != nullptr);
163  assert(B != nullptr);
164  assert(C != nullptr);
165 
166  ::cblas_sgemm(
167  Order, TransA, TransB, M, N, K, alpha,
168  A, lda, B, ldb, beta, C, ldc);
169  }
170 
171  // ------------------------------------------------------------------------
172  template <>
173  inline void basic_blas<double>::gemv(
174  const layout_type order,
175  const transpose_type trans,
176  const int m,
177  const int n,
178  const double alpha,
179  const double * a,
180  const int lda,
181  const double * x,
182  const int incx,
183  const double beta,
184  double * y,
185  const int incy)
186  {
187  assert(a != nullptr);
188  assert(x != nullptr);
189  assert(y != nullptr);
190 
191  ::cblas_dgemv(order, trans, m, n, alpha,
192  a, lda, x, incx, beta, y, incy);
193  }
194 
195  // ------------------------------------------------------------------------
196  template <>
197  inline void basic_blas<float>::gemv(
198  const layout_type order,
199  const transpose_type trans,
200  const int m,
201  const int n,
202  const float alpha,
203  const float * a,
204  const int lda,
205  const float * x,
206  const int incx,
207  const float beta,
208  float * y,
209  const int incy)
210  {
211  assert(a != nullptr);
212  assert(x != nullptr);
213  assert(y != nullptr);
214 
215  ::cblas_sgemv(order, trans, m, n, alpha,
216  a, lda, x, incx, beta, y, incy);
217  }
218 
219  #endif // DOXYGEN_IGNORE
220 }
221 
222 #endif // JADE_BLAS_HPP__
jade::basic_blas::value_type
TValue value_type
The value type.
Definition: jade.blas.hpp:22
jade::basic_blas::dot
static value_type dot(const int n, const value_type *x, const int incx, const value_type *y, const int incy)
Computes a vector-vector dot product.
jade::basic_blas
A template for a class providing access to BLAS.
Definition: jade.blas.hpp:19
jade::basic_blas::gemm
static void gemm(const layout_type Order, const transpose_type TransA, const transpose_type TransB, const int M, const int N, const int K, const value_type alpha, const value_type *A, const int lda, const value_type *B, const int ldb, const value_type beta, value_type *C, const int ldc)
Computes a matrix-matrix product with general matrices.
jade::basic_blas::gemv
static void gemv(const layout_type order, const transpose_type trans, const int m, const int n, const value_type alpha, const value_type *a, const int lda, const value_type *x, const int incx, const value_type beta, value_type *y, const int incy)
Computes a matrix-vector product using a general matrix.
jade::basic_blas::layout_type
CBLAS_ORDER layout_type
A type indicating whether the elements of a matrix are in row-major or column-major order.
Definition: jade.blas.hpp:26
jade::basic_blas::transpose_type
CBLAS_TRANSPOSE transpose_type
A type indicating a kind of transpose operation to perform on a matrix.
Definition: jade.blas.hpp:30