使用 Ray 和 Apache Arrow 实现 Python 快速序列化
已发布 2017年10月15日
作者 Philipp Moritz, Robert Nishihara
本文最初发布在 Ray 博客 上。 Philipp Moritz 和 Robert Nishihara 是加州大学伯克利分校的研究生。
这篇文章详细阐述了 Ray 和 Apache Arrow 之间的集成。这主要解决的是 数据序列化 问题。
根据 维基百科,序列化是
……将数据结构或对象状态转换为可以存储……或传输……并在以后(可能在不同的计算机环境中)重建的格式的过程。
为什么需要进行任何转换?因为当你创建一个 Python 对象时,它可能包含指向其他 Python 对象的指针,而这些对象都分配在不同的内存区域中,所有这些都必须在另一台机器上的另一个进程解包时才能理解。
序列化和反序列化是并行和分布式计算的瓶颈,尤其是在具有大型对象和大量数据的机器学习应用程序中。
设计目标
由于 Ray 针对机器学习和 AI 应用程序进行了优化,我们非常关注序列化和数据处理,并制定了以下设计目标:
- 它应该对大型数值数据非常高效(这包括 NumPy 数组和 Pandas DataFrame,以及递归包含 NumPy 数组和 Pandas DataFrame 的对象)。
- 对于一般 Python 类型,它的速度应该与 Pickle 差不多。
- 它应该与共享内存兼容,允许多个进程在不复制数据的情况下使用相同的数据。
- 反序列化应该非常快(如果可能,它不应该需要读取整个序列化对象)。
- 它应该是语言无关的(最终我们希望 Python worker 能够使用 Java 或其他语言的 worker 创建的对象,反之亦然)。
我们的方法和替代方案
Python 中的首选序列化方法是pickle 模块。Pickle 非常通用,尤其是在使用 cloudpickle 等变体时。但是,它不满足要求 1、3、4 或 5。像 json 这样的替代方案满足 5,但不满足 1-4。
我们的方法:为了满足要求 1-5,我们选择使用 Apache Arrow 格式作为我们的底层数据表示。我们与 Apache Arrow 团队合作,构建了用于将通用 Python 对象映射到 Arrow 格式和从 Arrow 格式映射回来的库。这种方法的一些特性:
- 数据布局与语言无关(要求 5)。
- 可以恒定时间内计算序列化数据块中的偏移量,而无需读取整个对象(要求 1 和 4)。
- Arrow 支持零拷贝读取,因此对象可以自然地存储在共享内存中并由多个进程使用(要求 1 和 3)。
- 我们可以自然地回退到 pickle 来处理我们无法很好处理的任何事情(要求 2)。
Arrow 的替代方案:我们本可以基于 Protocol Buffers 构建,但 Protocol Buffers 并非真正为数值数据而设计,并且该方法不满足 1、3 或 4。基于 Flatbuffers 构建实际上可以实现,但它需要实现 Arrow 已经拥有的许多功能,而且我们更喜欢针对大数据进行优化的列式数据布局。
加速
这里我们展示了 Python 的 pickle 模块的一些性能改进。实验是使用 pickle.HIGHEST_PROTOCOL
完成的。生成这些图的代码包含在文章末尾。
使用 NumPy 数组:在机器学习和 AI 应用程序中,数据(例如图像、神经网络权重、文本文档)通常表示为包含 NumPy 数组的数据结构。使用 NumPy 数组时,加速效果令人印象深刻。
反序列化的 Ray 条形几乎不可见,这并非错误。这是对零拷贝读取的支持的结果(节省主要来自缺乏内存移动)。


请注意,最大的优势在于反序列化。这里的加速是几个数量级的,并且随着 NumPy 数组变大而变得更好(由于设计目标 1、3 和 4)。使**反序列化**快速很重要,原因有两个。首先,一个对象可以被序列化一次,然后被反序列化多次(例如,广播到所有 worker 的对象)。其次,一种常见的模式是并行序列化许多对象,然后在单个 worker 上一次聚合和反序列化一个对象,这使得反序列化成为瓶颈。
**不使用 NumPy 数组:当使用我们无法利用共享内存的常规 Python 对象时,结果与 pickle 相当。


这些只是一些有趣的 Python 对象的示例。最重要的案例是 NumPy 数组嵌套在其他对象中的案例。请注意,我们的序列化库适用于非常通用的 Python 类型,包括自定义 Python 类和深度嵌套的对象。
API
序列化库可以通过 pyarrow 直接使用,如下所示。更多文档可在此处 获取。
x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
serialized_x = pyarrow.serialize(x).to_buffer()
deserialized_x = pyarrow.deserialize(serialized_x)
它可以通过 Ray API 直接使用,如下所示。
x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
x_id = ray.put(x)
deserialized_x = ray.get(x_id)
数据表示
我们使用 Apache Arrow 作为底层语言无关的数据布局。对象存储在两个部分中:一个**模式**和一个**数据块**。在高层次上,数据块大致是对象中递归包含的所有数据值的扁平化串联,而模式定义了数据块的类型和嵌套结构。
**技术细节:** Python 序列(例如字典、列表、元组、集合)被编码为其他类型的 Arrow 联合数组(例如布尔值、整数、字符串、字节、浮点数、双精度数、date64、张量(即 NumPy 数组)、列表、元组、字典和集合)。嵌套序列使用 Arrow 列表数组 进行编码。所有张量都被收集并追加到序列化对象的末尾,并且联合数组包含对这些张量的引用。
举一个具体的例子,请考虑以下对象。
[(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
它将在 Arrow 中用以下结构表示。
UnionArray(type_ids=[tuple, string, int, int, ndarray],
tuples=ListArray(offsets=[0, 2],
UnionArray(type_ids=[int, int],
ints=[1, 2])),
strings=['hello'],
ints=[3, 4],
ndarrays=[<offset of numpy array>])
Arrow 使用 Flatbuffers 对序列化模式进行编码。**仅使用模式,我们就可以计算数据块中每个值的偏移量,而无需扫描数据块**(与 Pickle 不同,这就是实现快速反序列化的原因)。这意味着我们可以在反序列化过程中避免复制或转换大型数组和其他值。张量被附加在联合数组的末尾,可以使用共享内存高效地共享和访问。
请注意,实际对象将在内存中按如下方式布局。

Arrow 序列化表示如下。

参与进来
我们欢迎贡献,尤其是在以下领域。
- 使用 Arrow 的 C++ 和 Java 实现来为 C++ 和 Java 实现这些版本。
- 实现对更多 Python 类型的支持和更好的测试覆盖率。
重现上述图表
作为参考,可以使用以下代码重现这些图表。对 ray.put
和 ray.get
而不是 pyarrow.serialize
和 pyarrow.deserialize
进行基准测试可以得到类似的图表。这些图是在此 提交 中生成的。
import pickle
import pyarrow
import matplotlib.pyplot as plt
import numpy as np
import timeit
def benchmark_object(obj, number=10):
# Time serialization and deserialization for pickle.
pickle_serialize = timeit.timeit(
lambda: pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL),
number=number)
serialized_obj = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
pickle_deserialize = timeit.timeit(lambda: pickle.loads(serialized_obj),
number=number)
# Time serialization and deserialization for Ray.
ray_serialize = timeit.timeit(
lambda: pyarrow.serialize(obj).to_buffer(), number=number)
serialized_obj = pyarrow.serialize(obj).to_buffer()
ray_deserialize = timeit.timeit(
lambda: pyarrow.deserialize(serialized_obj), number=number)
return [[pickle_serialize, pickle_deserialize],
[ray_serialize, ray_deserialize]]
def plot(pickle_times, ray_times, title, i):
fig, ax = plt.subplots()
fig.set_size_inches(3.8, 2.7)
bar_width = 0.35
index = np.arange(2)
opacity = 0.6
plt.bar(index, pickle_times, bar_width,
alpha=opacity, color='r', label='Pickle')
plt.bar(index + bar_width, ray_times, bar_width,
alpha=opacity, color='c', label='Ray')
plt.title(title, fontweight='bold')
plt.ylabel('Time (seconds)', fontsize=10)
labels = ['serialization', 'deserialization']
plt.xticks(index + bar_width / 2, labels, fontsize=10)
plt.legend(fontsize=10, bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.yticks(fontsize=10)
plt.savefig('plot-' + str(i) + '.png', format='png')
test_objects = [
[np.random.randn(50000) for i in range(100)],
{'weight-' + str(i): np.random.randn(50000) for i in range(100)},
{i: set(['string1' + str(i), 'string2' + str(i)]) for i in range(100000)},
[str(i) for i in range(200000)]
]
titles = [
'List of large numpy arrays',
'Dictionary of large numpy arrays',
'Large dictionary of small sets',
'Large list of strings'
]
for i in range(len(test_objects)):
plot(*benchmark_object(test_objects[i]), titles[i], i)