在redis中存储numpy数组的最快方法

euoag5mw  于 5个月前  发布在  Redis
关注(0)|答案(6)|浏览(56)

我在一个AI项目中使用Redis。
我们的想法是让多个环境模拟器在多个CPU内核上运行策略。模拟器将经验(状态/动作/奖励元组的列表)写入redis服务器(重放缓冲区)。然后训练进程将经验作为数据集读取以生成新策略。新策略部署到模拟器,删除以前运行的数据,然后进程继续。
大部分的体验都是在“状态”中捕获的。状态通常表示为一个大的numpy数组,比如80 x 80。模拟器在cpu允许的情况下尽可能快地生成这些。
为此,有没有人有好的想法或经验的最好/最快/最简单的方法来编写大量的numpy数组到redis。这是所有在同一台机器上,但后来,可以在一组云服务器上。代码样本欢迎!

bqjvbblv

bqjvbblv1#

我不知道它是不是最快,但你可以试试这样的...
将Numpy数组存储到Redis是这样的-参见函数toRedis()

  • 获取Numpy数组的形状并编码
  • 将Numpy数组作为字节追加到形状中
  • 在提供的键下存储编码数组

检索Numpy数组是这样的-请参阅函数fromRedis()

  • 从Redis中检索与提供的key对应的编码字符串
  • 从字符串中提取Numpy数组的形状
  • 提取数据并重新填充Numpy数组,重塑为原始形状
#!/usr/bin/env python3

import struct
import redis
import numpy as np

def toRedis(r,a,n):
   """Store given Numpy array 'a' in Redis under key 'n'"""
   h, w = a.shape
   shape = struct.pack('>II',h,w)
   encoded = shape + a.tobytes()

   # Store encoded data in Redis
   r.set(n,encoded)
   return

def fromRedis(r,n):
   """Retrieve Numpy array from Redis key 'n'"""
   encoded = r.get(n)
   h, w = struct.unpack('>II',encoded[:8])
   # Add slicing here, or else the array would differ from the original
   a = np.frombuffer(encoded[8:]).reshape(h,w)
   return a

# Create 80x80 numpy array to store
a0 = np.arange(6400,dtype=np.uint16).reshape(80,80) 

# Redis connection
r = redis.Redis(host='localhost', port=6379, db=0)

# Store array a0 in Redis under name 'a0array'
toRedis(r,a0,'a0array')

# Retrieve from Redis
a1 = fromRedis(r,'a0array')

np.testing.assert_array_equal(a0,a1)

字符串
您可以通过将Numpy数组的dtype与形状沿着编码来增加更多灵活性。我没有这样做,因为可能您已经知道所有数组都是一种特定类型,然后代码会变得更大,更难阅读。

现代iMac上的粗略基准

80x80 Numpy array of np.uint16   => 58 microseconds to write
200x200 Numpy array of np.uint16 => 88 microseconds to write

Keywords:Python,Numpy,Redis,array,serialize,serialize,key,incr,unique

mwngjboj

mwngjboj2#

你也可以考虑使用msgpack-numpy,它提供了“编码和解码例程,可以使用高效的msgpack格式对numpy提供的数值和数组数据类型进行序列化和非序列化”--参见https://msgpack.org/
快速概念验证:

import msgpack
import msgpack_numpy as m
import numpy as np
m.patch()               # Important line to monkey-patch for numpy support!

from redis import Redis

r = Redis('127.0.0.1')

# Create an array, then use msgpack to serialize it 
d_orig = np.array([1,2,3,4])
d_orig_packed = m.packb(d_orig)

# Set the data in redis
r.set('d', d_orig_packed)

# Retrieve and unpack the data
d_out = m.unpackb(r.get('d'))

# Check they match
assert np.alltrue(d_orig == d_out)
assert d_orig.dtype == d_out.dtype

字符串
在我的机器上,msgpack运行速度比使用struct快得多:

In: %timeit struct.pack('4096L', *np.arange(0, 4096))
1000 loops, best of 3: 443 µs per loop

In: %timeit m.packb(np.arange(0, 4096))
The slowest run took 7.74 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 32.6 µs per loop

g6baxovj

g6baxovj3#

下面我重写了函数fromRedistoRedis来处理可变维大小的数组,并包含数组的形状。

def toRedis(arr: np.array) -> str:
    arr_dtype = bytearray(str(arr.dtype), 'utf-8')
    arr_shape = bytearray(','.join([str(a) for a in arr.shape]), 'utf-8')
    sep = bytearray('|', 'utf-8')
    arr_bytes = arr.ravel().tobytes()
    to_return = arr_dtype + sep + arr_shape + sep + arr_bytes
    return to_return

def fromRedis(serialized_arr: str) -> np.array:
    sep = '|'.encode('utf-8')
    i_0 = serialized_arr.find(sep)
    i_1 = serialized_arr.find(sep, i_0 + 1)
    arr_dtype = serialized_arr[:i_0].decode('utf-8')
    arr_shape = tuple([int(a) for a in serialized_arr[i_0 + 1:i_1].decode('utf-8').split(',')])
    arr_str = serialized_arr[i_1 + 1:]
    arr = np.frombuffer(arr_str, dtype = arr_dtype).reshape(arr_shape)
    return arr

字符串

xwmevbvl

xwmevbvl4#

尝试给予plasma,因为它避免了串行化/并行化开销。
使用pip install pyarrow安装血浆
文档:https://arrow.apache.org/docs/python/plasma.html
首先,推出1gb内存的plasma [终端]:
plasma_store -m 100000000-s /tmp/plasma

import pyarrow.plasma as pa
import numpy as np
client = pa.connect("/tmp/plasma")
temp = np.random.rand(80,80)

字符串
写入时间:130 µs vs 782 µs(Redis实现:Mark Sethoff的回答)
通过使用plasma巨大页面可以改善写入时间,但仅适用于Linux机器:https://arrow.apache.org/docs/python/plasma.html#using-plasma-with-huge-pages
读取时间:31.2 µs vs 99.5 µs(Redis实现:Mark Sethoff的回答)
PS:代码在MacPro上运行

xt0899hw

xt0899hw5#

tobytes()函数的存储效率不是很高。为了减少必须写入redis服务器的存储,您可以使用base64包:

def encode_vector(ar):
    return base64.encodestring(ar.tobytes()).decode('ascii')

def decode_vector(ar):
    return np.fromstring(base64.decodestring(bytes(ar.decode('ascii'), 'ascii')), dtype='uint16')

字符串
@编辑:好的,由于Redis将值存储为字节字符串,因此直接存储字节字符串会更有效。但是,如果您将其转换为字符串,将其打印到控制台,或将其存储在文本文件中,则进行编码是有意义的。

dnph8jn4

dnph8jn46#

这是我从Jadiel de Armas修改的代码,他的代码几乎是正确的,只是缺少解码部分。我测试了它,它为我工作。

def set_numpy(redis, key: str, np_value: np.ndarray):
        d_type =  bytearray(str(np_value.dtype),'utf-8')
        d_shape =  bytearray(','.join([str(a) for a in np_value.shape]), 'utf-8')
        sep = bytearray('|', 'utf-8')
        data = np_value.ravel().tobytes()
        value = base64.a85encode(d_type + sep + d_shape + sep + data)
        redis.set(key, value)

   def get_numpy(redis, key:str):
        binary_value = redis.get(key)
        binary_value = base64.a85decode(binary_value)
        sep = '|'.encode('utf-8')
        i_0 = binary_value.find(sep)
        i_1 = binary_value.find(sep, i_0 + 1)
        arr_dtype = binary_value[:i_0].decode('utf-8')
        arr_shape = tuple([int(a) for a in binary_value[i_0 + 1:i_1].decode('utf-8').split(',')])
        arr_str = binary_value[i_1 + 1:]
        arr = np.frombuffer(arr_str, dtype=arr_dtype).reshape(arr_shape)
        return arr

字符串

相关问题