Skip to main content

polars_io/path_utils/
hugging_face.rs

1// Hugging Face path resolution support
2
3use std::borrow::Cow;
4
5use polars_error::{PolarsResult, polars_bail, to_compute_err};
6use polars_utils::pl_path::PlRefPath;
7
8use crate::cloud::{
9    CloudConfig, CloudOptions, Matcher, USER_AGENT, extract_prefix_expansion,
10    try_build_http_header_map_from_items_slice,
11};
12use crate::path_utils::HiveIdxTracker;
13use crate::pl_async::with_concurrency_budget;
14use crate::utils::{URL_ENCODE_CHARSET, decode_json_response};
15
16/// Percent-encoding character set for HF Hub paths.
17///
18/// This is URL_ENCODE_CHARSET with slashes preserved - by not encoding slashes,
19/// the API request will be counted under a higher "resolvers" ratelimit of (3000/5min)
20/// compared to the default "pages" limit of (100/5min limit).
21///
22/// ref <https://github.com/pola-rs/polars/issues/25389>
23const HF_PATH_ENCODE_CHARSET: &percent_encoding::AsciiSet = &URL_ENCODE_CHARSET.remove(b'/');
24
25#[derive(Debug, PartialEq)]
26struct HFPathParts {
27    bucket: String,
28    repository: String,
29    revision: String,
30    /// Path relative to the repository root.
31    path: String,
32}
33
34struct HFRepoLocation {
35    api_base_path: String,
36    download_base_path: String,
37}
38
39impl HFRepoLocation {
40    fn new(bucket: &str, repository: &str, revision: &str) -> Self {
41        // * Don't percent-encode bucket/repository - they are path segments where
42        //   slashes are separators. E.g. "HuggingFaceFW/fineweb-2" must stay as-is.
43        // * DO encode revision - slashes in revisions like "refs/convert/parquet"
44        //   are part of the revision name, not path separators.
45        //   See: https://github.com/pola-rs/polars/issues/25389
46        let encoded_revision =
47            percent_encoding::percent_encode(revision.as_bytes(), URL_ENCODE_CHARSET);
48        let api_base_path = format!(
49            "https://huggingface.co/api/{}/{}/tree/{}/",
50            bucket, repository, encoded_revision
51        );
52        let download_base_path = format!(
53            "https://huggingface.co/{}/{}/resolve/{}/",
54            bucket, repository, encoded_revision
55        );
56
57        Self {
58            api_base_path,
59            download_base_path,
60        }
61    }
62
63    fn get_file_uri(&self, rel_path: &str) -> String {
64        format!(
65            "{}{}",
66            self.download_base_path,
67            percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
68        )
69    }
70
71    fn get_api_uri(&self, rel_path: &str) -> String {
72        format!(
73            "{}{}",
74            self.api_base_path,
75            percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
76        )
77    }
78}
79
80impl HFPathParts {
81    /// Extracts path components from a hugging face path:
82    /// `hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}`
83    fn try_from_uri(uri: &str) -> PolarsResult<Self> {
84        let Some(this) = (|| {
85            // hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}
86            //       !>
87            if !uri.starts_with("hf://") {
88                return None;
89            }
90            let uri = &uri[5..];
91
92            // [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}
93            // ^-----------------^   !>
94            let i = memchr::memchr(b'/', uri.as_bytes())?;
95            let bucket = uri.get(..i)?.to_string();
96            let uri = uri.get(1 + i..)?;
97
98            // {username} / {reponame} @ {revision} / {path from root}
99            // ^----------------------------------^   !>
100            let i = memchr::memchr(b'/', uri.as_bytes())?;
101            let i = {
102                // Also handle if they just give the repository, i.e.:
103                // hf:// [datasets | spaces] / {username} / {reponame} @ {revision}
104                let uri = uri.get(1 + i..)?;
105                if uri.is_empty() {
106                    return None;
107                }
108                1 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len())
109            };
110            let repository = uri.get(..i)?;
111            let uri = uri.get(1 + i..).unwrap_or("");
112
113            let (repository, revision) =
114                if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) {
115                    (repository[..i].to_string(), repository[1 + i..].to_string())
116                } else {
117                    // No @revision in uri, default to `main`
118                    (repository.to_string(), "main".to_string())
119                };
120
121            // {path from root}
122            // ^--------------^
123            let path = uri.to_string();
124
125            Some(HFPathParts {
126                bucket,
127                repository,
128                revision,
129                path,
130            })
131        })() else {
132            polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri);
133        };
134
135        const BUCKETS: [&str; 2] = ["datasets", "spaces"];
136        if !BUCKETS.contains(&this.bucket.as_str()) {
137            polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket);
138        }
139
140        Ok(this)
141    }
142}
143
144#[derive(Debug, serde::Deserialize)]
145struct HFAPIResponse {
146    #[serde(rename = "type")]
147    type_: String,
148    path: String,
149    size: u64,
150}
151
152impl HFAPIResponse {
153    fn is_file(&self) -> bool {
154        self.type_ == "file"
155    }
156}
157
158/// API response is paginated with a `link` header.
159/// * https://huggingface.co/docs/hub/en/api#get-apidatasets
160/// * https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api?apiVersion=2022-11-28#using-link-headers
161struct GetPages<'a> {
162    client: &'a reqwest::Client,
163    uri: Option<String>,
164}
165
166impl GetPages<'_> {
167    async fn next(&mut self) -> Option<PolarsResult<bytes::Bytes>> {
168        let uri = self.uri.take()?;
169
170        Some(
171            async {
172                let resp = with_concurrency_budget(1, || async {
173                    self.client.get(uri).send().await.map_err(to_compute_err)
174                })
175                .await?;
176
177                self.uri = resp
178                    .headers()
179                    .get("link")
180                    .and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes()))
181                    .transpose()?;
182
183                let resp_bytes = resp.bytes().await.map_err(to_compute_err)?;
184
185                Ok(resp_bytes)
186            }
187            .await,
188        )
189    }
190
191    fn find_link(mut link: &[u8], rel: &[u8]) -> Option<PolarsResult<String>> {
192        // "<https://...>; rel=\"next\", <https://...>; rel=\"last\""
193        while !link.is_empty() {
194            let i = memchr::memchr(b'<', link)?;
195            link = link.get(1 + i..)?;
196            let i = memchr::memchr(b'>', link)?;
197            let uri = &link[..i];
198            link = link.get(1 + i..)?;
199
200            while !link.starts_with("rel=\"".as_bytes()) {
201                link = link.get(1..)?
202            }
203
204            // rel="next"
205            link = link.get(5..)?;
206            let i = memchr::memchr(b'"', link)?;
207
208            if &link[..i] == rel {
209                return Some(
210                    std::str::from_utf8(uri)
211                        .map_err(to_compute_err)
212                        .map(ToString::to_string),
213                );
214            }
215        }
216
217        None
218    }
219}
220
221pub(super) async fn expand_paths_hf(
222    paths: &[PlRefPath],
223    check_directory_level: bool,
224    cloud_options: &Option<CloudOptions>,
225    glob: bool,
226) -> PolarsResult<(usize, Vec<PlRefPath>)> {
227    assert!(!paths.is_empty());
228
229    let client = reqwest::ClientBuilder::new()
230        .user_agent(USER_AGENT)
231        .http1_only()
232        .https_only(true);
233
234    let client = if let Some(CloudOptions {
235        config: Some(CloudConfig::Http { headers }),
236        ..
237    }) = cloud_options
238    {
239        client.default_headers(try_build_http_header_map_from_items_slice(
240            headers.as_slice(),
241        )?)
242    } else {
243        client
244    };
245
246    let client = &client.build().unwrap();
247
248    let mut out_paths = vec![];
249    let mut hive_idx_tracker = HiveIdxTracker {
250        idx: usize::MAX,
251        paths,
252        check_directory_level,
253    };
254
255    for (path_idx, path) in paths.iter().enumerate() {
256        let path_parts = &HFPathParts::try_from_uri(path.as_str())?;
257        let repo_location = &HFRepoLocation::new(
258            &path_parts.bucket,
259            &path_parts.repository,
260            &path_parts.revision,
261        );
262        let rel_path = path_parts.path.as_str();
263
264        let (prefix, expansion) = if glob {
265            extract_prefix_expansion(rel_path)?
266        } else {
267            (Cow::Owned(path_parts.path.clone()), None)
268        };
269        let expansion_matcher = &if expansion.is_some() {
270            Some(Matcher::new(prefix.to_string(), expansion.as_deref())?)
271        } else {
272            None
273        };
274
275        let file_uri = repo_location.get_file_uri(rel_path);
276
277        if !path_parts.path.ends_with("/") && expansion.is_none() {
278            // Confirm that this is a file using a HEAD request.
279            if with_concurrency_budget(1, || async {
280                client.head(&file_uri).send().await.map_err(to_compute_err)
281            })
282            .await?
283            .status()
284                == 200
285            {
286                hive_idx_tracker.update(0, path_idx)?;
287                out_paths.push(PlRefPath::new(file_uri));
288                continue;
289            }
290        }
291
292        hive_idx_tracker.update(file_uri.len(), path_idx)?;
293
294        let uri = format!("{}?recursive=true", repo_location.get_api_uri(&prefix));
295        let mut gp = GetPages {
296            uri: Some(uri),
297            client,
298        };
299
300        let sort_start_idx = out_paths.len();
301
302        while let Some(bytes) = gp.next().await {
303            let bytes = bytes?;
304            let response: Vec<HFAPIResponse> = decode_json_response(bytes.as_ref())?;
305
306            for entry in response {
307                // Only include files with size > 0
308                if entry.is_file() && entry.size > 0 {
309                    // If we have a glob pattern, filter by it; otherwise include all files
310                    let matches = if let Some(matcher) = expansion_matcher {
311                        matcher.is_matching(entry.path.as_str())
312                    } else {
313                        true
314                    };
315
316                    if matches {
317                        out_paths.push(PlRefPath::new(repo_location.get_file_uri(&entry.path)));
318                    }
319                }
320            }
321        }
322
323        if let Some(mut_slice) = out_paths.get_mut(sort_start_idx..) {
324            <[PlRefPath]>::sort_unstable(mut_slice);
325        }
326    }
327
328    Ok((hive_idx_tracker.idx, out_paths))
329}
330
331mod tests {
332
333    #[test]
334    fn test_hf_path_from_uri() {
335        use super::HFPathParts;
336
337        let uri = "hf://datasets/pola-rs/polars/README.md";
338        let expect = HFPathParts {
339            bucket: "datasets".into(),
340            repository: "pola-rs/polars".into(),
341            revision: "main".into(),
342            path: "README.md".into(),
343        };
344
345        assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
346
347        let uri = "hf://spaces/pola-rs/polars@~parquet/";
348        let expect = HFPathParts {
349            bucket: "spaces".into(),
350            repository: "pola-rs/polars".into(),
351            revision: "~parquet".into(),
352            path: "".into(),
353        };
354
355        assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
356
357        let uri = "hf://spaces/pola-rs/polars@~parquet";
358        let expect = HFPathParts {
359            bucket: "spaces".into(),
360            repository: "pola-rs/polars".into(),
361            revision: "~parquet".into(),
362            path: "".into(),
363        };
364
365        assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
366
367        for uri in [
368            "://",
369            "s3://",
370            "https://",
371            "hf://",
372            "hf:///",
373            "hf:////",
374            "hf://datasets/a",
375            "hf://datasets/a/",
376            "hf://bucket/a/b/c", // Invalid bucket name
377        ] {
378            let out = HFPathParts::try_from_uri(uri);
379            if out.is_err() {
380                continue;
381            }
382            panic!("expected err result for uri {uri} instead of {out:?}");
383        }
384    }
385
386    #[test]
387    fn test_get_pages_find_next_link() {
388        use super::GetPages;
389        let link = r#"<https://api.github.com/repositories/263727855/issues?page=3>; rel="next", <https://api.github.com/repositories/263727855/issues?page=7>; rel="last""#.as_bytes();
390
391        assert_eq!(
392            GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap),
393            Some("https://api.github.com/repositories/263727855/issues?page=3".into()),
394        );
395
396        assert_eq!(
397            GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap),
398            Some("https://api.github.com/repositories/263727855/issues?page=7".into()),
399        );
400
401        assert_eq!(
402            GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap),
403            None,
404        );
405    }
406
407    #[test]
408    fn test_hf_url_encoding() {
409        // Verify URLs preserve slashes (don't encode as %2F) but encode special chars.
410        // Slashes must remain for correct rate limit classification by HF Hub.
411        // Special chars (spaces, colons) must be encoded for file downloads to work.
412        // See: https://github.com/pola-rs/polars/issues/25389
413        use super::HFRepoLocation;
414
415        let loc = HFRepoLocation::new("datasets", "HuggingFaceFW/fineweb-2", "main");
416
417        // Check base paths don't encode slashes
418        assert_eq!(
419            loc.api_base_path,
420            "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-2/tree/main/"
421        );
422        assert_eq!(
423            loc.download_base_path,
424            "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/"
425        );
426
427        // Check file URIs preserve slashes in paths
428        let file_uri = loc.get_file_uri("data/aai_Latn/train/000_00000.parquet");
429        assert_eq!(
430            file_uri,
431            "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/data/aai_Latn/train/000_00000.parquet"
432        );
433
434        // Check that special characters ARE encoded (spaces -> %20, colons -> %3A)
435        // This is needed for hive-partitioned paths like "date2=2023-01-01 00:00:00.000000"
436        let file_uri = loc.get_file_uri(
437            "hive_dates/date1=2024-01-01/date2=2023-01-01 00:00:00.000000/00000000.parquet",
438        );
439        assert_eq!(
440            file_uri,
441            "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/hive_dates/date1%3D2024-01-01/date2%3D2023-01-01%2000%3A00%3A00.000000/00000000.parquet"
442        );
443
444        // Check that brackets are encoded ([ -> %5B, ] -> %5D)
445        let file_uri = loc.get_file_uri("special-chars/[*.parquet");
446        assert_eq!(
447            file_uri,
448            "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/special-chars/%5B%2A.parquet"
449        );
450
451        // Check that revision slashes ARE encoded (they're part of the revision name)
452        // e.g. "refs/convert/parquet" -> "refs%2Fconvert%2Fparquet"
453        let loc = HFRepoLocation::new("datasets", "user/repo", "refs/convert/parquet");
454        assert_eq!(
455            loc.api_base_path,
456            "https://huggingface.co/api/datasets/user/repo/tree/refs%2Fconvert%2Fparquet/"
457        );
458        assert_eq!(
459            loc.download_base_path,
460            "https://huggingface.co/datasets/user/repo/resolve/refs%2Fconvert%2Fparquet/"
461        );
462    }
463}