Commit 58082cec authored by Amelie Royer's avatar Amelie Royer

adding kahan summation for precise probabilities in MDP

parent bab6ed28
......@@ -36,7 +36,6 @@ double transition_matrix [n_observations][n_actions][n_actions] = {0};
double rewards [n_observations][n_actions];
/*! \brief Loads the Model parameters from the precomputed data files.
*
* \param tfile Full path to the base_name.transitions file.
......@@ -45,7 +44,7 @@ double rewards [n_observations][n_actions];
* \param precision Maximum precision while reading stored probabilities.
*/
void load_model_parameters(std::string tfile, std::string rfile,
std::string pfile, std::string sfile, double precision) {
std::string pfile, std::string sfile, bool precision) {
// Variables
std::ifstream infile;
......@@ -102,7 +101,6 @@ void load_model_parameters(std::string tfile, std::string rfile,
}
}
// Accumulate
if (precision > 1) { v = std::trunc(v * precision); }
link = is_connected(s1, s2);
assert(("Unfeasible transition with >0 probability", link < n_actions));
transition_matrix[s1][a - 1][link] += v;
......@@ -114,15 +112,80 @@ void load_model_parameters(std::string tfile, std::string rfile,
double nrm;
for (s1 = 0; s1 < n_observations; s1++) {
for (a = 0; a < n_actions; a++) {
//double nrm = normalization[s1][a];
nrm = std::accumulate(transition_matrix[s1][a],
transition_matrix[s1][a] + n_actions, 0.0);
// If asking for precision, use kahan summation [slightly slower]
if (precision) {
double kahan_correction = 0.0;
nrm = 0.0;
for (s2 = 0; s2 < n_actions; s2++) {
double val = transition_matrix[s1][a][s2] - kahan_correction;
double aux = nrm + val;
kahan_correction = (aux - nrm) - val;
nrm = aux;
}
}
// Else basic sum
else {
nrm = std::accumulate(transition_matrix[s1][a],
transition_matrix[s1][a] + n_actions, 0.);
}
// Normalize
std::transform(transition_matrix[s1][a],
transition_matrix[s1][a] + n_actions,
transition_matrix[s1][a],
[nrm](const double t){ return t / nrm; }
);
);
//double kahan_sum = 0
//KahanAccumulation init = {0.0};
//nrm = (std::accumulate(transition_matrix[s1][a],
// transition_matrix[s1][a] + n_actions, init, KahanSum)).sum;
//std::transform(transition_matrix[s1][a],
// transition_matrix[s1][a] + n_actions,
// transition_matrix[s1][a],
// [nrm](const double t){ return ((t/nrm < 0.0001) ? 0 : t / nrm); }
// );
//double nrm = normalization[s1][a];
//nrm = std::accumulate(numbers.begin(), numbers.end(), init, KahanSum);
// KahanAccumulation init = {0.0};
// nrm = (std::accumulate(transition_matrix[s1][a],
// transition_matrix[s1][a] + n_actions, init, KahanSum)).sum;
// std::transform(transition_matrix[s1][a],
// transition_matrix[s1][a] + n_actions,
// transition_matrix[s1][a],
// [nrm](const double t){ return ((t/nrm < 0.0001) ? 0 : t / nrm); }
// );
// init = {0.0};
// nrm = (std::accumulate(transition_matrix[s1][a],
// transition_matrix[s1][a] + n_actions, init, KahanSum)).sum;
//nrm = std::accumulate(transition_matrix[s1][a],
//transition_matrix[s1][a] + n_actions, 0.0);
// std::transform(transition_matrix[s1][a],
// transition_matrix[s1][a] + n_actions,
// transition_matrix[s1][a],
// [nrm](const double t){ return t / nrm; }
// );
// nrm = std::accumulate(transition_matrix[s1][a],
// transition_matrix[s1][a] + n_actions, 0.0);
/// if (AIToolbox::checkDifferentSmall(1.0, nrm)) {
// std::cout << "it is working though";
// std::cout << std::setprecision(15) << nrm << "wtf";
// }
/*std::transform(transition_matrix[s1][a],
transition_matrix[s1][a] + n_actions,
transition_matrix[s1][a],
[nrm](const double t){ return ((AIToolbox::checkDifferentSmall(t / nrm, 0.0)) ? t / nrm : 0.0); }
);
nrm = std::accumulate(transition_matrix[s1][a],
transition_matrix[s1][a] + n_actions, 0.0);
std::transform(transition_matrix[s1][a],
transition_matrix[s1][a] + n_actions,
transition_matrix[s1][a],
[nrm](const double t){ return t / nrm; }*
);*/
//double test = 0;
//for (size_t s2 = 0; s2 < n_actions; s2++) {
// test += transition_matrix[s1][a][s2];
......@@ -245,8 +308,9 @@ int main(int argc, char* argv[]) {
assert(("Unvalid steps parameter", steps > 0));
float epsilon = ((argc > 4) ? std::atof(argv[4]) : 0.01);
assert(("Unvalid epsilon parameter", epsilon >= 0));
int precision = ((argc > 5) ? std::atoi(argv[5]) : 10);
assert(("Unvalid precision parameter", precision >= 0));
bool precision = ((argc > 5) ? (atoi(argv[5]) == 1) : false);
//std::cout << precision;
//return 0;
// Load model parameters
auto start = std::chrono::high_resolution_clock::now();
......@@ -256,7 +320,7 @@ int main(int argc, char* argv[]) {
load_model_parameters(datafile_base + ".transitions",
datafile_base + ".rewards",
datafile_base + ".profiles",
datafile_base + ".summary", std::pow(10, precision));
datafile_base + ".summary", std::pow(10, precision));
auto elapsed = std::chrono::high_resolution_clock::now() - start;
double loading_time = std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count() / 1000000.;
......@@ -273,7 +337,8 @@ int main(int argc, char* argv[]) {
start = std::chrono::high_resolution_clock::now();
RecoMDP world;
std::cout << "\n" << current_time_str() << " - Copying model [sparse]...!\n";
AIToolbox::MDP::SparseModel model(world);
//AIToolbox::MDP::SparseModel model(world);
AIToolbox::MDP::SparseModel model(world, precision);
// Solve
std::cout << current_time_str() << " - Init solver...!\n";
......
......@@ -23,7 +23,7 @@ HORIZON="1"
COMPILE=false
# SET ARGUMENTS FROM CMD LINE
while getopts "m:d:n:k:g:s:p:h:e:x:b:c" opt; do
while getopts "m:d:n:k:g:s:h:e:x:b:cp" opt; do
case $opt in
m)
MODE=$OPTARG
......@@ -44,7 +44,7 @@ while getopts "m:d:n:k:g:s:p:h:e:x:b:c" opt; do
STEPS=$OPTARG
;;
p)
PRECISION=$OPTARG
PRECISION=1
;;
h)
HORIZON=$OPTARG
......
......@@ -26,6 +26,7 @@
//#include <AIToolbox/POMDP/Algorithms/MEMCP.hpp>
#include <AIToolbox/POMDP/Types.hpp>
#include <AIToolbox/Types.hpp>
#include <AIToolbox/Utils.hpp>
/*!
......@@ -54,6 +55,9 @@ static const int NPROFILES =
#endif /*!< Number of environments in the MEMDP */
/*!
* Global variables
*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment