Tensorflow float64 error while running in eager execution

ykejflvf  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(60)

我正在使用TF 2.13.0,只有在启用即时执行时才会出现错误。有解决方法吗?
误差

tensorflow.python.framework.errors_impl.InvalidArgumentError: TensorArray dtype is float64 but Op is trying to write dtype float32

字符串
代码

import tensorflow as tf

# when the next line is uncommented, we get an error
tf.config.run_functions_eagerly(True)

@tf.function(input_signature=[tf.TensorSpec(shape=None,dtype=tf.float64)])
def TrySine(dev):
    mytensor = tf.map_fn(fn=lambda t,dev=dev: tf.math.sin(dev*3.14/180.0), elems=tf.ones(shape=(8,),dtype='float64'))
    return mytensor

output = TrySine(dev=5.0)

print(output)

neskvpey

neskvpey1#

一般来说,tensorflow不喜欢使用不同的dtypes进行计算。它通常会抛出警告或错误。
除非你有充分的理由使用float64,否则我建议你坚持使用float32,这是深度学习的标准。

import tensorflow as tf

# when the next line is uncommented, we get an error
tf.config.run_functions_eagerly(True)

@tf.function(input_signature=[tf.TensorSpec(shape=None,dtype=tf.float32)])
def TrySine(dev):
    mytensor = tf.map_fn(fn=lambda t,dev=dev: tf.math.sin(dev*3.14/180.0), elems=tf.ones(shape=(8,),dtype='float32'))
    return mytensor

output = TrySine(dev=5.0)

print(output)

字符串
但是如果你真的想使用float64,你可以确保你的输入常量在这个dtype中:

@tf.function(input_signature=[tf.TensorSpec(shape=None,dtype=tf.float64)])
def TrySine(dev):
    mytensor = tf.map_fn(fn=lambda t,dev=dev: tf.math.sin(dev*3.14/180.0), elems=tf.ones(shape=(8,),dtype='float64'))
    return mytensor

output = TrySine(dev=tf.constant(0.5, dtype=tf.float64))

print(output)

相关问题