1use std::borrow::Cow;
4
5use polars_error::{PolarsResult, polars_bail, to_compute_err};
6use polars_utils::plpath::PlPath;
7
8use crate::cloud::{
9 CloudConfig, CloudOptions, Matcher, USER_AGENT, extract_prefix_expansion,
10 try_build_http_header_map_from_items_slice,
11};
12use crate::path_utils::HiveIdxTracker;
13use crate::pl_async::with_concurrency_budget;
14use crate::utils::{URL_ENCODE_CHARSET, decode_json_response};
15
16const HF_PATH_ENCODE_CHARSET: &percent_encoding::AsciiSet = &URL_ENCODE_CHARSET.remove(b'/');
24
25#[derive(Debug, PartialEq)]
26struct HFPathParts {
27 bucket: String,
28 repository: String,
29 revision: String,
30 path: String,
32}
33
34struct HFRepoLocation {
35 api_base_path: String,
36 download_base_path: String,
37}
38
39impl HFRepoLocation {
40 fn new(bucket: &str, repository: &str, revision: &str) -> Self {
41 let encoded_revision =
47 percent_encoding::percent_encode(revision.as_bytes(), URL_ENCODE_CHARSET);
48 let api_base_path = format!(
49 "https://huggingface.co/api/{}/{}/tree/{}/",
50 bucket, repository, encoded_revision
51 );
52 let download_base_path = format!(
53 "https://huggingface.co/{}/{}/resolve/{}/",
54 bucket, repository, encoded_revision
55 );
56
57 Self {
58 api_base_path,
59 download_base_path,
60 }
61 }
62
63 fn get_file_uri(&self, rel_path: &str) -> String {
64 format!(
65 "{}{}",
66 self.download_base_path,
67 percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
68 )
69 }
70
71 fn get_api_uri(&self, rel_path: &str) -> String {
72 format!(
73 "{}{}",
74 self.api_base_path,
75 percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
76 )
77 }
78}
79
80impl HFPathParts {
81 fn try_from_uri(uri: &str) -> PolarsResult<Self> {
84 let Some(this) = (|| {
85 if !uri.starts_with("hf://") {
88 return None;
89 }
90 let uri = &uri[5..];
91
92 let i = memchr::memchr(b'/', uri.as_bytes())?;
95 let bucket = uri.get(..i)?.to_string();
96 let uri = uri.get(1 + i..)?;
97
98 let i = memchr::memchr(b'/', uri.as_bytes())?;
101 let i = {
102 let uri = uri.get(1 + i..)?;
105 if uri.is_empty() {
106 return None;
107 }
108 1 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len())
109 };
110 let repository = uri.get(..i)?;
111 let uri = uri.get(1 + i..).unwrap_or("");
112
113 let (repository, revision) =
114 if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) {
115 (repository[..i].to_string(), repository[1 + i..].to_string())
116 } else {
117 (repository.to_string(), "main".to_string())
119 };
120
121 let path = uri.to_string();
124
125 Some(HFPathParts {
126 bucket,
127 repository,
128 revision,
129 path,
130 })
131 })() else {
132 polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri);
133 };
134
135 const BUCKETS: [&str; 2] = ["datasets", "spaces"];
136 if !BUCKETS.contains(&this.bucket.as_str()) {
137 polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket);
138 }
139
140 Ok(this)
141 }
142}
143
144#[derive(Debug, serde::Deserialize)]
145struct HFAPIResponse {
146 #[serde(rename = "type")]
147 type_: String,
148 path: String,
149 size: u64,
150}
151
152impl HFAPIResponse {
153 fn is_file(&self) -> bool {
154 self.type_ == "file"
155 }
156}
157
158struct GetPages<'a> {
162 client: &'a reqwest::Client,
163 uri: Option<String>,
164}
165
166impl GetPages<'_> {
167 async fn next(&mut self) -> Option<PolarsResult<bytes::Bytes>> {
168 let uri = self.uri.take()?;
169
170 Some(
171 async {
172 let resp = with_concurrency_budget(1, || async {
173 self.client.get(uri).send().await.map_err(to_compute_err)
174 })
175 .await?;
176
177 self.uri = resp
178 .headers()
179 .get("link")
180 .and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes()))
181 .transpose()?;
182
183 let resp_bytes = resp.bytes().await.map_err(to_compute_err)?;
184
185 Ok(resp_bytes)
186 }
187 .await,
188 )
189 }
190
191 fn find_link(mut link: &[u8], rel: &[u8]) -> Option<PolarsResult<String>> {
192 while !link.is_empty() {
194 let i = memchr::memchr(b'<', link)?;
195 link = link.get(1 + i..)?;
196 let i = memchr::memchr(b'>', link)?;
197 let uri = &link[..i];
198 link = link.get(1 + i..)?;
199
200 while !link.starts_with("rel=\"".as_bytes()) {
201 link = link.get(1..)?
202 }
203
204 link = link.get(5..)?;
206 let i = memchr::memchr(b'"', link)?;
207
208 if &link[..i] == rel {
209 return Some(
210 std::str::from_utf8(uri)
211 .map_err(to_compute_err)
212 .map(ToString::to_string),
213 );
214 }
215 }
216
217 None
218 }
219}
220
221pub(super) async fn expand_paths_hf(
222 paths: &[PlPath],
223 check_directory_level: bool,
224 cloud_options: &Option<CloudOptions>,
225 glob: bool,
226) -> PolarsResult<(usize, Vec<PlPath>)> {
227 assert!(!paths.is_empty());
228
229 let client = reqwest::ClientBuilder::new()
230 .user_agent(USER_AGENT)
231 .http1_only()
232 .https_only(true);
233
234 let client = if let Some(CloudOptions {
235 config: Some(CloudConfig::Http { headers }),
236 ..
237 }) = cloud_options
238 {
239 client.default_headers(try_build_http_header_map_from_items_slice(
240 headers.as_slice(),
241 )?)
242 } else {
243 client
244 };
245
246 let client = &client.build().unwrap();
247
248 let mut out_paths = vec![];
249 let mut hive_idx_tracker = HiveIdxTracker {
250 idx: usize::MAX,
251 paths,
252 check_directory_level,
253 };
254
255 for (path_idx, path) in paths.iter().enumerate() {
256 let path_parts = &HFPathParts::try_from_uri(path.to_str())?;
257 let repo_location = &HFRepoLocation::new(
258 &path_parts.bucket,
259 &path_parts.repository,
260 &path_parts.revision,
261 );
262 let rel_path = path_parts.path.as_str();
263
264 let (prefix, expansion) = if glob {
265 extract_prefix_expansion(rel_path)?
266 } else {
267 (Cow::Owned(path_parts.path.clone()), None)
268 };
269 let expansion_matcher = &if expansion.is_some() {
270 Some(Matcher::new(prefix.to_string(), expansion.as_deref())?)
271 } else {
272 None
273 };
274
275 let file_uri = repo_location.get_file_uri(rel_path);
276
277 if !path_parts.path.ends_with("/") && expansion.is_none() {
278 if with_concurrency_budget(1, || async {
280 client.head(&file_uri).send().await.map_err(to_compute_err)
281 })
282 .await?
283 .status()
284 == 200
285 {
286 hive_idx_tracker.update(0, path_idx)?;
287 out_paths.push(PlPath::from_string(file_uri));
288 continue;
289 }
290 }
291
292 hive_idx_tracker.update(file_uri.len(), path_idx)?;
293
294 let uri = format!("{}?recursive=true", repo_location.get_api_uri(&prefix));
295 let mut gp = GetPages {
296 uri: Some(uri),
297 client,
298 };
299
300 while let Some(bytes) = gp.next().await {
301 let bytes = bytes?;
302 let response: Vec<HFAPIResponse> = decode_json_response(bytes.as_ref())?;
303
304 for entry in response {
305 if entry.is_file() && entry.size > 0 {
307 let matches = if let Some(matcher) = expansion_matcher {
309 matcher.is_matching(entry.path.as_str())
310 } else {
311 true
312 };
313
314 if matches {
315 out_paths
316 .push(PlPath::from_string(repo_location.get_file_uri(&entry.path)));
317 }
318 }
319 }
320 }
321 }
322
323 Ok((hive_idx_tracker.idx, out_paths))
324}
325
326mod tests {
327
328 #[test]
329 fn test_hf_path_from_uri() {
330 use super::HFPathParts;
331
332 let uri = "hf://datasets/pola-rs/polars/README.md";
333 let expect = HFPathParts {
334 bucket: "datasets".into(),
335 repository: "pola-rs/polars".into(),
336 revision: "main".into(),
337 path: "README.md".into(),
338 };
339
340 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
341
342 let uri = "hf://spaces/pola-rs/polars@~parquet/";
343 let expect = HFPathParts {
344 bucket: "spaces".into(),
345 repository: "pola-rs/polars".into(),
346 revision: "~parquet".into(),
347 path: "".into(),
348 };
349
350 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
351
352 let uri = "hf://spaces/pola-rs/polars@~parquet";
353 let expect = HFPathParts {
354 bucket: "spaces".into(),
355 repository: "pola-rs/polars".into(),
356 revision: "~parquet".into(),
357 path: "".into(),
358 };
359
360 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
361
362 for uri in [
363 "://",
364 "s3://",
365 "https://",
366 "hf://",
367 "hf:///",
368 "hf:////",
369 "hf://datasets/a",
370 "hf://datasets/a/",
371 "hf://bucket/a/b/c", ] {
373 let out = HFPathParts::try_from_uri(uri);
374 if out.is_err() {
375 continue;
376 }
377 panic!("expected err result for uri {uri} instead of {out:?}");
378 }
379 }
380
381 #[test]
382 fn test_get_pages_find_next_link() {
383 use super::GetPages;
384 let link = r#"<https://api.github.com/repositories/263727855/issues?page=3>; rel="next", <https://api.github.com/repositories/263727855/issues?page=7>; rel="last""#.as_bytes();
385
386 assert_eq!(
387 GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap),
388 Some("https://api.github.com/repositories/263727855/issues?page=3".into()),
389 );
390
391 assert_eq!(
392 GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap),
393 Some("https://api.github.com/repositories/263727855/issues?page=7".into()),
394 );
395
396 assert_eq!(
397 GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap),
398 None,
399 );
400 }
401
402 #[test]
403 fn test_hf_url_encoding() {
404 use super::HFRepoLocation;
409
410 let loc = HFRepoLocation::new("datasets", "HuggingFaceFW/fineweb-2", "main");
411
412 assert_eq!(
414 loc.api_base_path,
415 "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-2/tree/main/"
416 );
417 assert_eq!(
418 loc.download_base_path,
419 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/"
420 );
421
422 let file_uri = loc.get_file_uri("data/aai_Latn/train/000_00000.parquet");
424 assert_eq!(
425 file_uri,
426 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/data/aai_Latn/train/000_00000.parquet"
427 );
428
429 let file_uri = loc.get_file_uri(
432 "hive_dates/date1=2024-01-01/date2=2023-01-01 00:00:00.000000/00000000.parquet",
433 );
434 assert_eq!(
435 file_uri,
436 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/hive_dates/date1%3D2024-01-01/date2%3D2023-01-01%2000%3A00%3A00.000000/00000000.parquet"
437 );
438
439 let file_uri = loc.get_file_uri("special-chars/[*.parquet");
441 assert_eq!(
442 file_uri,
443 "https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/special-chars/%5B%2A.parquet"
444 );
445
446 let loc = HFRepoLocation::new("datasets", "user/repo", "refs/convert/parquet");
449 assert_eq!(
450 loc.api_base_path,
451 "https://huggingface.co/api/datasets/user/repo/tree/refs%2Fconvert%2Fparquet/"
452 );
453 assert_eq!(
454 loc.download_base_path,
455 "https://huggingface.co/datasets/user/repo/resolve/refs%2Fconvert%2Fparquet/"
456 );
457 }
458}