pytorch 计算一个Tensor中有多少个元素存在于另一个Tensor中

vmpqdwk3  于 9个月前  发布在  其他
关注(0)|答案(1)|浏览(102)

我有两个1DTensor:

A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
B = torch.tensor([2, 5, 6, 8, 12, 15, 16])

Tensor非常巨大,长度不同,值的顺序更快,也没有排序。
我想得到B中(i)存在于A中(ii)不存在于A中的元素的个数。因此,输出将是:

Exists: 4
Do not exist: 3

我试过:

exists = torch.eq(A,B).sum().item()
not_exist = torch.numel(B) - exists

但这给出了错误:

RuntimeError: The size of tensor a (10) must match the size of tensor b (7) at non-singleton dimension 0

下面的方法是可行的,但它首先需要创建一个booleanTensor,然后对true元素求和。它对非常大的Tensor有效吗?

exists = np.isin(A,B).sum()
not_exist = torch.numel(B) - exists

有没有更好或更有效的方法?

rsl1atfo

rsl1atfo1#

请尝试以下操作:进口 Torch

A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
B = torch.tensor([2, 5, 6, 8, 12, 15, 16])

# Convert tensors to sets
setA = set(A.numpy())
setB = set(B.numpy())

# Find intersection and difference
intersection = setA & setB
difference = setB - setA

# Calculate the counts
exists = len(intersection)
not_exist = len(difference)

print(f"Exists: {exists}")
print(f"Do not exist: {not_exist}")

更新:

你可以坚持使用原生的PyTorch方法,比如广播。这种方法可能会占用更多内存,但对于大型Tensor更有效,特别是在使用GPU加速时。

import torch

A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
B = torch.tensor([2, 5, 6, 8, 12, 15, 16])

# Use broadcasting to compare each element of B with all of A
# The result will be a 2D tensor with shape (len(B), len(A))
comparison_matrix = B[:, None] == A

# Sum along the second axis will give a tensor where each element
# indicates how many times an element from B was found in A
matches_per_element = comparison_matrix.sum(dim=1)

# Calculate the counts
exists = (matches_per_element > 0).sum().item()
not_exist = len(B) - exists

print(f"Exists: {exists}")
print(f"Do not exist: {not_exist}")

相关问题