polars.Series.to_jax#
- Series.to_jax(device: jax.Device | str | None = None) jax.Array [source]#
Convert this Series to a Jax Array.
Added in version 0.20.27.
Warning
This functionality is currently considered unstable. It may be changed at any point without it being considered a breaking change.
- Parameters:
- device
Specify the jax
Device
on which the array will be created; can provide a string (such as “cpu”, “gpu”, or “tpu”) in which case the device is retrieved asjax.devices(string)[0]
. For more specific control you can supply the instantiatedDevice
directly. If None, arrays are created on the default device.
Examples
>>> s = pl.Series("x", [10.5, 0.0, -10.0, 5.5]) >>> s.to_jax() Array([ 10.5, 0. , -10. , 5.5], dtype=float32)