python 在tensorflow中可以使用环境变量来代替tf.constant吗?

6psbrbz9  于 6个月前  发布在  Python
关注(0)|答案(1)|浏览(61)

我是tensorflow的新手,最近我正在处理一个项目,该项目有以下代码

def build_fv_graph(self) -> tf.Tensor:
        with tf.variable_scope("modifier_attr"):
            t_mdl_name = tf.constant(self.model_name, name="mdl_name", dtype=tf.string)

字符串
正如你所看到的,这里的代码使用tf.constant将一个字符串硬编码到模型中,这将方便用户覆盖它,例如,通过环境变量。我的问题是这是可能的,或者对于这种用例有更好的解决方案吗?

0sgqnhkj

0sgqnhkj1#

我仍然不确定这是否对你有用。但是让我们试试。用这样的东西加载你的模型:

def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

字符串
当创建模型时,使用类似这样的东西:

# Create a placeholder for mdl_name
mdl_name_placeholder = tf.placeholder(tf.string, name="mdl_name")

# Create the graph
with tf.Graph().as_default() as graph:
    with tf.variable_scope("modifier_attr"):
        # Use a placeholder instead of a constant
        t_mdl_name = tf.identity(mdl_name_placeholder, name="mdl_name_identity")
        
    # Add your logic for constructing the graph


当你想改变变量时,使用类似这样的方法:

# Example usage:
with tf.Session(graph=graph) as sess:
    # Initialize variables (if you have any in the graph)
    sess.run(tf.global_variables_initializer())

    # Feed the placeholder with the desired value
    mdl_name_value = "xyz"
    
    # Run the graph with a new value for mdl_name
    result = sess.run(t_mdl_name, feed_dict={mdl_name_placeholder: mdl_name_value})
    print(result)


这就是我知道你提供的细节所能想到的。也许有些语法是不正确的,但想法就在这里。

相关问题