Why do we need TensorFlow tf.Graph?

What is the purpose of:

with tf.Graph().as_default()

I have some tensorflow code that uses the above. However, the code has only one graph, so why do we need this?


Solution 1:

TL;DR: It's unnecessary, but it's a good practice to follow.

Since a default graph is always registered, every op and variable is placed into the default graph. The statement, however, creates a new graph and places everything (declared inside its scope) into this graph. If the graph is the only graph, it's useless. But it's a good practice because if you start to work with many graphs it's easier to understand where ops and vars are placed. Since this statement costs you nothing, it's better to write it anyway. Just to be sure that if you refactor the code in the future, the operations defined belong to the graph you choose initially.

Solution 2:

It's an artifact of the time when you had to explicitly specify graph for every op you created.

I haven't seen any compelling cases to need more than one graph, so you can usually get away with keeping graph implicit and using tf.reset_default_graph() when you want to wipe slate clean

Some gotchas:

  • Default graph stack is thread local, so creating ops in multiple threads will create multiple graphs
  • Session keeps a handle of its graph (sess.graph), so if you create Session before you call tf.reset_default_graph(), your session graph will be different from your default graph which means that new ops you create won't be runnable in that sesson

When you hit one of those gotchas, you can set a particular graph (ie, from tf.get_default_graph() in another thread or sess.graph) to be default graph as follows:

self.graph_context = graph.as_default()   # save it to some variable that won't get gc'ed
self.graph_context.enforce_nesting = False
self.graph_context.__enter__()