Commit 1f4b81b8 authored by Amelie Royer's avatar Amelie Royer

Adding explicit mdp fake environment to prepare_synth

parent 897a28bd
......@@ -13,7 +13,7 @@ import sys, os
import gzip
import argparse
from random import randint
from utils import ChunkedWriter, Logger, init_base_writing, get_nstates, get_next_state_id
from utils import Logger, init_base_writing, get_nstates, get_next_state_id
if sys.version_info[0] == 3:
xrange = range
......@@ -93,7 +93,7 @@ if __name__ == "__main__":
s2 = get_next_state_id(session[-1], a)
session.append(a)
session.append(s2)
f.write("%d\t%d\t%s\n" % (user, cluster, ' '.join(str(x) for x in session) ))
f.write("%d\t%d\t%s\n" % (user, cluster + 1, ' '.join(str(x) for x in session) )) # Cluster 0 is the average transition probability for the MDP ==> not a user, no test sequences
###### 4. Set rewards
print("\n\n\033[91m-----> Rewards generation\033[0m")
......@@ -107,6 +107,39 @@ if __name__ == "__main__":
f = gzip.open("%s.transitions.gz" % output_base, 'w') if args.zip else open("%s.transitions" % output_base, 'w')
buffer_size = 2**31 -1
print("\n\n\033[91m-----> Probability inference\033[0m")
total_count = n_users * (exc + n_items - 1)
transitions_str = ""
#### Define the average transition probability for the MDP model
# For fixed s1
for s1 in xrange(n_states):
sys.stderr.write(" state: %d / %d \r" % (s1 + 1, n_states))
sys.stderr.flush()
# For fixed a
for a in actions:
# Positive P(s1 -a-> s1.a)
count = exc + n_users - 1
new_count = args.alpha * count if not args.norm else args.alpha * count / total_count
assert (args.alpha * count < total_count), "AssertionError: alpha parameter too large. Probabilities out of range."
s2 = get_next_state_id(s1, a)
transitions_str += "%d\t%d\t%d\t%s\n" % (s1, a, s2, new_count)
# Negative P(s1 -a-> s1.b), b!= a
beta = (total_count - args.alpha * count) / (total_count - count)
# For every s2, sample T(s1, a, s2)
for s2_link in actions:
if (s2_link != a):
s2 = get_next_state_id(s1, s2_link)
count = exc + n_users - 1
transitions_str += "%d\t%d\t%d\t%s\n" % (s1, a, s2, beta * count if not args.norm else beta * count / total_count)
# If buffer overflows, write in file
if len(transitions_str) > buffer_size:
f.write(bytes(transitions_str.encode("UTF-8")) if args.zip else transitions_str)
transitions_str = ""
# Environment change
transitions_str += "\n"
f.write(bytes(transitions_str.encode("UTF-8")) if args.zip else transitions_str)
#### Define the transition probability for each environment
total_count = exc + n_items - 1
transitions_str = ""
for user_profile in xrange(n_users):
......@@ -121,7 +154,7 @@ if __name__ == "__main__":
# Positive P(s1 -a-> s1.a)
count = exc if a == user_profile + 1 else 1
new_count = args.alpha * count if not args.norm else args.alpha * count / total_count
assert (args.alpha * count < nrm), "AssertionError: alpha parameter too large. Probabilities out of range."
assert (args.alpha * count < total_count), "AssertionError: alpha parameter too large. Probabilities out of range."
s2 = get_next_state_id(s1, a)
transitions_str += "%d\t%d\t%d\t%s\n" % (s1, a, s2, new_count)
# Negative P(s1 -a-> s1.b), b!= a
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment