diff --git a/capnp/lib/capnp.pyx b/capnp/lib/capnp.pyx index d5c6d1d..8c531a8 100644 --- a/capnp/lib/capnp.pyx +++ b/capnp/lib/capnp.pyx @@ -29,6 +29,7 @@ import threading as _threading import socket as _socket import random as _random import collections as _collections +import mmap as _mmap _CAPNP_VERSION_MAJOR = capnp.CAPNP_VERSION_MAJOR _CAPNP_VERSION_MINOR = capnp.CAPNP_VERSION_MINOR @@ -277,6 +278,8 @@ ctypedef fused PromiseTypes: cdef extern from "Python.h": cdef int PyObject_AsReadBuffer(object, void** b, Py_ssize_t* c) cdef int PyObject_AsWriteBuffer(object, void** b, Py_ssize_t* c) + cdef int PyObject_GetBuffer(object, Py_buffer *view, int flags) + cdef void PyBuffer_Release(Py_buffer *view) # Templated classes are weird in cython. I couldn't put it in a pxd header for some reason cdef extern from "capnp/list.h" namespace " ::capnp": @@ -3682,6 +3685,19 @@ cdef class _AlignedBuffer: if self.allocated: free(self.buf) + +@cython.internal +cdef class _BufferView: + cdef Py_buffer view + cdef char * buf + + def __init__(self, other): + PyObject_GetBuffer(other, &self.view, 0) + self.buf = self.view.buf + + def __dealloc__(self): + PyBuffer_Release(&self.view) + @cython.internal cdef class _FlatArrayMessageReader(_MessageReader): cdef object _object_to_pin @@ -3698,13 +3714,19 @@ cdef class _FlatArrayMessageReader(_MessageReader): if sz % 8 != 0: raise ValueError("input length must be a multiple of eight bytes") - cdef char * ptr = buf - if (ptr) % 8 != 0: - aligned = _AlignedBuffer(buf) - ptr = aligned.buf - self._object_to_pin = aligned + cdef char * ptr + if type(buf) == _mmap.mmap: + view = _BufferView(buf) + ptr = view.view.buf + self._object_to_pin = view else: - self._object_to_pin = buf + ptr = buf + if (ptr) % 8 != 0: + aligned = _AlignedBuffer(buf) + ptr = aligned.buf + self._object_to_pin = aligned + else: + self._object_to_pin = buf self.thisptr = new schema_cpp.FlatArrayMessageReader(schema_cpp.WordArrayPtr(ptr, sz//8)) diff --git a/test/test_serialization.py b/test/test_serialization.py index a1893e7..87c77cf 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -5,6 +5,7 @@ import platform import test_regression import tempfile import pickle +import mmap this_dir = os.path.dirname(__file__) @@ -40,6 +41,20 @@ def test_roundtrip_bytes(all_types): msg = all_types.TestAllTypes.from_bytes(message_bytes) test_regression.check_all_types(msg) +def test_roundtrip_bytes_mmap(all_types): + msg = all_types.TestAllTypes.new_message() + test_regression.init_all_types(msg) + + with tempfile.TemporaryFile() as f: + msg.write(f) + length = f.tell() + + f.seek(0) + memory = mmap.mmap(f.fileno(), length) + + msg = all_types.TestAllTypes.from_bytes(memory) + test_regression.check_all_types(msg) + def test_roundtrip_bytes_packed(all_types): msg = all_types.TestAllTypes.new_message() test_regression.init_all_types(msg)