Commit cd0d2c6f authored by Amelie Royer's avatar Amelie Royer

Setting optional normalization in load_transitions function

parent 1294bd67
......@@ -71,12 +71,12 @@ int main(int argc, char* argv[]) {
if (!data.compare("reco")) {
Recomodel model (datafile_base + ".summary", discount, true);
model.load_rewards(datafile_base + ".rewards");
model.load_transitions(datafile_base + ".transitions", precision, datafile_base + ".profiles");
model.load_transitions(datafile_base + ".transitions", precision, precision, datafile_base + ".profiles");
mainMDP(model, datafile_base, steps, epsilon, precision, verbose);
} else if (!data.compare("maze")) {
Mazemodel model(datafile_base + ".summary", discount);
model.load_rewards(datafile_base + ".rewards");
model.load_transitions(datafile_base + ".transitions", precision);
model.load_transitions(datafile_base + ".transitions", precision, precision);
mainMDP(model, datafile_base, steps, epsilon, precision, verbose);
}
return 0;
......
......@@ -122,7 +122,7 @@ int main(int argc, char* argv[]) {
if (!data.compare("reco")) {
Recomodel model (datafile_base + ".summary", discount, false);
model.load_rewards(datafile_base + ".rewards");
model.load_transitions(datafile_base + ".transitions", precision, datafile_base + ".profiles");
model.load_transitions(datafile_base + ".transitions", precision, precision, datafile_base + ".profiles");
mainMEMDP(model, datafile_base, algo, horizon, steps, epsilon, beliefSize, exp, precision, verbose, true);
} else if (!data.compare("maze")) {
if (discount < 1) {
......@@ -131,7 +131,7 @@ int main(int argc, char* argv[]) {
}
Mazemodel model(datafile_base + ".summary", discount);
model.load_rewards(datafile_base + ".rewards");
model.load_transitions(datafile_base + ".transitions", precision, verbose);
model.load_transitions(datafile_base + ".transitions", precision, precision, verbose);
mainMEMDP(model, datafile_base, algo, horizon, steps, epsilon, beliefSize, exp, precision, verbose, false);
}
return 0;
......
......@@ -434,7 +434,7 @@ void Mazemodel::load_rewards(std::string rfile) {
/**
* LOAD_TRANSITIONS
*/
void Mazemodel::load_transitions(std::string tfile, bool precision /* =false */, bool verbose /* = false */) {
void Mazemodel::load_transitions(std::string tfile, bool precision /* =false */, bool normalization /* =false */, bool verbose /* = false */) {
std::ifstream infile;
std::string line;
std::istringstream iss;
......@@ -504,33 +504,35 @@ void Mazemodel::load_transitions(std::string tfile, bool precision /* =false */,
infile.close();
// Normalization
double nrm;
for (int p = 0; p < n_environments; p++) {
for (size_t state1 = 3; state1 < n_observations; state1++) {
for (size_t action = 0; action < n_actions; action++) {
nrm = 0.0;
// If asking for precision, use kahan summation [slightly slower]
if (precision) {
double kahan_correction = 0.0;
for (size_t state2 = 0; state2 < n_links; state2++) {
double val = transition_matrix[index(p, state1, action, state2)] - kahan_correction;
double aux = nrm + val;
kahan_correction = (aux - nrm) - val;
nrm = aux;
if (normalization) {
double nrm;
for (int p = 0; p < n_environments; p++) {
for (size_t state1 = 3; state1 < n_observations; state1++) {
for (size_t action = 0; action < n_actions; action++) {
nrm = 0.0;
// If asking for precision, use kahan summation [slightly slower]
if (precision) {
double kahan_correction = 0.0;
for (size_t state2 = 0; state2 < n_links; state2++) {
double val = transition_matrix[index(p, state1, action, state2)] - kahan_correction;
double aux = nrm + val;
kahan_correction = (aux - nrm) - val;
nrm = aux;
}
}
// Else basic sum
else{
nrm = std::accumulate(&transition_matrix[index(p, state1, action, 0)],
&transition_matrix[index(p, state1, action, n_links)], 0.);
}
// Normalize (nrm 0 <-> unreachable wall states)
if (nrm > 0.00000001) {
std::transform(&transition_matrix[index(p, state1, action, 0)],
&transition_matrix[index(p, state1, action, n_links)],
&transition_matrix[index(p, state1, action, 0)],
[nrm](const double t){ return t / nrm; }
);
}
}
// Else basic sum
else{
nrm = std::accumulate(&transition_matrix[index(p, state1, action, 0)],
&transition_matrix[index(p, state1, action, n_links)], 0.);
}
// Normalize (nrm 0 <-> unreachable wall states)
if (nrm > 0.00000001) {
std::transform(&transition_matrix[index(p, state1, action, 0)],
&transition_matrix[index(p, state1, action, n_links)],
&transition_matrix[index(p, state1, action, 0)],
[nrm](const double t){ return t / nrm; }
);
}
}
}
......
......@@ -141,7 +141,7 @@ public:
* \param pfile Profiles distribution file.
* \param precision If true, precise normalization is enabled.
*/
void load_transitions(std::string tfile, bool precision=false, bool verbose=false);
void load_transitions(std::string tfile, bool precision=false, bool normalization=false, bool verbose=false);
/*! \brief Returns a given transition probability.
*
......
......@@ -225,7 +225,7 @@ void Recomodel::load_rewards(std::string rfile) {
/**
* LOAD_TRANSITIONS
*/
void Recomodel::load_transitions(std::string tfile, bool precision /* =false */, std::string pfile /* ="" */) {
void Recomodel::load_transitions(std::string tfile, bool precision /* =false */, bool normalization /* =false */, std::string pfile /* ="" */) {
//std::fstream infile;
std::string line;
std::ifstream file, gzfile;
......@@ -254,7 +254,7 @@ void Recomodel::load_transitions(std::string tfile, bool precision /* =false */,
file.open(tfile, std::ios::in);
// If not found try the zipped version
if (!file.is_open()) {
std::cout << ".transitions not found. Searching for .gz alternative" << std::flush;;
std::cout << ".transitions not found. Loading .gz alternative\n" << std::flush;;
gzfile.open(tfile + ".gz", std::ios_base::in | std::ios_base::binary);
in.push(boost::iostreams::gzip_decompressor());
in.push(gzfile);
......@@ -267,6 +267,7 @@ void Recomodel::load_transitions(std::string tfile, bool precision /* =false */,
std::istringstream iss(line);
// Change of environment
if (!(iss >> s1 >> a >> s2 >> v)) {
std::cerr << "\r env " << profiles_found + 1 << " / " << n_environments;
profiles_found += 1;
assert(("Incomplete transition function in current profile in .transitions",
transitions_found == n_observations * n_actions * n_actions));
......@@ -294,33 +295,37 @@ void Recomodel::load_transitions(std::string tfile, bool precision /* =false */,
}
//Normalization
double nrm;
int env_loop = (is_mdp ? 1 : n_environments);
for (int p = 0; p < env_loop; p++) {
for (s1 = 0; s1 < n_observations; s1++) {
for (a = 0; a < n_actions; a++) {
nrm = 0.0;
// If asking for precision, use kahan summation [slightly slower]
if (precision) {
double kahan_correction = 0.0;
for (s2 = 0; s2 < n_actions; s2++) {
double val = transition_matrix[index(p, s1, a, s2)] - kahan_correction;
double aux = nrm + val;
kahan_correction = (aux - nrm) - val;
nrm = aux;
if (normalization) {
std::cout << "Normalization\n";
double nrm;
int env_loop = (is_mdp ? 1 : n_environments);
for (int p = 0; p < env_loop; p++) {
std::cerr << "\r env " << p + 1 << " / " << n_environments;
for (s1 = 0; s1 < n_observations; s1++) {
for (a = 0; a < n_actions; a++) {
nrm = 0.0;
// If asking for precision, use kahan summation [slightly slower]
if (precision) {
double kahan_correction = 0.0;
for (s2 = 0; s2 < n_actions; s2++) {
double val = transition_matrix[index(p, s1, a, s2)] - kahan_correction;
double aux = nrm + val;
kahan_correction = (aux - nrm) - val;
nrm = aux;
}
}
// Else basic sum
else{
nrm = std::accumulate(&transition_matrix[index(p, s1, a, 0)],
&transition_matrix[index(p, s1, a, n_actions)], 0.);
}
// Normalize
std::transform(&transition_matrix[index(p, s1, a, 0)],
&transition_matrix[index(p, s1, a, n_actions)],
&transition_matrix[index(p, s1, a, 0)],
[nrm](const double t){ return t / nrm; }
);
}
// Else basic sum
else{
nrm = std::accumulate(&transition_matrix[index(p, s1, a, 0)],
&transition_matrix[index(p, s1, a, n_actions)], 0.);
}
// Normalize
std::transform(&transition_matrix[index(p, s1, a, 0)],
&transition_matrix[index(p, s1, a, n_actions)],
&transition_matrix[index(p, s1, a, 0)],
[nrm](const double t){ return t / nrm; }
);
}
}
}
......
......@@ -89,7 +89,7 @@ public:
* \param pfile Profiles distribution file.
* \param precision If true, precise normalization is enabled.
*/
void load_transitions(std::string tfile, bool precision=false, std::string pfile="");
void load_transitions(std::string tfile, bool precision=false, bool normalization=false, std::string pfile="");
/*! \brief Returns a given transition probability.
*
......
......@@ -95,7 +95,7 @@ if __name__ == "__main__":
print "\n\n\033[91m-----> Rewards generation\033[0m"
with open("%s.rewards" % output_base, 'w') as f:
for item in actions:
sys.stderr.write(" item: %d / %d \r" % (item + 1, len(actions)))
sys.stderr.write(" item: %d / %d \r" % (item, len(actions)))
f.write("%d\t%.5f\n" % (item, 1))
###### 5. Create transition function
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment