polars_io/path_utils/
hugging_face.rs

1// Hugging Face path resolution support
2
3use std::borrow::Cow;
4use std::collections::VecDeque;
5use std::path::PathBuf;
6
7use polars_error::{PolarsResult, polars_bail, to_compute_err};
8
9use crate::cloud::{
10    CloudConfig, CloudOptions, Matcher, extract_prefix_expansion,
11    try_build_http_header_map_from_items_slice,
12};
13use crate::path_utils::HiveIdxTracker;
14use crate::pl_async::with_concurrency_budget;
15use crate::prelude::URL_ENCODE_CHAR_SET;
16use crate::utils::decode_json_response;
17
18#[derive(Debug, PartialEq)]
19struct HFPathParts {
20    bucket: String,
21    repository: String,
22    revision: String,
23    /// Path relative to the repository root.
24    path: String,
25}
26
27struct HFRepoLocation {
28    api_base_path: String,
29    download_base_path: String,
30}
31
32impl HFRepoLocation {
33    fn new(bucket: &str, repository: &str, revision: &str) -> Self {
34        let bucket = percent_encode(bucket.as_bytes());
35        let repository = percent_encode(repository.as_bytes());
36
37        // "https://huggingface.co/api/ [datasets | spaces] / {username} / {reponame} / tree / {revision} / {path from root}"
38        let api_base_path = format!(
39            "{}{}{}{}{}{}{}",
40            "https://huggingface.co/api/", bucket, "/", repository, "/tree/", revision, "/"
41        );
42        let download_base_path = format!(
43            "{}{}{}{}{}{}{}",
44            "https://huggingface.co/", bucket, "/", repository, "/resolve/", revision, "/"
45        );
46
47        Self {
48            api_base_path,
49            download_base_path,
50        }
51    }
52
53    fn get_file_uri(&self, rel_path: &str) -> String {
54        format!(
55            "{}{}",
56            self.download_base_path,
57            percent_encode(rel_path.as_bytes())
58        )
59    }
60
61    fn get_api_uri(&self, rel_path: &str) -> String {
62        format!(
63            "{}{}",
64            self.api_base_path,
65            percent_encode(rel_path.as_bytes())
66        )
67    }
68}
69
70impl HFPathParts {
71    /// Extracts path components from a hugging face path:
72    /// `hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}`
73    fn try_from_uri(uri: &str) -> PolarsResult<Self> {
74        let Some(this) = (|| {
75            // hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}
76            //       !>
77            if !uri.starts_with("hf://") {
78                return None;
79            }
80            let uri = &uri[5..];
81
82            // [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}
83            // ^-----------------^   !>
84            let i = memchr::memchr(b'/', uri.as_bytes())?;
85            let bucket = uri.get(..i)?.to_string();
86            let uri = uri.get(1 + i..)?;
87
88            // {username} / {reponame} @ {revision} / {path from root}
89            // ^----------------------------------^   !>
90            let i = memchr::memchr(b'/', uri.as_bytes())?;
91            let i = {
92                // Also handle if they just give the repository, i.e.:
93                // hf:// [datasets | spaces] / {username} / {reponame} @ {revision}
94                let uri = uri.get(1 + i..)?;
95                if uri.is_empty() {
96                    return None;
97                }
98                1 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len())
99            };
100            let repository = uri.get(..i)?;
101            let uri = uri.get(1 + i..).unwrap_or("");
102
103            let (repository, revision) =
104                if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) {
105                    (repository[..i].to_string(), repository[1 + i..].to_string())
106                } else {
107                    // No @revision in uri, default to `main`
108                    (repository.to_string(), "main".to_string())
109                };
110
111            // {path from root}
112            // ^--------------^
113            let path = uri.to_string();
114
115            Some(HFPathParts {
116                bucket,
117                repository,
118                revision,
119                path,
120            })
121        })() else {
122            polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri);
123        };
124
125        const BUCKETS: [&str; 2] = ["datasets", "spaces"];
126        if !BUCKETS.contains(&this.bucket.as_str()) {
127            polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket);
128        }
129
130        Ok(this)
131    }
132}
133
134#[derive(Debug, serde::Deserialize)]
135struct HFAPIResponse {
136    #[serde(rename = "type")]
137    type_: String,
138    path: String,
139    size: u64,
140}
141
142impl HFAPIResponse {
143    fn is_file(&self) -> bool {
144        self.type_ == "file"
145    }
146
147    fn is_directory(&self) -> bool {
148        self.type_ == "directory"
149    }
150}
151
152/// API response is paginated with a `link` header.
153/// * https://huggingface.co/docs/hub/en/api#get-apidatasets
154/// * https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api?apiVersion=2022-11-28#using-link-headers
155struct GetPages<'a> {
156    client: &'a reqwest::Client,
157    uri: Option<String>,
158}
159
160impl GetPages<'_> {
161    async fn next(&mut self) -> Option<PolarsResult<bytes::Bytes>> {
162        let uri = self.uri.take()?;
163
164        Some(
165            async {
166                let resp = with_concurrency_budget(1, || async {
167                    self.client.get(uri).send().await.map_err(to_compute_err)
168                })
169                .await?;
170
171                self.uri = resp
172                    .headers()
173                    .get("link")
174                    .and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes()))
175                    .transpose()?;
176
177                let resp_bytes = resp.bytes().await.map_err(to_compute_err)?;
178
179                Ok(resp_bytes)
180            }
181            .await,
182        )
183    }
184
185    fn find_link(mut link: &[u8], rel: &[u8]) -> Option<PolarsResult<String>> {
186        // "<https://...>; rel=\"next\", <https://...>; rel=\"last\""
187        while !link.is_empty() {
188            let i = memchr::memchr(b'<', link)?;
189            link = link.get(1 + i..)?;
190            let i = memchr::memchr(b'>', link)?;
191            let uri = &link[..i];
192            link = link.get(1 + i..)?;
193
194            while !link.starts_with("rel=\"".as_bytes()) {
195                link = link.get(1..)?
196            }
197
198            // rel="next"
199            link = link.get(5..)?;
200            let i = memchr::memchr(b'"', link)?;
201
202            if &link[..i] == rel {
203                return Some(
204                    std::str::from_utf8(uri)
205                        .map_err(to_compute_err)
206                        .map(ToString::to_string),
207                );
208            }
209        }
210
211        None
212    }
213}
214
215pub(super) async fn expand_paths_hf(
216    paths: &[PathBuf],
217    check_directory_level: bool,
218    cloud_options: Option<&CloudOptions>,
219    glob: bool,
220) -> PolarsResult<(usize, Vec<PathBuf>)> {
221    assert!(!paths.is_empty());
222
223    let client = reqwest::ClientBuilder::new().http1_only().https_only(true);
224
225    let client = if let Some(CloudOptions {
226        config: Some(CloudConfig::Http { headers }),
227        ..
228    }) = cloud_options
229    {
230        client.default_headers(try_build_http_header_map_from_items_slice(
231            headers.as_slice(),
232        )?)
233    } else {
234        client
235    };
236
237    let client = &client.build().unwrap();
238
239    let mut out_paths = vec![];
240    let mut stack = VecDeque::new();
241    let mut entries = vec![];
242    let mut hive_idx_tracker = HiveIdxTracker {
243        idx: usize::MAX,
244        paths,
245        check_directory_level,
246    };
247
248    for (path_idx, path) in paths.iter().enumerate() {
249        let path_parts = &HFPathParts::try_from_uri(path.to_str().unwrap())?;
250        let repo_location = &HFRepoLocation::new(
251            &path_parts.bucket,
252            &path_parts.repository,
253            &path_parts.revision,
254        );
255        let rel_path = path_parts.path.as_str();
256
257        let (prefix, expansion) = if glob {
258            extract_prefix_expansion(rel_path)?
259        } else {
260            (Cow::Owned(path_parts.path.clone()), None)
261        };
262        let expansion_matcher = &if expansion.is_some() {
263            Some(Matcher::new(prefix.to_string(), expansion.as_deref())?)
264        } else {
265            None
266        };
267
268        if !path_parts.path.ends_with("/") && expansion.is_none() {
269            hive_idx_tracker.update(0, path_idx)?;
270            let file_uri = repo_location.get_file_uri(rel_path);
271            let file_uri = file_uri.as_str();
272
273            if with_concurrency_budget(1, || async {
274                client.head(file_uri).send().await.map_err(to_compute_err)
275            })
276            .await?
277            .status()
278                == 200
279            {
280                out_paths.push(PathBuf::from(file_uri));
281                continue;
282            }
283        }
284
285        hive_idx_tracker.update(repo_location.get_file_uri(rel_path).len(), path_idx)?;
286
287        assert!(stack.is_empty());
288        stack.push_back(prefix.into_owned());
289
290        while let Some(rel_path) = stack.pop_front() {
291            assert!(entries.is_empty());
292
293            let uri = repo_location.get_api_uri(rel_path.as_str());
294            let mut gp = GetPages {
295                uri: Some(uri),
296                client,
297            };
298
299            if let Some(matcher) = expansion_matcher {
300                while let Some(bytes) = gp.next().await {
301                    let bytes = bytes?;
302                    let bytes = bytes.as_ref();
303                    let response: Vec<HFAPIResponse> = decode_json_response(bytes)?;
304                    entries.extend(response.into_iter().filter(|x| {
305                        !x.is_file() || (x.size > 0 && matcher.is_matching(x.path.as_str()))
306                    }));
307                }
308            } else {
309                while let Some(bytes) = gp.next().await {
310                    let bytes = bytes?;
311                    let bytes = bytes.as_ref();
312                    let response: Vec<HFAPIResponse> = decode_json_response(bytes)?;
313                    entries.extend(response.into_iter().filter(|x| !x.is_file() || x.size > 0));
314                }
315            }
316
317            entries.sort_unstable_by(|a, b| a.path.as_str().partial_cmp(b.path.as_str()).unwrap());
318
319            for e in entries.drain(..) {
320                if e.is_file() {
321                    out_paths.push(PathBuf::from(repo_location.get_file_uri(&e.path)));
322                } else if e.is_directory() {
323                    stack.push_back(e.path);
324                }
325            }
326        }
327    }
328
329    Ok((hive_idx_tracker.idx, out_paths))
330}
331
332fn percent_encode(bytes: &[u8]) -> percent_encoding::PercentEncode {
333    percent_encoding::percent_encode(bytes, URL_ENCODE_CHAR_SET)
334}
335
336mod tests {
337
338    #[test]
339    fn test_hf_path_from_uri() {
340        use super::HFPathParts;
341
342        let uri = "hf://datasets/pola-rs/polars/README.md";
343        let expect = HFPathParts {
344            bucket: "datasets".into(),
345            repository: "pola-rs/polars".into(),
346            revision: "main".into(),
347            path: "README.md".into(),
348        };
349
350        assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
351
352        let uri = "hf://spaces/pola-rs/polars@~parquet/";
353        let expect = HFPathParts {
354            bucket: "spaces".into(),
355            repository: "pola-rs/polars".into(),
356            revision: "~parquet".into(),
357            path: "".into(),
358        };
359
360        assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
361
362        let uri = "hf://spaces/pola-rs/polars@~parquet";
363        let expect = HFPathParts {
364            bucket: "spaces".into(),
365            repository: "pola-rs/polars".into(),
366            revision: "~parquet".into(),
367            path: "".into(),
368        };
369
370        assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
371
372        for uri in [
373            "://",
374            "s3://",
375            "https://",
376            "hf://",
377            "hf:///",
378            "hf:////",
379            "hf://datasets/a",
380            "hf://datasets/a/",
381            "hf://bucket/a/b/c", // Invalid bucket name
382        ] {
383            let out = HFPathParts::try_from_uri(uri);
384            if out.is_err() {
385                continue;
386            }
387            panic!("expected err result for uri {} instead of {:?}", uri, out);
388        }
389    }
390
391    #[test]
392    fn test_get_pages_find_next_link() {
393        use super::GetPages;
394        let link = r#"<https://api.github.com/repositories/263727855/issues?page=3>; rel="next", <https://api.github.com/repositories/263727855/issues?page=7>; rel="last""#.as_bytes();
395
396        assert_eq!(
397            GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap),
398            Some("https://api.github.com/repositories/263727855/issues?page=3".into()),
399        );
400
401        assert_eq!(
402            GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap),
403            Some("https://api.github.com/repositories/263727855/issues?page=7".into()),
404        );
405
406        assert_eq!(
407            GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap),
408            None,
409        );
410    }
411}