ʻOhana
Population structure, admixture history, and selection using learning methods.
cpax/jade.improver.hpp
1 /* -------------------------------------------------------------------------
2  Ohana
3  Copyright (c) 2015-2020 Jade Cheng (\___/)
4  Jade Cheng <info@jade-cheng.com> (='.'=)
5  ------------------------------------------------------------------------- */
6 
7 #ifndef JADE_IMPROVER_HPP__
8 #define JADE_IMPROVER_HPP__
9 
10 #include "jade.forced_grouping.hpp"
11 #include "jade.lemke.hpp"
12 #include "jade.verification.hpp"
13 
14 namespace jade
15 {
16  ///
17  /// A template for a class that improves the Q and F matrices.
18  ///
19  template <typename TValue>
21  {
22  public:
23  /// The value type.
24  typedef TValue value_type;
25 
26  /// The matrix type.
28 
29  /// The genotype matrix type.
31 
32  /// The forced grouping type.
34 
35  /// The verification type.
37 
38  /// The Lemke type.
40 
41  ///
42  /// \return A new-and-improved F matrix.
43  ///
45  const genotype_matrix_type & g, ///< The G matrix.
46  const matrix_type & q, ///< The Q matrix.
47  const matrix_type & fa, ///< The F matrix.
48  const matrix_type & fb, ///< The 1-F matrix.
49  const matrix_type & qfa, ///< The Q*F matrix.
50  const matrix_type & qfb, ///< The Q*(1-F) matrix.
51  const matrix_type * fif, ///< The Fin-force matrix.
52  const bool frb) ///< Using frequency-bounds.
53  {
54  assert(verification_type::validate_gqf_sizes(g, q, fa));
55  assert(verification_type::validate_gqf_sizes(g, q, fb));
58  assert(nullptr == fif || !frb);
59 
60  const auto I = g.get_height();
61  const auto K = fa.get_height();
62  const auto J = fa.get_width();
63  assert(nullptr == fif || verification_type::
64  validate_fif_size(*fif, K, J));
65 
66  matrix_type f_dst (K, J);
67 
68  matrix_type shift_vec (K, 1);
69  shift_vec.set_values(1);
70 
71  matrix_type derivative_vec (K, 1);
72  matrix_type hessian_mat (K, K);
73 
74  const auto frb_delta = value_type(1.0) /
75  (value_type(2 * I) + value_type(1.0));
76 
77  for (size_t j = 0; j < J; j++)
78  {
79  const auto f_column = fa.copy_column(j);
80 
82  q,
83  fa,
84  fb,
85  qfa,
86  qfb,
87  j,
88  derivative_vec,
89  hessian_mat);
90 
91  const auto a_mat = _create_a_mat(K, false);
92 
93  auto b_vec = _create_b_vec(f_column, a_mat, shift_vec, false);
94  if (nullptr != fif)
95  {
96  for (size_t k = 0; k < fif->get_height(); k++)
97  {
98  b_vec[k + 0] = value_type(0);
99  b_vec[k + K] = value_type(0);
100  }
101  }
102  else if (frb)
103  {
104  for (size_t k = 0; k < K; k++)
105  {
106  b_vec[k + 0] -= frb_delta;
107  b_vec[k + K] -= frb_delta;
108  }
109  }
110 
111  const auto sqp_q = -hessian_mat;
112  const auto sqp_a = -a_mat;
113  const auto sqp_c = (hessian_mat * shift_vec) - derivative_vec;
114  const auto sqp_b = -b_vec;
115 
116  matrix_type shifted_delta_vec;
117  if (lemke_type::solve(
118  shifted_delta_vec,
119  sqp_q, // shifted QP's "Q" matrix
120  sqp_a, // shifted QP's "A" matrix
121  sqp_c, // shifted QP's "c" vector
122  sqp_b)) // shifted QP's "b" vector
123  {
124  assert(shifted_delta_vec.get_height() == 3 * K);
125  for (size_t k = 0; k < K; k++)
126  f_dst(k, j) = f_column[k]
127  + shifted_delta_vec[k]
128  - value_type(1);
129  }
130  else
131  {
132  for (size_t k = 0; k < K; k++)
133  f_dst(k, j) = f_column[k];
134  }
135  }
136 
137  return f_dst;
138  }
139 
140  ///
141  /// \return A new-and-improved Q matrix.
142  ///
144  const genotype_matrix_type & g, ///< The G matrix.
145  const matrix_type & q, ///< The Q matrix.
146  const matrix_type & fa, ///< The F matrix.
147  const matrix_type & fb, ///< The 1-F matrix.
148  const matrix_type & qfa, ///< The Q*F matrix.
149  const matrix_type & qfb, ///< The Q*(1-F) matrix.
150  const forced_grouping_type * fg) ///< The force-grouping.
151  {
152  assert(verification_type::validate_gqf_sizes(g, q, fa));
153  assert(verification_type::validate_gqf_sizes(g, q, fb));
155  assert(verification_type::validate_f(fa));
156 
157  const auto I = q.get_height();
158  const auto K = q.get_width();
159 
160  matrix_type q_dst (I, K);
161 
162  matrix_type shift_vec (K, 1);
163  shift_vec.set_values(1);
164 
165  matrix_type derivative_vec (K, 1);
166  matrix_type hessian_mat (K, K);
167 
168  for (size_t i = 0; i < I; i++)
169  {
170  const auto q_row = q.copy_row(i);
171 
173  q,
174  fa,
175  fb,
176  qfa,
177  qfb,
178  i,
179  derivative_vec,
180  hessian_mat);
181 
182  const auto a_mat = _create_a_mat(K, true);
183 
184  auto b_vec = _create_b_vec(q_row, a_mat, shift_vec, true);
185  if (nullptr != fg)
186  {
187  for (size_t k = 0; k < K; k++)
188  {
189  b_vec[k + 0] -= fg->get_min(i, k);
190  b_vec[k + K] += fg->get_max(i, k) - value_type(1);
191  }
192  }
193 
194  const auto sqp_q = -hessian_mat;
195  const auto sqp_a = -a_mat;
196  const auto sqp_c = (hessian_mat * shift_vec) - derivative_vec;
197  const auto sqp_b = -b_vec;
198 
199  matrix_type shifted_delta_vec;
200  if (lemke_type::solve(
201  shifted_delta_vec,
202  sqp_q, // shifted QP's "Q" matrix
203  sqp_a, // shifted QP's "A" matrix
204  sqp_c, // shifted QP's "c" vector
205  sqp_b)) // shifted QP's "b" vector
206  {
207  assert(shifted_delta_vec.get_height() == 3 * K);
208  for (size_t k = 0; k < K; k++)
209  q_dst(i, k) = q_row[k]
210  + shifted_delta_vec[k]
211  - value_type(1);
212  }
213  else
214  {
215  for (size_t k = 0; k < K; k++)
216  q_dst(i, k) = q_row[k];
217  }
218 
219  static const auto epsilon = value_type(1.0e-6);
220  static const auto min = value_type(0.0) + epsilon;
221  static const auto max = value_type(1.0) - epsilon;
222  q_dst.clamp_row(i, min, max);
223 
224  const auto sum = q_dst.get_row_sum(i);
225  q_dst.multiply_row(i, value_type(1) / sum);
226  }
227 
228  return q_dst;
229  }
230 
231  private:
232  // --------------------------------------------------------------------
233  static matrix_type _create_a_mat(
234  const size_t K,
235  const bool is_padded)
236  {
237  matrix_type c_mat (K + K + (is_padded ? 2 : 0), K);
238 
239  for (size_t k = 0; k < K; k++)
240  {
241  c_mat(k + 0, k) = value_type(-1);
242  c_mat(K + k, k) = value_type(+1);
243  }
244 
245  if (is_padded)
246  {
247  for (size_t k = 0; k < K; k++)
248  {
249  c_mat(K + K + 0, k) = value_type(+1);
250  c_mat(K + K + 1, k) = value_type(-1);
251  }
252  }
253 
254  return c_mat;
255  }
256 
257  // --------------------------------------------------------------------
258  static matrix_type _create_b_vec(
259  const matrix_type & current_values,
260  const matrix_type & a_mat,
261  const matrix_type & shift_vec,
262  const bool is_padded)
263  {
264  assert(current_values.is_vector());
265 
266  const auto K = current_values.get_length();
267 
268  matrix_type b_vec (K + K + (is_padded ? 2 : 0), 1);
269 
270  for (size_t k = 0; k < K; k++)
271  {
272  b_vec[k + 0] = value_type(0);
273  b_vec[k + K] = value_type(1);
274  }
275 
276  if (is_padded)
277  {
278  b_vec[K + K + 0] = value_type(+1);
279  b_vec[K + K + 1] = value_type(-1);
280  }
281 
282  b_vec -= a_mat * current_values.create_transpose();
283  b_vec += a_mat * shift_vec;
284 
285  return b_vec;
286  }
287  };
288 }
289 
290 #endif // JADE_IMPROVER_HPP__
jade::basic_lemke::solve
bool solve()
Executes the algorithm until it has completed or has aborted.
Definition: jade.lemke.hpp:256
jade::basic_verification
A template for a class that performs validation on various types of matrices.
Definition: jade.verification.hpp:20
jade::basic_improver::genotype_matrix_type
basic_genotype_matrix< value_type > genotype_matrix_type
The genotype matrix type.
Definition: cpax/jade.improver.hpp:30
jade::basic_forced_grouping::get_max
value_type get_max(const size_t i, const size_t k) const
Definition: cpax/jade.forced_grouping.hpp:135
jade::basic_matrix::get_width
size_t get_width() const
Definition: jade.matrix.hpp:757
jade::basic_lemke
A template for a class that implements Lemke's algorithm.
Definition: jade.lemke.hpp:19
jade::basic_improver::improve_f
static matrix_type improve_f(const genotype_matrix_type &g, const matrix_type &q, const matrix_type &fa, const matrix_type &fb, const matrix_type &qfa, const matrix_type &qfb, const matrix_type *fif, const bool frb)
Definition: cpax/jade.improver.hpp:44
jade::basic_genotype_matrix
A template for an abstract class implementing operations for a genotype matrix.
Definition: jade.genotype_matrix.hpp:26
jade::basic_forced_grouping::get_min
value_type get_min(const size_t i, const size_t k) const
Definition: cpax/jade.forced_grouping.hpp:148
jade::basic_improver::value_type
TValue value_type
The value type.
Definition: cpax/jade.improver.hpp:24
jade::basic_verification::validate_f
static bool validate_f(const matrix_type &f)
Validates the F matrix and throws an exception if validation fails.
Definition: jade.verification.hpp:73
jade::basic_improver::verification_type
basic_verification< value_type > verification_type
The verification type.
Definition: cpax/jade.improver.hpp:36
jade::basic_improver::improve_q
static matrix_type improve_q(const genotype_matrix_type &g, const matrix_type &q, const matrix_type &fa, const matrix_type &fb, const matrix_type &qfa, const matrix_type &qfb, const forced_grouping_type *fg)
Definition: cpax/jade.improver.hpp:143
jade::basic_matrix::get_height
size_t get_height() const
Definition: jade.matrix.hpp:603
jade::basic_verification::validate_gqf_sizes
static bool validate_gqf_sizes(const genotype_matrix_type &g, const matrix_type &q, const matrix_type &f)
Validates the sizes of the G, Q, and F matrices and throws an exception if validation fails.
Definition: jade.verification.hpp:211
jade::basic_matrix::multiply_row
void multiply_row(const size_t row, const value_type value)
Multiplies a row by a specified value.
Definition: jade.matrix.hpp:976
jade::basic_improver::forced_grouping_type
basic_forced_grouping< value_type > forced_grouping_type
The forced grouping type.
Definition: cpax/jade.improver.hpp:33
jade::basic_improver::matrix_type
basic_matrix< value_type > matrix_type
The matrix type.
Definition: cpax/jade.improver.hpp:27
jade::basic_verification::validate_q
static bool validate_q(const matrix_type &q)
Validates the Q matrix and throws an exception if validation fails.
Definition: jade.verification.hpp:225
jade::basic_genotype_matrix::get_height
virtual size_t get_height() const =0
jade::basic_improver::lemke_type
basic_lemke< value_type > lemke_type
The Lemke type.
Definition: cpax/jade.improver.hpp:39
jade::basic_matrix::get_row_sum
value_type get_row_sum(const size_t row) const
Definition: jade.matrix.hpp:716
jade::basic_matrix::set_values
void set_values(const value_type value)
Sets all values of the matrix to the specified value.
Definition: jade.matrix.hpp:1189
jade::basic_genotype_matrix::compute_derivatives_q
virtual void compute_derivatives_q(const matrix_type &q, const matrix_type &fa, const matrix_type &fb, const matrix_type &qfa, const matrix_type &qfb, const size_t i, matrix_type &d_vec, matrix_type &h_mat) const =0
Computes the derivative vector and hessian matrix for a specified individual of the Q matrix.
jade::basic_forced_grouping
A template for a class that implements the forced grouping feature.
Definition: cpax/jade.forced_grouping.hpp:19
jade::basic_matrix::copy_column
basic_matrix copy_column(const size_t column) const
Definition: jade.matrix.hpp:230
jade::basic_matrix< value_type >
jade::basic_matrix::copy_row
basic_matrix copy_row(const size_t row) const
Definition: jade.matrix.hpp:312
jade::basic_genotype_matrix::compute_derivatives_f
virtual void compute_derivatives_f(const matrix_type &q, const matrix_type &fa, const matrix_type &fb, const matrix_type &qfa, const matrix_type &qfb, const size_t j, matrix_type &d_vec, matrix_type &h_mat) const =0
Computes the derivative vector and hessian matrix for a specified marker of the F matrix.
jade::basic_matrix::clamp_row
void clamp_row(const size_t row, const value_type min, const value_type max)
Clamps all values in a row to the specified range.
Definition: jade.matrix.hpp:178
jade::basic_improver
A template for a class that improves the Q and F matrices.
Definition: cpax/jade.improver.hpp:21