diff options
Diffstat (limited to 'analysis/tensorflow/fast_em.py')
-rwxr-xr-x | analysis/tensorflow/fast_em.py | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/analysis/tensorflow/fast_em.py b/analysis/tensorflow/fast_em.py new file mode 100755 index 0000000..ea001e4 --- /dev/null +++ b/analysis/tensorflow/fast_em.py @@ -0,0 +1,180 @@ +#!/usr/bin/python +""" +fast_em.py: Tensorflow implementation of expectation maximization for RAPPOR +association analysis. + +TODO: + - Use TensorFlow ops for reading input (so that reading input can be + distributed) + - Reduce the number of ops (currently proportional to the number of reports). + May require new TensorFlow ops. + - Fix performance bug (v_split is probably being recomputed on every + iteration): + bin$ ./test.sh decode-assoc-cpp - 1.1 seconds (single-threaded C++) + bin$ ./test.sh decode-assoc-tensorflow - 226 seconds on GPU +""" + +import sys + +import numpy as np +import tensorflow as tf + + +def log(msg, *args): + if args: + msg = msg % args + print >>sys.stderr, msg + + +def ExpectTag(f, expected): + """Read and consume a 4 byte tag from the given file.""" + b = f.read(4) + if b != expected: + raise RuntimeError('Expected %r, got %r' % (expected, b)) + + +def ReadListOfMatrices(f): + """ + Read a big list of conditional probability matrices from a binary file. + """ + ExpectTag(f, 'ne \0') + num_entries = np.fromfile(f, np.uint32, count=1)[0] + log('Number of entries: %d', num_entries) + + ExpectTag(f, 'es \0') + entry_size = np.fromfile(f, np.uint32, count=1)[0] + log('Entry size: %d', entry_size) + + ExpectTag(f, 'dat\0') + vec_length = num_entries * entry_size + v = np.fromfile(f, np.float64, count=vec_length) + + log('Values read: %d', len(v)) + log('v: %s', v[:10]) + #print 'SUM', sum(v) + + # NOTE: We're not reshaping because we're using one TensorFlow tensor object + # per matrix, since it makes the algorithm expressible with current + # TensorFlow ops. + #v = v.reshape((num_entries, entry_size)) + + return num_entries, entry_size, v + + +def WriteTag(f, tag): + if len(tag) != 3: + raise AssertionError("Tags should be 3 bytes. Got %r" % tag) + f.write(tag + '\0') # NUL terminated + + +def WriteResult(f, num_em_iters, pij): + WriteTag(f, 'emi') + emi = np.array([num_em_iters], np.uint32) + emi.tofile(f) + + WriteTag(f, 'pij') + pij.tofile(f) + + +def DebugSum(num_entries, entry_size, v): + """Sum the entries as a sanity check.""" + cond_prob = tf.placeholder(tf.float64, shape=(num_entries * entry_size,)) + debug_sum = tf.reduce_sum(cond_prob) + with tf.Session() as sess: + s = sess.run(debug_sum, feed_dict={cond_prob: v}) + log('Debug sum: %f', s) + + +def BuildEmIter(num_entries, entry_size, v): + # Placeholder for the value from the previous iteration. + pij_in = tf.placeholder(tf.float64, shape=(entry_size,)) + + # split along dimension 0 + # TODO: + # - make sure this doesn't get run for every EM iteration + # - investigate using tf.tile() instead? (this may cost more memory) + v_split = tf.split(0, num_entries, v) + + z_numerator = [report * pij_in for report in v_split] + sum_z = [tf.reduce_sum(report) for report in z_numerator] + z = [z_numerator[i] / sum_z[i] for i in xrange(num_entries)] + + # Concat per-report tensors and reshape. This is probably inefficient? + z_concat = tf.concat(0, z) + z_concat = tf.reshape(z_concat, [num_entries, entry_size]) + + # This whole expression represents an EM iteration. Bind the pij_in + # placeholder, and get a new estimation of Pij. + em_iter_expr = tf.reduce_sum(z_concat, 0) / num_entries + + return pij_in, em_iter_expr + + +def RunEm(pij_in, entry_size, em_iter_expr, max_em_iters, epsilon=1e-6): + """Run the iterative EM algorithm (using the TensorFlow API). + + Args: + num_entries: number of matrices (one per report) + entry_size: total number of cells in each matrix + v: numpy.ndarray (e.g. 7000 x 8 matrix) + max_em_iters: maximum number of EM iterations + + Returns: + pij: numpy.ndarray (e.g. vector of length 8) + """ + # Initial value is the uniform distribution + pij = np.ones(entry_size) / entry_size + + i = 0 # visible outside loop + + # Do EM iterations. + with tf.Session() as sess: + for i in xrange(max_em_iters): + print 'PIJ', pij + new_pij = sess.run(em_iter_expr, feed_dict={pij_in: pij}) + dif = max(abs(new_pij - pij)) + log('EM iteration %d, dif = %e', i, dif) + pij = new_pij + + if dif < epsilon: + log('Early EM termination: %e < %e', max_dif, epsilon) + break + + # If i = 9, then we did 10 iteratinos. + return i + 1, pij + + +def sep(): + print '-' * 80 + + +def main(argv): + input_path = argv[1] + output_path = argv[2] + max_em_iters = int(argv[3]) + + sep() + with open(input_path) as f: + num_entries, entry_size, cond_prob = ReadListOfMatrices(f) + + sep() + DebugSum(num_entries, entry_size, cond_prob) + + sep() + pij_in, em_iter_expr = BuildEmIter(num_entries, entry_size, cond_prob) + num_em_iters, pij = RunEm(pij_in, entry_size, em_iter_expr, max_em_iters) + + sep() + log('Final Pij: %s', pij) + + with open(output_path, 'wb') as f: + WriteResult(f, num_em_iters, pij) + log('Wrote %s', output_path) + + +if __name__ == '__main__': + try: + main(sys.argv) + except RuntimeError, e: + print >>sys.stderr, 'FATAL: %s' % e + sys.exit(1) |