ʻOhana
Population structure, admixture history, and selection using learning methods.
jade.lapack.hpp
1 /* -------------------------------------------------------------------------
2  Ohana
3  Copyright (c) 2015-2020 Jade Cheng (\___/)
4  Jade Cheng <info@jade-cheng.com> (='.'=)
5  ------------------------------------------------------------------------- */
6 
7 #ifndef JADE_LAPACK_HPP__
8 #define JADE_LAPACK_HPP__
9 
10 #include "jade.assert.hpp"
11 
12 namespace jade
13 {
14  ///
15  /// A template for a class providing access to LAPACK.
16  ///
17  template <typename TValue>
19  {
20  public:
21  /// The value type.
22  typedef TValue value_type;
23 
24 #ifndef LAPACK_ROW_MAJOR
25 #define LAPACK_ROW_MAJOR 101
26 #endif
27 
28 #ifndef LAPACK_COL_MAJOR
29 #define LAPACK_COL_MAJOR 102
30 #endif
31 
32  /// A type indicating whether the elements of a matrix are in row-major
33  /// or column-major order.
35  {
36  col_major = LAPACK_COL_MAJOR, ///< A column-major order.
37  row_major = LAPACK_ROW_MAJOR, ///< A row-major order.
38  };
39 
40  ///
41  /// Computes the solution to the system of linear equations with a
42  /// square coefficient matrix A and multiple right-hand sides.
43  ///
44  /// Matrices must be in column-major order.
45  ///
46  /// \return Zero if successful; otherwise, non-zero.
47  ///
48  static int gesv(
49  layout_type layout, ///< The matrix layout.
50  int n, ///< The number of equations.
51  int nrhs, ///< The number of right-hand sides.
52  value_type * a, ///< The coefficient matrix.
53  int lda, ///< The leading dimension of a.
54  int * ipiv, ///< The pivot table.
55  value_type * b, ///< The right-hand sides.
56  int ldb) ///< The leading dimension of b.
57  ;
58 
59  ///
60  /// Computes the Cholesky factorization of a symmetric (Hermitian)
61  /// positive-definite matrix.
62  ///
63  /// Matrices must be in column-major order.
64  ///
65  /// \return Zero if successful; otherwise, non-zero.
66  ///
67  static int potrf(
68  layout_type layout, ///< The matrix layout.
69  char uplo, ///< The upper-lower flag.
70  int n, ///< The order of the matrix.
71  value_type * a, ///< The matrix data.
72  int lda) ///< The stride of the matrix.
73  ;
74 
75  ///
76  /// Computes the inverse of a symmetric (Hermitian) positive-definite
77  /// matrix using the Cholesky factorization.
78  ///
79  /// Matrices must be in column-major order.
80  ///
81  /// \return Zero if successful; otherwise, non-zero.
82  ///
83  static int potri(
84  layout_type layout, ///< The matrix layout.
85  char uplo, ///< The upper-lower flag.
86  int n, ///< The order of the matrix.
87  value_type * a, ///< The matrix data.
88  int lda) ///< The stride of the matrix.
89  ;
90 
91  private:
92  // --------------------------------------------------------------------
93  class col_storage
94  {
95  col_storage() = delete;
96  col_storage(const col_storage &) = delete;
97  col_storage & operator = (const col_storage &) = delete;
98 
99  public:
100  // ----------------------------------------------------------------
101  col_storage(
102  value_type ** data_ptr,
103  int * stride_ptr,
104  int rows,
105  int cols)
106  : _data_ptr (data_ptr)
107  , _stride_ptr (stride_ptr)
108  , _data_0 (*data_ptr)
109  , _stride_0 (size_t(*stride_ptr))
110  , _rows (size_t(rows))
111  , _cols (size_t(cols))
112  {
113  assert(nullptr != data_ptr);
114  assert(nullptr != *data_ptr);
115  assert(nullptr != stride_ptr);
116  assert(*stride_ptr > 0);
117  assert(rows > 0);
118  assert(cols > 0);
119 
120  _temp.resize(_rows * _cols);
121  auto t = _temp.data();
122 
123  for (size_t c = 0; c < _cols; c++)
124  for (size_t r = 0; r < _rows; r++)
125  *t++ = _data_0[r * _stride_0 + c];
126 
127  *data_ptr = _temp.data();
128  *stride_ptr = int(_rows);
129  }
130 
131  // ----------------------------------------------------------------
132  ~col_storage()
133  {
134  *_data_ptr = _data_0;
135  *_stride_ptr = int(_stride_0);
136 
137  auto t = _temp.data();
138  for (size_t c = 0; c < _cols; c++)
139  for (size_t r = 0; r < _rows; r++)
140  _data_0[r * _stride_0 + c] = *t++;
141  }
142 
143  private:
144  value_type ** _data_ptr;
145  int * _stride_ptr;
146  value_type * _data_0;
147  size_t _stride_0;
148  size_t _rows;
149  size_t _cols;
150  std::vector<value_type> _temp;
151  };
152 
153  // ----------------------------------------------------------------
154  static std::unique_ptr<col_storage> init_storage(
155  layout_type layout,
156  value_type ** data_ptr,
157  int * stride_ptr,
158  int rows,
159  int cols)
160  {
161  assert(layout == row_major || layout == col_major);
162 
163  std::unique_ptr<col_storage> ptr;
164 
165  if (layout == row_major)
166  {
167  ptr.reset(new col_storage(
168  data_ptr,
169  stride_ptr,
170  rows,
171  cols));
172  }
173 
174  return ptr;
175  }
176  };
177 
178  #ifndef DOXYGEN_IGNORE
179 
180  // ------------------------------------------------------------------------
181  template <>
182  inline int basic_lapack<double>::gesv(
183  layout_type layout,
184  int n,
185  int nrhs,
186  double * a,
187  int lda,
188  int * ipiv,
189  double * b,
190  int ldb)
191  {
192  assert(a != nullptr);
193  assert(b != nullptr);
194 
195 #if defined(JADE_USE_ACCELERATE_FRAMEWORK)
196  const auto a_storage = init_storage(layout, &a, &lda, n, n);
197  const auto b_storage = init_storage(layout, &b, &ldb, n, nrhs);
198  int info;
199  ::dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info);
200  return info;
201 #elif defined(JADE_USE_NETLIB_PACKAGES)
202  return LAPACKE_dgesv(layout, n, nrhs, a, lda, ipiv, b, ldb);
203 #else
204  #error Unsupported build environment
205 #endif
206  }
207 
208  // ------------------------------------------------------------------------
209  template <>
210  inline int basic_lapack<float>::gesv(
211  layout_type layout,
212  int n,
213  int nrhs,
214  float * a,
215  int lda,
216  int * ipiv,
217  float * b,
218  int ldb)
219  {
220  assert(a != nullptr);
221  assert(b != nullptr);
222 
223 #if defined(JADE_USE_ACCELERATE_FRAMEWORK)
224  const auto a_storage = init_storage(layout, &a, &lda, n, n);
225  const auto b_storage = init_storage(layout, &b, &ldb, n, nrhs);
226  int info;
227  (void)::sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info);
228  return info;
229 #elif defined(JADE_USE_NETLIB_PACKAGES)
230  return LAPACKE_sgesv(layout, n, nrhs, a, lda, ipiv, b, ldb);
231 #else
232  #error Unsupported build environment
233 #endif
234  }
235 
236  // ------------------------------------------------------------------------
237  template <>
238  inline int basic_lapack<double>::potrf(
239  layout_type layout,
240  char uplo,
241  int n,
242  double * a,
243  int lda)
244  {
245  assert(a != nullptr);
246 
247 #if defined(JADE_USE_ACCELERATE_FRAMEWORK)
248  const auto storage = init_storage(layout, &a, &lda, n, n);
249  int info;
250  (void)::dpotrf_(&uplo, &n, a, &lda, &info);
251  return info;
252 #elif defined(JADE_USE_NETLIB_PACKAGES)
253  return LAPACKE_dpotrf(layout, uplo, n, a, lda);
254 #else
255  #error Unsupported build environment
256 #endif
257  }
258 
259  // ------------------------------------------------------------------------
260  template <>
261  inline int basic_lapack<float>::potrf(
262  layout_type layout,
263  char uplo,
264  int n,
265  float * a,
266  int lda)
267  {
268  assert(a != nullptr);
269 
270 #if defined(JADE_USE_ACCELERATE_FRAMEWORK)
271  const auto storage = init_storage(layout, &a, &lda, n, n);
272  int info;
273  (void)::spotrf_(&uplo, &n, a, &lda, &info);
274  return info;
275 #elif defined(JADE_USE_NETLIB_PACKAGES)
276  return LAPACKE_spotrf(layout, uplo, n, a, lda);
277 #else
278  #error Unsupported build environment
279 #endif
280  }
281 
282  // ------------------------------------------------------------------------
283  template <>
284  inline int basic_lapack<double>::potri(
285  layout_type layout,
286  char uplo,
287  int n,
288  double * a,
289  int lda)
290  {
291  assert(a != nullptr);
292 
293 #if defined(JADE_USE_ACCELERATE_FRAMEWORK)
294  const auto storage = init_storage(layout, &a, &lda, n, n);
295  int info;
296  (void)::dpotri_(&uplo, &n, a, &lda, &info);
297  return info;
298 #elif defined(JADE_USE_NETLIB_PACKAGES)
299  return LAPACKE_dpotri(layout, uplo, n, a, lda);
300 #else
301  #error Unsupported build environment
302 #endif
303  }
304 
305  // ------------------------------------------------------------------------
306  template <>
307  inline int basic_lapack<float>::potri(
308  layout_type layout,
309  char uplo,
310  int n,
311  float * a,
312  int lda)
313  {
314  assert(a != nullptr);
315 
316 #if defined(JADE_USE_ACCELERATE_FRAMEWORK)
317  const auto storage = init_storage(layout, &a, &lda, n, n);
318  int info;
319  (void)::spotri_(&uplo, &n, a, &lda, &info);
320  return info;
321 #elif defined(JADE_USE_NETLIB_PACKAGES)
322  return LAPACKE_spotri(layout, uplo, n, a, lda);
323 #else
324  #error Unsupported build environment
325 #endif
326  }
327 
328  #endif // DOXYGEN_IGNORE
329 }
330 
331 #endif // JADE_LAPACK_HPP__
jade::basic_lapack::potrf
static int potrf(layout_type layout, char uplo, int n, value_type *a, int lda)
Computes the Cholesky factorization of a symmetric (Hermitian) positive-definite matrix.
jade::basic_lapack::layout_type
layout_type
A type indicating whether the elements of a matrix are in row-major or column-major order.
Definition: jade.lapack.hpp:35
jade::basic_lapack::potri
static int potri(layout_type layout, char uplo, int n, value_type *a, int lda)
Computes the inverse of a symmetric (Hermitian) positive-definite matrix using the Cholesky factoriza...
jade::basic_lapack::gesv
static int gesv(layout_type layout, int n, int nrhs, value_type *a, int lda, int *ipiv, value_type *b, int ldb)
Computes the solution to the system of linear equations with a square coefficient matrix A and multip...
jade::basic_lapack::col_major
@ col_major
A column-major order.
Definition: jade.lapack.hpp:36
jade::basic_lapack::value_type
TValue value_type
The value type.
Definition: jade.lapack.hpp:22
jade::basic_lapack
A template for a class providing access to LAPACK.
Definition: jade.lapack.hpp:19
jade::basic_lapack::row_major
@ row_major
A row-major order.
Definition: jade.lapack.hpp:37