numpy 在具有重复行对的2D阵列中有效地查找全行置换

lvjbypge  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(49)

考虑数组:

import numpy as np
import numpy_indexed as npi
from itertools import permutations

arr = np.array([[1, 2, 3, 4],
                [3, 3, 3, 6],
                [2, 0, 0, 2],
                [2, 0, 0, 2],  
                [8, 2, 8, 2],
                [4, 5, 4, 5], 
                [3, 3, 3, 6],
                [4, 5, 4, 5],
                [0, 9, 8, 7],
                [1, 2, 3, 4]])

字符串
我需要找到所有唯一的行排列。我现在做的是找到所有10!行排列,然后使用npi.unique。像这样:

arr_perms = np.array([arr[i, :] for i in permutations(range(len(arr)))])
u, index = npi.unique(arr_perms, return_index = True, axis=0)


这和预期的一样有效,但是由于我正在使用的数组的性质,它似乎很昂贵。这些数组都至少有一对(通常是几对)相同的行
在所示的小示例中,10行包含4个相同的行对,因此唯一行置换的总数仅为10!/2^4 = 226800,大大减少。

问题:有没有一种方法可以有效地找到唯一的行排列,而不必首先找到完整的排列集?

n6lpvg4x

n6lpvg4x1#

这里的代码可能会对你有所帮助,它仍然会创建所有的排列,但是它比你发布的代码快10倍,因为每个长度为4的子数组都使用了位移位转换为int。

import math
import numpy as np
import numpy_indexed as npi

def generate_permutations_matrix(n, k):
    """
    Generate a matrix of permutations for given values of n and k.

    Parameters:
    - n (int): Total number of elements.
    - k (int): Length of each permutation.

    Returns:
    - np.ndarray: Matrix of permutations.
    """
    result_matrix = np.zeros((math.perm(n, k), k), np.uint8)
    f = 1
    for m in range(n - k + 1, n + 1):
        sub_matrix = result_matrix[:f, n - m + 1:]
        for i in range(1, m):
            result_matrix[i * f: (i + 1) * f, n - m] = i
            result_matrix[i * f: (i + 1) * f, n - m + 1:] = sub_matrix + (
                    sub_matrix >= i
            )
        sub_matrix += 1
        f *= m
    return result_matrix

def convert_to_integer(arr):
    """
    Convert a 2D array of integers to a single integer value.

    Parameters:
    - arr (np.ndarray): Input 2D array.

    Returns:
    - int: Integer representation of the array.
    """
    arr = arr.astype(np.uint64)
    integer_value = arr[..., 0].copy()
    for ival in range(1, arr.shape[1]):
        integer_value += arr[..., ival] << ival * 8
    return integer_value

def get_permutations(arr, len_of_each):
    """
    Get unique permutations of rows in a 2D array.

    Parameters:
    - arr (np.ndarray): Input 2D array.
    - len_of_each (int): Length of each permutation.

    Returns:
    - np.ndarray: Permutations matrix.
    """
    integer_value = convert_to_integer(arr)

    # Create dictionaries for mapping integer values to unique identifiers
    lookup_dict = {}
    lookup_dict2 = {}
    alpha1 = []
    identifier = 0

    # Iterate through integer values and create mappings
    for ini, x in enumerate(integer_value):
        if x in lookup_dict:
            value = lookup_dict.get(x)
            alpha1.append((value, x))
            lookup_dict2[ini] = value
        else:
            alpha1.append((identifier, x))
            lookup_dict[x] = identifier
            lookup_dict2[ini] = identifier
            identifier += 1

    # Generate permutations matrix
    permutations_matrix = generate_permutations_matrix(len(alpha1), len_of_each)

    # Map original identifiers back to permutations matrix
    cond_list = []
    choice_list = []
    for nom in lookup_dict2.items():
        cond_list.append(permutations_matrix == nom[0])
        choice_list.append(nom[1])

    choice_list = np.array(choice_list, dtype=np.uint8)
    unique_permutations = npi.unique(
        np.select(cond_list, choice_list, 0), return_index=False, axis=0
    )

    # Map back to original identifiers
    cond_list2 = []
    choice_list2 = []
    for nom in lookup_dict.items():
        cond_list2.append(unique_permutations == nom[1])
        choice_list2.append(nom[0])

    result = np.select(cond_list2, choice_list2, 0)

    # Convert back to the original 2D array format
    return np.stack(
        [
            np.dstack([(result[..., g] >> rx * 8) & 255 for rx in range(arr.shape[1])])
            for g in range(len_of_each)
        ],
        axis=2,
    ).squeeze()

# Example usage
arr = np.array(
    [
        [1, 2, 3, 4],
        [3, 3, 3, 6],
        [2, 0, 0, 2],
        [2, 0, 0, 2],
        [8, 2, 8, 2],
        [4, 5, 4, 5],
        [3, 3, 3, 6],
        [4, 5, 4, 5],
        [0, 9, 8, 7],
        [1, 2, 3, 4],
    ]
)
result_permutations = get_permutations(arr=arr, len_of_each=10)

from itertools import permutations

def test(arr):
    arr_perms = np.array([arr[i, :] for i in permutations(range(len(arr)))])
u, index = npi.unique(arr_perms, return_index=True, axis=0)
return u, index

%timeit result_permutations = get_permutations(arr=arr, len_of_each=10)
1.52 s ± 27.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit test(arr)
13.6 s ± 52.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

字符串

相关问题