1use std::borrow::Cow;
4
5use polars_error::{PolarsResult, polars_bail, to_compute_err};
6use polars_utils::pl_path::PlRefPath;
7
8use crate::cloud::{
9 CloudConfig, CloudOptions, Matcher, USER_AGENT, extract_prefix_expansion,
10 try_build_http_header_map_from_items_slice,
11};
12use crate::path_utils::HiveIdxTracker;
13use crate::pl_async::with_concurrency_budget;
14use crate::utils::{URL_ENCODE_CHARSET, decode_json_response};
15
16const HF_PATH_ENCODE_CHARSET: &percent_encoding::AsciiSet = &URL_ENCODE_CHARSET.remove(b'/');
24
25#[derive(Debug, PartialEq)]
26struct HFPathParts {
27 bucket: String,
28 repository: String,
29 revision: String,
30 path: String,
32}
33
34struct HFRepoLocation {
35 api_base_path: String,
36 download_base_path: String,
37}
38
39impl HFRepoLocation {
40 fn new(bucket: &str, repository: &str, revision: &str) -> Self {
41 let encoded_revision =
47 percent_encoding::percent_encode(revision.as_bytes(), URL_ENCODE_CHARSET);
48 let api_base_path = format!(
49 "https://huggingface.co/api/{}/{}/tree/{}/",
50 bucket, repository, encoded_revision
51 );
52 let download_base_path = format!(
53 "https://huggingface.co/{}/{}/resolve/{}/",
54 bucket, repository, encoded_revision
55 );
56
57 Self {
58 api_base_path,
59 download_base_path,
60 }
61 }
62
63 fn get_file_uri(&self, rel_path: &str) -> String {
64 format!(
65 "{}{}",
66 self.download_base_path,
67 percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
68 )
69 }
70
71 fn get_api_uri(&self, rel_path: &str) -> String {
72 format!(
73 "{}{}",
74 self.api_base_path,
75 percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
76 )
77 }
78}
79
80impl HFPathParts {
81 fn try_from_uri(uri: &str) -> PolarsResult<Self> {
84 let Some(this) = (|| {
85 if !uri.starts_with("hf://") {
88 return None;
89 }
90 let uri = &uri[5..];
91
92 let i = memchr::memchr(b'/', uri.as_bytes())?;
95 let bucket = uri.get(..i)?.to_string();
96 let uri = uri.get(1 + i..)?;
97
98 let i = memchr::memchr(b'/', uri.as_bytes())?;
101 let i = {
102 let uri = uri.get(1 + i..)?;
105 if uri.is_empty() {
106 return None;
107 }
108 1 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len())
109 };
110 let repository = uri.get(..i)?;
111 let uri = uri.get(1 + i..).unwrap_or("");
112
113 let (repository, revision) =
114 if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) {
115 (repository[..i].to_string(), repository[1 + i..].to_string())
116 } else {
117 (repository.to_string(), "main".to_string())
119 };
120
121 let path = uri.to_string();
124
125 Some(HFPathParts {
126 bucket,
127 repository,
128 revision,
129 path,
130 })
131 })() else {
132 polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri);
133 };
134
135 const BUCKETS: [&str; 2] = ["datasets", "spaces"];
136 if !BUCKETS.contains(&this.bucket.as_str()) {
137 polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket);
138 }
139
140 Ok(this)
141 }
142}
143
144#[derive(Debug, serde::Deserialize)]
145struct HFAPIResponse {
146 #[serde(rename = "type")]
147 type_: String,
148 path: String,
149 size: u64,
150}
151
152impl HFAPIResponse {
153 fn is_file(&self) -> bool {
154 self.type_ == "file"
155 }
156}
157
158struct GetPages<'a> {
162 client: &'a reqwest::Client,
163 uri: Option<String>,
164}
165
166impl GetPages<'_> {
167 async fn next(&mut self) -> Option<PolarsResult<bytes::Bytes>> {
168 let uri = self.uri.take()?;
169
170 Some(
171 async {
172 let resp = with_concurrency_budget(1, || async {
173 self.client.get(uri).send().await.map_err(to_compute_err)
174 })
175 .await?;
176
177 self.uri = resp
178 .headers()
179 .get("link")
180 .and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes()))
181 .transpose()?;
182
183 let resp_bytes = resp.bytes().await.map_err(to_compute_err)?;
184
185 Ok(resp_bytes)
186 }
187 .await,
188 )
189 }
190
191 fn find_link(mut link: &[u8], rel: &[u8]) -> Option<PolarsResult<String>> {
192 while !link.is_empty() {
194 let i = memchr::memchr(b'<', link)?;
195 link = link.get(1 + i..)?;
196 let i = memchr::memchr(b'>', link)?;
197 let uri = &link[..i];
198 link = link.get(1 + i..)?;
199
200 while !link.starts_with("rel=\"".as_bytes()) {
201 link = link.get(1..)?
202 }
203
204 link = link.get(5..)?;
206 let i = memchr::memchr(b'"', link)?;
207
208 if &link[..i] == rel {
209 return Some(
210 std::str::from_utf8(uri)
211 .map_err(to_compute_err)
212 .map(ToString::to_string),
213 );
214 }
215 }
216
217 None
218 }
219}
220
221pub(super) async fn expand_paths_hf(
222 paths: &[PlRefPath],
223 check_directory_level: bool,
224 cloud_options: &Option<CloudOptions>,
225 glob: bool,
226) -> PolarsResult<(usize, Vec<PlRefPath>)> {
227 assert!(!paths.is_empty());
228
229 let client = reqwest::ClientBuilder::new()
230 .user_agent(USER_AGENT)
231 .http1_only()
232 .https_only(true);
233
234 let client = if let Some(CloudOptions {
235 config: Some(CloudConfig::Http { headers }),
236 ..
237 }) = cloud_options
238 {
239 client.default_headers(try_build_http_header_map_from_items_slice(
240 headers.as_slice(),
241 )?)
242 } else {
243 client
244 };
245
246 let client = &client.build().unwrap();
247
248 let mut out_paths = vec![];
249 let mut hive_idx_tracker = HiveIdxTracker {
250 idx: usize::MAX,
251 paths,
252 check_directory_level,
253 };
254
255 for (path_idx, path) in paths.iter().enumerate() {
256 let path_parts = &HFPathParts::try_from_uri(path.as_str())?;
257 let repo_location = &HFRepoLocation::new(
258 &path_parts.bucket,
259 &path_parts.repository,
260 &path_parts.revision,
261 );
262 let rel_path = path_parts.path.as_str();
263
264 let (prefix, expansion) = if glob {
265 extract_prefix_expansion(rel_path)?
266 } else {
267 (Cow::Owned(path_parts.path.clone()), None)
268 };
269 let expansion_matcher = &if expansion.is_some() {
270 Some(Matcher::new(prefix.to_string(), expansion.as_deref())?)
271 } else {
272 None
273 };
274
275 let file_uri = repo_location.get_file_uri(rel_path);
276
277 if !path_parts.path.ends_with("/") && expansion.is_none() {
278 if with_concurrency_budget(1, || async {
280 client.head(&file_uri).send().await.map_err(to_compute_err)
281 })
282 .await?
283 .status()
284 == 200
285 {
286 hive_idx_tracker.update(0, path_idx)?;
287 out_paths.push(PlRefPath::new(file_uri));
288 continue;
289 }
290 }
291
292 hive_idx_tracker.update(file_uri.len(), path_idx)?;
293
294 let uri = format!("{}?recursive=true", repo_location.get_api_uri(&prefix));
295 let mut gp = GetPages {
296 uri: Some(uri),
297 client,
298 };
299
300 let sort_start_idx = out_paths.len();
301
302 while let Some(bytes) = gp.next().await {
303 let bytes = bytes?;
304 let response: Vec<HFAPIResponse> = decode_json_response(bytes.as_ref())?;
305
306 for entry in response {
307 if entry.is_file() && entry.size > 0 {
309 let matches = if let Some(matcher) = expansion_matcher {
311 matcher.is_matching(entry.path.as_str())
312 } else {
313 true
314 };
315
316 if matches {
317 out_paths.push(PlRefPath::new(repo_location.get_file_uri(&entry.path)));
318 }
319 }
320 }
321 }
322
323 if let Some(mut_slice) = out_paths.get_mut(sort_start_idx..) {
324 <[PlRefPath]>::sort_unstable(mut_slice);
325 }
326 }
327
328 Ok((hive_idx_tracker.idx, out_paths))
329}
330
331mod tests {
332
333 #[test]
334 fn test_hf_path_from_uri() {
335 use super::HFPathParts;
336
337 let uri = "hf://datasets/pola-rs/polars/README.md";
338 let expect = HFPathParts {
339 bucket: "datasets".into(),
340 repository: "pola-rs/polars".into(),
341 revision: "main".into(),
342 path: "README.md".into(),
343 };
344
345 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
346
347 let uri = "hf://spaces/pola-rs/polars@~parquet/";
348 let expect = HFPathParts {
349 bucket: "spaces".into(),
350 repository: "pola-rs/polars".into(),
351 revision: "~parquet".into(),
352 path: "".into(),
353 };
354
355 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
356
357 let uri = "hf://spaces/pola-rs/polars@~parquet";
358 let expect = HFPathParts {
359 bucket: "spaces".into(),
360 repository: "pola-rs/polars".into(),
361 revision: "~parquet".into(),
362 path: "".into(),
363 };
364
365 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
366
367 for uri in [
368 "://",
369 "s3://",
370 "https://",
371 "hf://",
372 "hf:///",
373 "hf:////",
374 "hf://datasets/a",
375 "hf://datasets/a/",
376 "hf://bucket/a/b/c", ] {
378 let out = HFPathParts::try_from_uri(uri);
379 if out.is_err() {
380 continue;
381 }
382 panic!("expected err result for uri {uri} instead of {out:?}");
383 }
384 }
385
386 #[test]
387 fn test_get_pages_find_next_link() {
388 use super::GetPages;
389 let link = r#"<https://api.github.com/repositories/263727855/issues?page=3>; rel="next", <https://api.github.com/repositories/263727855/issues?page=7>; rel="last""#.as_bytes();
390
391 assert_eq!(
392 GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap),
393 Some("https://api.github.com/repositories/263727855/issues?page=3".into()),
394 );
395
396 assert_eq!(
397 GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap),
398 Some("https://api.github.com/repositories/263727855/issues?page=7".into()),
399 );
400
401 assert_eq!(
402 GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap),
403 None,
404 );
405 }
406
407 #[test]
408 fn test_hf_url_encoding() {
409 use super::HFRepoLocation;
414
415 let loc = HFRepoLocation::new("datasets", "HuggingFaceFW/fineweb-2", "main");
416
417 assert_eq!(
419 loc.api_base_path,
420 "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-2/tree/main/"
421 );
422 assert_eq!(
423 loc.download_base_path,
424 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/"
425 );
426
427 let file_uri = loc.get_file_uri("data/aai_Latn/train/000_00000.parquet");
429 assert_eq!(
430 file_uri,
431 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/data/aai_Latn/train/000_00000.parquet"
432 );
433
434 let file_uri = loc.get_file_uri(
437 "hive_dates/date1=2024-01-01/date2=2023-01-01 00:00:00.000000/00000000.parquet",
438 );
439 assert_eq!(
440 file_uri,
441 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/hive_dates/date1%3D2024-01-01/date2%3D2023-01-01%2000%3A00%3A00.000000/00000000.parquet"
442 );
443
444 let file_uri = loc.get_file_uri("special-chars/[*.parquet");
446 assert_eq!(
447 file_uri,
448 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/special-chars/%5B%2A.parquet"
449 );
450
451 let loc = HFRepoLocation::new("datasets", "user/repo", "refs/convert/parquet");
454 assert_eq!(
455 loc.api_base_path,
456 "https://huggingface.co/api/datasets/user/repo/tree/refs%2Fconvert%2Fparquet/"
457 );
458 assert_eq!(
459 loc.download_base_path,
460 "https://huggingface.co/datasets/user/repo/resolve/refs%2Fconvert%2Fparquet/"
461 );
462 }
463}