Commit bd114e31 authored by Amelie Royer's avatar Amelie Royer

memory-friendly transition writing in prepare_synth

parent a39eed3e
from __future__ import print_function
#!/usr/bin/env python
# -*- coding: utf-8 -*-
......@@ -14,6 +15,9 @@ import argparse
from random import randint
from utils import ChunkedWriter, Logger, init_base_writing, get_nstates, get_next_state_id
if sys.version_info[0] == 3:
xrange = range
def init_output_dir(nitems, hlength):
"""
Initializes the output directory.
......@@ -59,7 +63,7 @@ if __name__ == "__main__":
n_users = args.nactions
actions = range(1, n_items + 1)
init_base_writing(n_items, args.history)
n_states = get_nstates(n_items, args.history)
n_states = int(get_nstates(n_items, args.history))
output_base = init_output_dir(args.nactions, args.history)
#### 2. Write .items and .profiles dummy files
......@@ -70,7 +74,7 @@ if __name__ == "__main__":
f.write("\n".join("%d\t1\t1" % i for i in xrange(n_users)))
##### 3. Create dummy test sessions
print "\n\033[91m-----> Test sequences generation\033[0m"
print("\n\033[91m-----> Test sequences generation\033[0m")
exc = 4 * (n_users - 1) # Sample size. Ensure 0.8 probability given to action i
with open("%s.test" % output_base, 'w') as f:
for user in xrange(args.test):
......@@ -92,18 +96,21 @@ if __name__ == "__main__":
f.write("%d\t%d\t%s\n" % (user, cluster, ' '.join(str(x) for x in session) ))
###### 4. Set rewards
print "\n\n\033[91m-----> Rewards generation\033[0m"
print("\n\n\033[91m-----> Rewards generation\033[0m")
with open("%s.rewards" % output_base, 'w') as f:
for item in actions:
sys.stderr.write(" item: %d / %d \r" % (item, len(actions)))
f.write("%d\t%.5f\n" % (item, 1))
###### 5. Create transition function
print "\n\n\033[91m-----> Probability inference\033[0m"
# Write
f = gzip.open("%s.transitions.gz" % output_base, 'w') if args.zip else open("%s.transitions" % output_base, 'w')
buffer_size = 2**16
print("\n\n\033[91m-----> Probability inference\033[0m")
total_count = exc + n_items - 1
transitions_str = ""
for user_profile in xrange(n_users):
print >> sys.stderr, "\n > Profile %d / %d: \n" % (user_profile + 1, n_users),
print("\n > Profile %d / %d: \n" % (user_profile + 1, n_users), end=" ", file=sys.stderr)
sys.stderr.flush()
# For fixed s1
for s1 in xrange(n_states):
......@@ -124,23 +131,18 @@ if __name__ == "__main__":
s2 = get_next_state_id(s1, s2_link)
count = exc if s2_link == user_profile + 1 else 1
transitions_str += "%d\t%d\t%d\t%s\n" % (s1, a, s2, beta * count if not args.norm else beta * count / total_count)
# If buffer overflows, write in the zip file
if len(transitions_str) > buffer_size:
f.write(bytes(transitions_str.encode("UTF-8")) if args.zip else transitions_str)
transitions_str = ""
transitions_str += "\n"
f.write(bytes(transitions_str.encode("UTF-8")) if args.zip else transitions_str)
f.close()
# Write
print "\n\n\033[91m-----> Writing...\033[0m"
if not args.zip:
with open("%s.transitions" % output_base, 'w') as f:
f.write(transitions_str)
else:
f = gzip.open("%s.transitions.gz" % output_base, 'wb')
cw = ChunkedWriter(f)
cw.write(transitions_str)
f.close()
with open("%s.summary" % output_base, 'wb') as f:
with open("%s.summary" % output_base, 'w') as f:
f.write("%d States\n%d Actions (Items)\n%d user profiles\n%d history length\n%d product clustering level\n\n%s" % (n_states, n_items, n_users, args.history, args.nactions, logger.to_string()))
###### 6. Summary
print "\n\n\033[92m-----> End\033[0m"
print " Output directory: %s" % output_base
print("\n\n\033[92m-----> End\033[0m")
print(" Output directory: %s" % output_base)
# End
......@@ -10,9 +10,14 @@ __email__ = "amelie.royer@ist.ac.at"
import sys
import numpy as np
from StringIO import StringIO
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
if sys.version_info[0] == 3:
xrange = range
class Logger:
"""
Capture print statments and write them to a log file while printing on the screen.
......@@ -33,6 +38,9 @@ class Logger:
def to_string(self):
return self.logfile.getvalue()
def flush(self):
self.logfile.flush()
def line_count(f):
"""
......@@ -114,7 +122,7 @@ def get_n_customer_cluster(ulevel):
if ulevel == 0:
return 6
else:
print >> sys.stderr, "Unknown ulevel = %d option. Exit." % ulevel
#print >> sys.stderr, "Unknown ulevel = %d option. Exit." % ulevel
raise SystemExit
......@@ -234,7 +242,7 @@ class ChunkedWriter(object):
Write chunks of data in a given file. Work around of the overflow bug when writing
with gzip in Python 2.7
"""
def __init__(self, file, chunksize=sys.maxint):
def __init__(self, file, chunksize=65536):
self.file = file
self.chunksize = chunksize
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment