如何在TensorFlow 2中重置初始化

ubof19bj  于 6个月前  发布在  其他
关注(0)|答案(2)|浏览(103)

如果我在初始化tf.Variable后尝试在TensorFlow 2中更改并行度,

import tensorflow as tf
_ = tf.Variable([1])
tf.config.threading.set_inter_op_parallelism_threads(1)

字符串
我得到一个错误
RuntimeError:Inter op parallelism cannot be modified after initialization.
我理解为什么会这样,但是它(可能还有其他因素)导致我的测试相互干扰。

def test_model():  # this test
   v = tf.Variable([1])
   ...

def test_threading():  # is breaking this test
   tf.config.threading.set_inter_op_parallelism_threads(1)
   ...


如何重置TensorFlow状态以便设置线程?

owfi6suc

owfi6suc1#

这是可以通过一种“黑客”的方式来实现的。但我建议以正确的方式来实现(即在开始时设置配置)。

import tensorflow as tf
from tensorflow.python.eager import context

_ = tf.Variable([1])

context._context = None
context._create_context()

tf.config.threading.set_inter_op_parallelism_threads(1)

字符串

编辑:一开始设置config是什么意思,

import tensorflow as tf
from tensorflow.python.eager import context

tf.config.threading.set_inter_op_parallelism_threads(1)
_ = tf.Variable([1])


但是在某些情况下你不能总是这样做。只是指出在tf中设置config的传统方法。所以如果你的情况不允许你在开始时修复tf.config,你必须重置你的tf.eager.context,如上面的解决方案所示。

uxhixvfz

uxhixvfz2#

扩展@thushv89提供的解决方案,我们还可以通过以下方式直接设置上下文:

from tensorflow.python.eager import context
context.context().intra_op_parallelism_threads = <num_threads>
context.context().inter_op_parallelism_threads = <num_threads>
import tensorflow as tf

字符串
有关get()set()线程函数的详细信息,请单击此处

相关问题