polars.DataFrame.to_jax#
- DataFrame.to_jax(
- return_type: JaxExportType = 'array',
- *,
- device: jax.Device | str | None = None,
- label: str | Expr | Sequence[str | Expr] | None = None,
- features: str | Expr | Sequence[str | Expr] | None = None,
- dtype: PolarsDataType | None = None,
- order: IndexOrder = 'fortran',
Convert DataFrame to a Jax Array, or dict of Jax Arrays.
Added in version 0.20.27.
Warning
This functionality is currently considered unstable. It may be changed at any point without it being considered a breaking change.
- Parameters:
- return_type{“array”, “dict”}
Set return type; a Jax Array, or dict of Jax Arrays.
- device
Specify the jax
Device
on which the array will be created; can provide a string (such as “cpu”, “gpu”, or “tpu”) in which case the device is retrieved asjax.devices(string)[0]
. For more specific control you can supply the instantiatedDevice
directly. If None, arrays are created on the default device.- label
One or more column names, expressions, or selectors that label the feature data; results in a
{"label": ..., "features": ...}
dict being returned whenreturn_type
is “dict” instead of a{"col": array, }
dict.- features
One or more column names, expressions, or selectors that contain the feature data; if omitted, all columns that are not designated as part of the label are used. Only applies when
return_type
is “dict”.- dtype
Unify the dtype of all returned arrays; this casts any column that is not already of the required dtype before converting to Array. Note that export will be single-precision (32bit) unless the Jax config/environment directs otherwise (eg: “jax_enable_x64” was set True in the config object at startup, or “JAX_ENABLE_X64” is set to “1” in the environment).
- order{“c”, “fortran”}
The index order of the returned Jax array, either C-like (row-major) or Fortran-like (column-major).
See also
Examples
>>> df = pl.DataFrame( ... { ... "lbl": [0, 1, 2, 3], ... "feat1": [1, 0, 0, 1], ... "feat2": [1.5, -0.5, 0.0, -2.25], ... } ... )
Standard return type (2D Array), on the standard device:
>>> df.to_jax() Array([[ 0. , 1. , 1.5 ], [ 1. , 0. , -0.5 ], [ 2. , 0. , 0. ], [ 3. , 1. , -2.25]], dtype=float32)
Create the Array on the default GPU device:
>>> a = df.to_jax(device="gpu") >>> a.device() GpuDevice(id=0, process_index=0)
Create the Array on a specific GPU device:
>>> gpu_device = jax.devices("gpu")[1] >>> a = df.to_jax(device=gpu_device) >>> a.device() GpuDevice(id=1, process_index=0)
As a dictionary of individual Arrays:
>>> df.to_jax("dict") {'lbl': Array([0, 1, 2, 3], dtype=int32), 'feat1': Array([1, 0, 0, 1], dtype=int32), 'feat2': Array([ 1.5 , -0.5 , 0. , -2.25], dtype=float32)}
As a “label” and “features” dictionary; note that as “features” is not declared, it defaults to all the columns that are not in “label”:
>>> df.to_jax("dict", label="lbl") {'label': Array([[0], [1], [2], [3]], dtype=int32), 'features': Array([[ 1. , 1.5 ], [ 0. , -0.5 ], [ 0. , 0. ], [ 1. , -2.25]], dtype=float32)}
As a “label” and “features” dictionary where each is designated using a col or selector expression (which can also be used to cast the data if the label and features are better-represented with different dtypes):
>>> import polars.selectors as cs >>> df.to_jax( ... return_type="dict", ... features=cs.float(), ... label=pl.col("lbl").cast(pl.UInt8), ... ) {'label': Array([[0], [1], [2], [3]], dtype=uint8), 'features': Array([[ 1.5 ], [-0.5 ], [ 0. ], [-2.25]], dtype=float32)}