Below is the file 'mk2.py' from this revision. You can also download the file.
#!/usr/bin/env python import cPickle import random random.seed() import heapq import math import sys class MarkovState(object): def __init__(self, state): self.state = state self.h = None self.total = 0 self.scores = {} def increment(self, token): self.total += 1 self.scores[token] = self.scores.get(token, 0) + 1 self.h = None def __entropy(self): return -1 * sum(map(lambda p: p * math.log(p, 2), map(lambda x: (self.scores[x] / float(self.total)), self.scores))) def entropy(self): if self.h == None: self.h = self.__entropy() return self.h def __repr__(self): return "state" + repr(self.scores) def __cmp__(self, other): if other == None: return -1 return cmp(other.entropy(), self.entropy()) class MarkovChain(object): def __init__(self, length, join_token='', cutoff_func=None): self.length = length self.join_token = join_token self.upchunked = set() self.cutoff_func = cutoff_func or MarkovChain.log_chunkable self.clear() @classmethod def log_chunkable (cls, self, entropies): # fast, but not necessarily as correct return math.log (len(self.states.keys ()), 2) / 10 @classmethod def standard_deviation_chunkable (cls, self, entropies): l_h = len(entropies) mean_h = sum(entropies) / l_h sd_h = math.sqrt(sum([ pow(t - mean_h, 2) for t in entropies ]) / l_h) # print >> sys.stderr, l_h, mean_h, sd_h cutoff = mean_h + 3 * sd_h # should really justify in some way other than 'it works' return cutoff def update(self, gen): buffer = [] this_stash = [] for token in gen: this_stash.append(token) if len(buffer) == self.length: tbuffer = tuple(buffer) if self.states.has_key(tbuffer): state = self.states[tbuffer] else: state = self.states[tbuffer] = MarkovState(tbuffer) state.increment(token) buffer = buffer[1:] buffer.append(token) self.stash.append(this_stash) def clear(self): self.states = {} self.stash = [] def random_next(self, from_state): def next_state(token): return from_state.state[:-1] + (token,) # eliminate dead-ends def not_dead_end(token): return self.states.has_key (next_state (token)) possible = filter (not_dead_end, from_state.scores.keys()) # print >>sys.stderr, (from_state, possible) if not possible: return None total = sum (map (lambda s: from_state.scores[s], possible)) choice = random.randrange(0, total) for k in possible: total -= from_state.scores[k] if total <= 0: return self.states[next_state(k)] raise Exception("Unreachable") def upchunk(self): while True: to_upchunk, to_upchunk_value = self.__select_upchunk() if to_upchunk == None: break stash_copy = self.stash self.clear() self.update_upchunked (to_upchunk, to_upchunk_value) for stash in stash_copy: self.update(self.__upchunk_gen (stash, to_upchunk, to_upchunk_value)) del stash_copy def update_upchunked (self, to_upchunk, replace_with): self.upchunked.add (replace_with) for token in to_upchunk: if token in self.upchunked: self.upchunked.remove (token) def __select_upchunk(self): q = [] keys = self.states.keys() keylen = len(keys) if keylen == 0: return None, None max_h = -1 candidate = None entropies = [] for idx, tokens in enumerate(keys): state = self.states[tokens] h = state.entropy () entropies.append (h) if h > max_h: max_h = h candidate = state cutoff = self.cutoff_func (self, entropies) # print >>sys.stderr, "best entropy vs. cutoff is: %s :: %.2f vs. cutoff %.2f" % (candidate.state, candidate.entropy(), cutoff) if candidate.entropy() < cutoff: return None, None else: return candidate.state, self.join_token.join(candidate.state) def __upchunk_gen(self, gen, to_upchunk, replace_with): buffer = [] for i in gen: buffer.append(i) if len(buffer) == len(to_upchunk): if tuple(buffer) == to_upchunk: buffer = [ replace_with ] else: to_yield, buffer = buffer[0], buffer[1:] yield to_yield for i in buffer: yield i def pprint(self): from pprint import pprint pprint(chain.states) def simple_gen(fname): for line in open(fname, 'rb'): for char in line: yield char # for word in line.split(): # yield word.lower() if __name__ == '__main__': chain = MarkovChain(2) for infile in sys.argv[1:]: print >> sys.stderr, "Reading input file:", infile chain.update(simple_gen (infile)) chain.upchunk() print >>sys.stderr, "processing produced", len(chain.states.keys()), "states." cPickle.dump(chain, sys.stdout, protocol=2)