polars_utils/
mmap.rs

1use std::fs::File;
2use std::io;
3
4pub use memmap::Mmap;
5
6mod private {
7    use std::fs::File;
8    use std::ops::Deref;
9    use std::sync::Arc;
10
11    use polars_error::PolarsResult;
12
13    use super::MMapSemaphore;
14    use crate::mem::prefetch::prefetch_l2;
15
16    /// A read-only reference to a slice of memory that can potentially be memory-mapped.
17    ///
18    /// A reference count is kept to the underlying buffer to ensure the memory is kept alive.
19    /// [`MemSlice::slice`] can be used to slice the memory in a zero-copy manner.
20    ///
21    /// This still owns the all the original memory and therefore should probably not be a long-lasting
22    /// structure.
23    #[derive(Clone, Debug)]
24    pub struct MemSlice {
25        // Store the `&[u8]` to make the `Deref` free.
26        // `slice` is not 'static - it is backed by `inner`. This is safe as long as `slice` is not
27        // directly accessed, and we are in a private module to guarantee that. Access should only
28        // be done through `Deref<Target = [u8]>`, which automatically gives the correct lifetime.
29        slice: &'static [u8],
30        #[allow(unused)]
31        inner: MemSliceInner,
32    }
33
34    /// Keeps the underlying buffer alive. This should be cheaply cloneable.
35    #[derive(Clone, Debug)]
36    #[allow(unused)]
37    enum MemSliceInner {
38        Bytes(bytes::Bytes), // Separate because it does atomic refcounting internally
39        Arc(Arc<dyn std::fmt::Debug + Send + Sync>),
40    }
41
42    impl Deref for MemSlice {
43        type Target = [u8];
44
45        #[inline(always)]
46        fn deref(&self) -> &Self::Target {
47            self.slice
48        }
49    }
50
51    impl AsRef<[u8]> for MemSlice {
52        #[inline(always)]
53        fn as_ref(&self) -> &[u8] {
54            self.slice
55        }
56    }
57
58    impl Default for MemSlice {
59        fn default() -> Self {
60            Self::from_bytes(bytes::Bytes::new())
61        }
62    }
63
64    impl From<Vec<u8>> for MemSlice {
65        fn from(value: Vec<u8>) -> Self {
66            Self::from_vec(value)
67        }
68    }
69
70    impl MemSlice {
71        pub const EMPTY: Self = Self::from_static(&[]);
72
73        /// Copy the contents into a new owned `Vec`
74        #[inline(always)]
75        pub fn to_vec(self) -> Vec<u8> {
76            <[u8]>::to_vec(self.deref())
77        }
78
79        /// Construct a `MemSlice` from an existing `Vec<u8>`. This is zero-copy.
80        #[inline]
81        pub fn from_vec(v: Vec<u8>) -> Self {
82            Self::from_bytes(bytes::Bytes::from(v))
83        }
84
85        /// Construct a `MemSlice` from [`bytes::Bytes`]. This is zero-copy.
86        #[inline]
87        pub fn from_bytes(bytes: bytes::Bytes) -> Self {
88            Self {
89                slice: unsafe { std::mem::transmute::<&[u8], &'static [u8]>(bytes.as_ref()) },
90                inner: MemSliceInner::Bytes(bytes),
91            }
92        }
93
94        #[inline]
95        pub fn from_mmap(mmap: Arc<MMapSemaphore>) -> Self {
96            Self {
97                slice: unsafe {
98                    std::mem::transmute::<&[u8], &'static [u8]>(mmap.as_ref().as_ref())
99                },
100                inner: MemSliceInner::Arc(mmap),
101            }
102        }
103
104        #[inline]
105        pub fn from_arc<T>(slice: &[u8], arc: Arc<T>) -> Self
106        where
107            T: std::fmt::Debug + Send + Sync + 'static,
108        {
109            Self {
110                slice: unsafe { std::mem::transmute::<&[u8], &'static [u8]>(slice) },
111                inner: MemSliceInner::Arc(arc),
112            }
113        }
114
115        #[inline]
116        pub fn from_file(file: &File) -> PolarsResult<Self> {
117            let mmap = MMapSemaphore::new_from_file(file)?;
118            Ok(Self::from_mmap(Arc::new(mmap)))
119        }
120
121        /// Construct a `MemSlice` that simply wraps around a `&[u8]`.
122        #[inline]
123        pub const fn from_static(slice: &'static [u8]) -> Self {
124            let inner = MemSliceInner::Bytes(bytes::Bytes::from_static(slice));
125            Self { slice, inner }
126        }
127
128        /// Attempt to prefetch the memory belonging to to this [`MemSlice`]
129        #[inline]
130        pub fn prefetch(&self) {
131            prefetch_l2(self.as_ref());
132        }
133
134        /// # Panics
135        /// Panics if range is not in bounds.
136        #[inline]
137        #[track_caller]
138        pub fn slice(&self, range: std::ops::Range<usize>) -> Self {
139            let mut out = self.clone();
140            out.slice = &out.slice[range];
141            out
142        }
143    }
144
145    impl From<bytes::Bytes> for MemSlice {
146        fn from(value: bytes::Bytes) -> Self {
147            Self::from_bytes(value)
148        }
149    }
150}
151
152use memmap::MmapOptions;
153use polars_error::PolarsResult;
154#[cfg(target_family = "unix")]
155use polars_error::polars_bail;
156pub use private::MemSlice;
157
158/// A cursor over a [`MemSlice`].
159#[derive(Debug, Clone)]
160pub struct MemReader {
161    data: MemSlice,
162    position: usize,
163}
164
165impl MemReader {
166    pub fn new(data: MemSlice) -> Self {
167        Self { data, position: 0 }
168    }
169
170    #[inline(always)]
171    pub fn remaining_len(&self) -> usize {
172        self.data.len() - self.position
173    }
174
175    #[inline(always)]
176    pub fn total_len(&self) -> usize {
177        self.data.len()
178    }
179
180    #[inline(always)]
181    pub fn position(&self) -> usize {
182        self.position
183    }
184
185    /// Construct a `MemSlice` from an existing `Vec<u8>`. This is zero-copy.
186    #[inline(always)]
187    pub fn from_vec(v: Vec<u8>) -> Self {
188        Self::new(MemSlice::from_vec(v))
189    }
190
191    /// Construct a `MemSlice` from [`bytes::Bytes`]. This is zero-copy.
192    #[inline(always)]
193    pub fn from_bytes(bytes: bytes::Bytes) -> Self {
194        Self::new(MemSlice::from_bytes(bytes))
195    }
196
197    // Construct a `MemSlice` that simply wraps around a `&[u8]`. The caller must ensure the
198    /// slice outlives the returned `MemSlice`.
199    #[inline]
200    pub fn from_slice(slice: &'static [u8]) -> Self {
201        Self::new(MemSlice::from_static(slice))
202    }
203
204    #[inline(always)]
205    pub fn from_reader<R: io::Read>(mut reader: R) -> io::Result<Self> {
206        let mut vec = Vec::new();
207        reader.read_to_end(&mut vec)?;
208        Ok(Self::from_vec(vec))
209    }
210
211    #[inline(always)]
212    pub fn read_slice(&mut self, n: usize) -> MemSlice {
213        let start = self.position;
214        let end = usize::min(self.position + n, self.data.len());
215        self.position = end;
216        self.data.slice(start..end)
217    }
218}
219
220impl From<MemSlice> for MemReader {
221    fn from(data: MemSlice) -> Self {
222        Self { data, position: 0 }
223    }
224}
225
226impl io::Read for MemReader {
227    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
228        let n = usize::min(buf.len(), self.remaining_len());
229        buf[..n].copy_from_slice(&self.data[self.position..self.position + n]);
230        self.position += n;
231        Ok(n)
232    }
233}
234
235impl io::Seek for MemReader {
236    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
237        let position = match pos {
238            io::SeekFrom::Start(position) => usize::min(position as usize, self.total_len()),
239            io::SeekFrom::End(offset) => {
240                let Some(position) = self.total_len().checked_add_signed(offset as isize) else {
241                    return Err(io::Error::other("Seek before to before buffer"));
242                };
243
244                position
245            },
246            io::SeekFrom::Current(offset) => {
247                let Some(position) = self.position.checked_add_signed(offset as isize) else {
248                    return Err(io::Error::other("Seek before to before buffer"));
249                };
250
251                position
252            },
253        };
254
255        self.position = position;
256
257        Ok(position as u64)
258    }
259}
260
261// Keep track of memory mapped files so we don't write to them while reading
262// Use a btree as it uses less memory than a hashmap and this thing never shrinks.
263// Write handle in Windows is exclusive, so this is only necessary in Unix.
264#[cfg(target_family = "unix")]
265static MEMORY_MAPPED_FILES: std::sync::LazyLock<
266    std::sync::Mutex<std::collections::BTreeMap<(u64, u64), u32>>,
267> = std::sync::LazyLock::new(|| std::sync::Mutex::new(Default::default()));
268
269#[derive(Debug)]
270pub struct MMapSemaphore {
271    #[cfg(target_family = "unix")]
272    key: (u64, u64),
273    mmap: Mmap,
274}
275
276impl MMapSemaphore {
277    pub fn new_from_file_with_options(
278        file: &File,
279        options: MmapOptions,
280    ) -> PolarsResult<MMapSemaphore> {
281        let mmap = unsafe { options.map(file) }?;
282
283        #[cfg(target_family = "unix")]
284        {
285            // FIXME: We aren't handling the case where the file is already open in write-mode here.
286
287            use std::os::unix::fs::MetadataExt;
288            let metadata = file.metadata()?;
289
290            let mut guard = MEMORY_MAPPED_FILES.lock().unwrap();
291            let key = (metadata.dev(), metadata.ino());
292            match guard.entry(key) {
293                std::collections::btree_map::Entry::Occupied(mut e) => *e.get_mut() += 1,
294                std::collections::btree_map::Entry::Vacant(e) => _ = e.insert(1),
295            }
296            Ok(Self { key, mmap })
297        }
298
299        #[cfg(not(target_family = "unix"))]
300        Ok(Self { mmap })
301    }
302
303    pub fn new_from_file(file: &File) -> PolarsResult<MMapSemaphore> {
304        Self::new_from_file_with_options(file, MmapOptions::default())
305    }
306
307    pub fn as_ptr(&self) -> *const u8 {
308        self.mmap.as_ptr()
309    }
310}
311
312impl AsRef<[u8]> for MMapSemaphore {
313    #[inline]
314    fn as_ref(&self) -> &[u8] {
315        self.mmap.as_ref()
316    }
317}
318
319#[cfg(target_family = "unix")]
320impl Drop for MMapSemaphore {
321    fn drop(&mut self) {
322        let mut guard = MEMORY_MAPPED_FILES.lock().unwrap();
323        if let std::collections::btree_map::Entry::Occupied(mut e) = guard.entry(self.key) {
324            let v = e.get_mut();
325            *v -= 1;
326
327            if *v == 0 {
328                e.remove_entry();
329            }
330        }
331    }
332}
333
334pub fn ensure_not_mapped(
335    #[cfg_attr(not(target_family = "unix"), allow(unused))] file_md: &std::fs::Metadata,
336) -> PolarsResult<()> {
337    // TODO: We need to actually register that this file has been write-opened and prevent
338    // read-opening this file based on that.
339    #[cfg(target_family = "unix")]
340    {
341        use std::os::unix::fs::MetadataExt;
342        let guard = MEMORY_MAPPED_FILES.lock().unwrap();
343        if guard.contains_key(&(file_md.dev(), file_md.ino())) {
344            polars_bail!(ComputeError: "cannot write to file: already memory mapped");
345        }
346    }
347    Ok(())
348}
349
350mod tests {
351    #[test]
352    fn test_mem_slice_zero_copy() {
353        use std::sync::Arc;
354
355        use super::MemSlice;
356
357        {
358            let vec = vec![1u8, 2, 3, 4, 5];
359            let ptr = vec.as_ptr();
360
361            let mem_slice = MemSlice::from_vec(vec);
362            let ptr_out = mem_slice.as_ptr();
363
364            assert_eq!(ptr_out, ptr);
365        }
366
367        {
368            let mut vec = vec![1u8, 2, 3, 4, 5];
369            vec.truncate(2);
370            let ptr = vec.as_ptr();
371
372            let mem_slice = MemSlice::from_vec(vec);
373            let ptr_out = mem_slice.as_ptr();
374
375            assert_eq!(ptr_out, ptr);
376        }
377
378        {
379            let bytes = bytes::Bytes::from(vec![1u8, 2, 3, 4, 5]);
380            let ptr = bytes.as_ptr();
381
382            let mem_slice = MemSlice::from_bytes(bytes);
383            let ptr_out = mem_slice.as_ptr();
384
385            assert_eq!(ptr_out, ptr);
386        }
387
388        {
389            use crate::mmap::MMapSemaphore;
390
391            let path = "../../examples/datasets/foods1.csv";
392            let file = std::fs::File::open(path).unwrap();
393            let mmap = MMapSemaphore::new_from_file(&file).unwrap();
394            let ptr = mmap.as_ptr();
395
396            let mem_slice = MemSlice::from_mmap(Arc::new(mmap));
397            let ptr_out = mem_slice.as_ptr();
398
399            assert_eq!(ptr_out, ptr);
400        }
401
402        {
403            let vec = vec![1u8, 2, 3, 4, 5];
404            let slice = vec.as_slice();
405            let ptr = slice.as_ptr();
406
407            let mem_slice = MemSlice::from_static(unsafe {
408                std::mem::transmute::<&[u8], &'static [u8]>(slice)
409            });
410            let ptr_out = mem_slice.as_ptr();
411
412            assert_eq!(ptr_out, ptr);
413        }
414    }
415
416    #[test]
417    fn test_mem_slice_slicing() {
418        use super::MemSlice;
419
420        {
421            let vec = vec![1u8, 2, 3, 4, 5];
422            let slice = vec.as_slice();
423
424            let mem_slice = MemSlice::from_static(unsafe {
425                std::mem::transmute::<&[u8], &'static [u8]>(slice)
426            });
427
428            let out = &*mem_slice.slice(3..5);
429            assert_eq!(out, &slice[3..5]);
430            assert_eq!(out.as_ptr(), slice[3..5].as_ptr());
431        }
432    }
433}