polars.Series.to_jax#

Series.to_jax(device: jax.Device | str | None = None) jax.Array[source]#

Convert this Series to a Jax Array.

New in version 0.20.27.

Warning

This functionality is currently considered unstable. It may be changed at any point without it being considered a breaking change.

Parameters:
device

Specify the jax Device on which the array will be created; can provide a string (such as “cpu”, “gpu”, or “tpu”) in which case the device is retrieved as jax.devices(string)[0]. For more specific control you can supply the instantiated Device directly. If None, arrays are created on the default device.

Examples

>>> s = pl.Series("x", [10.5, 0.0, -10.0, 5.5])
>>> s.to_jax()
Array([ 10.5,   0. , -10. ,   5.5], dtype=float32)