我有一组不同形状的矩阵M = (M_1, M_2, ... M_K)
。为了提高效率,我可以将所有的M
存储到一个大小为K x max(M_k.shape[0]) x max(M_k.shape[1])
的Tensor中。这对于批量矩阵乘法和逐元素加法之类的事情来说很好。但是如果我想做逐元素除法,忽略零元素呢?
我想到的最好的版本是:
import numpy as np
import tensorflow as tf
M = tf.constant(np.array([[1.,2.,0],[3.,4.,5.],[6.,0,0]]), tf.float32)
Minv = tf.select(tf.equal(M, 0), tf.zeros_like(M), tf.inv(M))
字符串
这是最快的方法吗?tf.select
通过GPU仍然可以很好地加速吗?
1条答案
按热度按时间f0ofjuux1#
您使用的是哪个版本的Tensorflow?
tf.select
是deprecated。你可以使用
tf.where
,下面是一个实现:字符串
其中:
tf.abs(M)>eps
是掩码