1use std::borrow::Cow;
4use std::collections::VecDeque;
5use std::path::PathBuf;
6
7use polars_error::{PolarsResult, polars_bail, to_compute_err};
8
9use crate::cloud::{
10 CloudConfig, CloudOptions, Matcher, extract_prefix_expansion,
11 try_build_http_header_map_from_items_slice,
12};
13use crate::path_utils::HiveIdxTracker;
14use crate::pl_async::with_concurrency_budget;
15use crate::prelude::URL_ENCODE_CHAR_SET;
16use crate::utils::decode_json_response;
17
18#[derive(Debug, PartialEq)]
19struct HFPathParts {
20 bucket: String,
21 repository: String,
22 revision: String,
23 path: String,
25}
26
27struct HFRepoLocation {
28 api_base_path: String,
29 download_base_path: String,
30}
31
32impl HFRepoLocation {
33 fn new(bucket: &str, repository: &str, revision: &str) -> Self {
34 let bucket = percent_encode(bucket.as_bytes());
35 let repository = percent_encode(repository.as_bytes());
36
37 let api_base_path = format!(
39 "{}{}{}{}{}{}{}",
40 "https://huggingface.co/api/", bucket, "/", repository, "/tree/", revision, "/"
41 );
42 let download_base_path = format!(
43 "{}{}{}{}{}{}{}",
44 "https://huggingface.co/", bucket, "/", repository, "/resolve/", revision, "/"
45 );
46
47 Self {
48 api_base_path,
49 download_base_path,
50 }
51 }
52
53 fn get_file_uri(&self, rel_path: &str) -> String {
54 format!(
55 "{}{}",
56 self.download_base_path,
57 percent_encode(rel_path.as_bytes())
58 )
59 }
60
61 fn get_api_uri(&self, rel_path: &str) -> String {
62 format!(
63 "{}{}",
64 self.api_base_path,
65 percent_encode(rel_path.as_bytes())
66 )
67 }
68}
69
70impl HFPathParts {
71 fn try_from_uri(uri: &str) -> PolarsResult<Self> {
74 let Some(this) = (|| {
75 if !uri.starts_with("hf://") {
78 return None;
79 }
80 let uri = &uri[5..];
81
82 let i = memchr::memchr(b'/', uri.as_bytes())?;
85 let bucket = uri.get(..i)?.to_string();
86 let uri = uri.get(1 + i..)?;
87
88 let i = memchr::memchr(b'/', uri.as_bytes())?;
91 let i = {
92 let uri = uri.get(1 + i..)?;
95 if uri.is_empty() {
96 return None;
97 }
98 1 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len())
99 };
100 let repository = uri.get(..i)?;
101 let uri = uri.get(1 + i..).unwrap_or("");
102
103 let (repository, revision) =
104 if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) {
105 (repository[..i].to_string(), repository[1 + i..].to_string())
106 } else {
107 (repository.to_string(), "main".to_string())
109 };
110
111 let path = uri.to_string();
114
115 Some(HFPathParts {
116 bucket,
117 repository,
118 revision,
119 path,
120 })
121 })() else {
122 polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri);
123 };
124
125 const BUCKETS: [&str; 2] = ["datasets", "spaces"];
126 if !BUCKETS.contains(&this.bucket.as_str()) {
127 polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket);
128 }
129
130 Ok(this)
131 }
132}
133
134#[derive(Debug, serde::Deserialize)]
135struct HFAPIResponse {
136 #[serde(rename = "type")]
137 type_: String,
138 path: String,
139 size: u64,
140}
141
142impl HFAPIResponse {
143 fn is_file(&self) -> bool {
144 self.type_ == "file"
145 }
146
147 fn is_directory(&self) -> bool {
148 self.type_ == "directory"
149 }
150}
151
152struct GetPages<'a> {
156 client: &'a reqwest::Client,
157 uri: Option<String>,
158}
159
160impl GetPages<'_> {
161 async fn next(&mut self) -> Option<PolarsResult<bytes::Bytes>> {
162 let uri = self.uri.take()?;
163
164 Some(
165 async {
166 let resp = with_concurrency_budget(1, || async {
167 self.client.get(uri).send().await.map_err(to_compute_err)
168 })
169 .await?;
170
171 self.uri = resp
172 .headers()
173 .get("link")
174 .and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes()))
175 .transpose()?;
176
177 let resp_bytes = resp.bytes().await.map_err(to_compute_err)?;
178
179 Ok(resp_bytes)
180 }
181 .await,
182 )
183 }
184
185 fn find_link(mut link: &[u8], rel: &[u8]) -> Option<PolarsResult<String>> {
186 while !link.is_empty() {
188 let i = memchr::memchr(b'<', link)?;
189 link = link.get(1 + i..)?;
190 let i = memchr::memchr(b'>', link)?;
191 let uri = &link[..i];
192 link = link.get(1 + i..)?;
193
194 while !link.starts_with("rel=\"".as_bytes()) {
195 link = link.get(1..)?
196 }
197
198 link = link.get(5..)?;
200 let i = memchr::memchr(b'"', link)?;
201
202 if &link[..i] == rel {
203 return Some(
204 std::str::from_utf8(uri)
205 .map_err(to_compute_err)
206 .map(ToString::to_string),
207 );
208 }
209 }
210
211 None
212 }
213}
214
215pub(super) async fn expand_paths_hf(
216 paths: &[PathBuf],
217 check_directory_level: bool,
218 cloud_options: Option<&CloudOptions>,
219 glob: bool,
220) -> PolarsResult<(usize, Vec<PathBuf>)> {
221 assert!(!paths.is_empty());
222
223 let client = reqwest::ClientBuilder::new().http1_only().https_only(true);
224
225 let client = if let Some(CloudOptions {
226 config: Some(CloudConfig::Http { headers }),
227 ..
228 }) = cloud_options
229 {
230 client.default_headers(try_build_http_header_map_from_items_slice(
231 headers.as_slice(),
232 )?)
233 } else {
234 client
235 };
236
237 let client = &client.build().unwrap();
238
239 let mut out_paths = vec![];
240 let mut stack = VecDeque::new();
241 let mut entries = vec![];
242 let mut hive_idx_tracker = HiveIdxTracker {
243 idx: usize::MAX,
244 paths,
245 check_directory_level,
246 };
247
248 for (path_idx, path) in paths.iter().enumerate() {
249 let path_parts = &HFPathParts::try_from_uri(path.to_str().unwrap())?;
250 let repo_location = &HFRepoLocation::new(
251 &path_parts.bucket,
252 &path_parts.repository,
253 &path_parts.revision,
254 );
255 let rel_path = path_parts.path.as_str();
256
257 let (prefix, expansion) = if glob {
258 extract_prefix_expansion(rel_path)?
259 } else {
260 (Cow::Owned(path_parts.path.clone()), None)
261 };
262 let expansion_matcher = &if expansion.is_some() {
263 Some(Matcher::new(prefix.to_string(), expansion.as_deref())?)
264 } else {
265 None
266 };
267
268 if !path_parts.path.ends_with("/") && expansion.is_none() {
269 hive_idx_tracker.update(0, path_idx)?;
270 let file_uri = repo_location.get_file_uri(rel_path);
271 let file_uri = file_uri.as_str();
272
273 if with_concurrency_budget(1, || async {
274 client.head(file_uri).send().await.map_err(to_compute_err)
275 })
276 .await?
277 .status()
278 == 200
279 {
280 out_paths.push(PathBuf::from(file_uri));
281 continue;
282 }
283 }
284
285 hive_idx_tracker.update(repo_location.get_file_uri(rel_path).len(), path_idx)?;
286
287 assert!(stack.is_empty());
288 stack.push_back(prefix.into_owned());
289
290 while let Some(rel_path) = stack.pop_front() {
291 assert!(entries.is_empty());
292
293 let uri = repo_location.get_api_uri(rel_path.as_str());
294 let mut gp = GetPages {
295 uri: Some(uri),
296 client,
297 };
298
299 if let Some(matcher) = expansion_matcher {
300 while let Some(bytes) = gp.next().await {
301 let bytes = bytes?;
302 let bytes = bytes.as_ref();
303 let response: Vec<HFAPIResponse> = decode_json_response(bytes)?;
304 entries.extend(response.into_iter().filter(|x| {
305 !x.is_file() || (x.size > 0 && matcher.is_matching(x.path.as_str()))
306 }));
307 }
308 } else {
309 while let Some(bytes) = gp.next().await {
310 let bytes = bytes?;
311 let bytes = bytes.as_ref();
312 let response: Vec<HFAPIResponse> = decode_json_response(bytes)?;
313 entries.extend(response.into_iter().filter(|x| !x.is_file() || x.size > 0));
314 }
315 }
316
317 entries.sort_unstable_by(|a, b| a.path.as_str().partial_cmp(b.path.as_str()).unwrap());
318
319 for e in entries.drain(..) {
320 if e.is_file() {
321 out_paths.push(PathBuf::from(repo_location.get_file_uri(&e.path)));
322 } else if e.is_directory() {
323 stack.push_back(e.path);
324 }
325 }
326 }
327 }
328
329 Ok((hive_idx_tracker.idx, out_paths))
330}
331
332fn percent_encode(bytes: &[u8]) -> percent_encoding::PercentEncode {
333 percent_encoding::percent_encode(bytes, URL_ENCODE_CHAR_SET)
334}
335
336mod tests {
337
338 #[test]
339 fn test_hf_path_from_uri() {
340 use super::HFPathParts;
341
342 let uri = "hf://datasets/pola-rs/polars/README.md";
343 let expect = HFPathParts {
344 bucket: "datasets".into(),
345 repository: "pola-rs/polars".into(),
346 revision: "main".into(),
347 path: "README.md".into(),
348 };
349
350 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
351
352 let uri = "hf://spaces/pola-rs/polars@~parquet/";
353 let expect = HFPathParts {
354 bucket: "spaces".into(),
355 repository: "pola-rs/polars".into(),
356 revision: "~parquet".into(),
357 path: "".into(),
358 };
359
360 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
361
362 let uri = "hf://spaces/pola-rs/polars@~parquet";
363 let expect = HFPathParts {
364 bucket: "spaces".into(),
365 repository: "pola-rs/polars".into(),
366 revision: "~parquet".into(),
367 path: "".into(),
368 };
369
370 assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
371
372 for uri in [
373 "://",
374 "s3://",
375 "https://",
376 "hf://",
377 "hf:///",
378 "hf:////",
379 "hf://datasets/a",
380 "hf://datasets/a/",
381 "hf://bucket/a/b/c", ] {
383 let out = HFPathParts::try_from_uri(uri);
384 if out.is_err() {
385 continue;
386 }
387 panic!("expected err result for uri {} instead of {:?}", uri, out);
388 }
389 }
390
391 #[test]
392 fn test_get_pages_find_next_link() {
393 use super::GetPages;
394 let link = r#"<https://api.github.com/repositories/263727855/issues?page=3>; rel="next", <https://api.github.com/repositories/263727855/issues?page=7>; rel="last""#.as_bytes();
395
396 assert_eq!(
397 GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap),
398 Some("https://api.github.com/repositories/263727855/issues?page=3".into()),
399 );
400
401 assert_eq!(
402 GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap),
403 Some("https://api.github.com/repositories/263727855/issues?page=7".into()),
404 );
405
406 assert_eq!(
407 GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap),
408 None,
409 );
410 }
411}