DLPack 协议#
DLPack 协议 是一种稳定的内存数据结构,允许在使用多维数组或张量的主要框架之间交换数据。 它旨在支持跨硬件,这意味着它允许交换 CPU(例如 GPU)以外设备上的数据。
DLPack 协议已被 选为 Python 数组 API 标准 由 Python 数据 API 标准联盟,以便在 Python 生态系统中的数组/张量库之间实现设备感知数据交换。 请参阅有关该标准的更多信息 协议文档,有关 DLPack 的更多信息,请参阅 Python 规范 for DLPack.
PyArrow 中 DLPack 的实现#
DLPack 协议的生产端已针对 pa.Array
实现,可用于在 PyArrow 和其他张量库之间交换数据。 支持的数据类型是整数、无符号整数和浮点数。 该协议没有缺少数据支持,这意味着具有缺失值的 PyArrow 数组无法通过 DLPack 协议传输。 当前,Arrow 的协议实现仅支持 CPU 设备上的数据。
协议的数据交换语法包括
from_dlpack(x)
:使用实现__dlpack__
方法的数组对象,并在共享内存的同时创建一个新数组。__dlpack__(self, stream=None)
和__dlpack_device__
:使用 DLPack 结构生成一个 PyCapsule,该结构在from_dlpack(x)
中被调用。
PyArrow 实现了协议的第二部分(__dlpack__(self, stream=None)
和 __dlpack_device__
),因此可以被实现 from_dlpack
的库使用。
示例#
将 PyArrow CPU 数组转换为 NumPy 数组
>>> import pyarrow as pa
>>> array = pa.array([2, 0, 2, 4])
<pyarrow.lib.Int64Array object at 0x121fd4880>
[
2,
0,
2,
4
]
>>> import numpy as np
>>> np.from_dlpack(array)
array([2, 0, 2, 4])
将 PyArrow CPU 数组转换为 PyTorch 张量
>>> import torch
>>> torch.from_dlpack(array)
tensor([2, 0, 2, 4])
将 PyArrow CPU 数组转换为 JAX 数组
>>> import jax
>>> jax.numpy.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)
>>> jax.dlpack.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)