-
paguro.Collection.to_jax(return_type: JaxExportType =
'array', *, device: jax.Device | str | None =None, label: str | Expr | Sequence[str | Expr] | None =None, features: str | Expr | Sequence[str | Expr] | None =None, dtype: PolarsDataType | None =None, order: IndexOrder ='fortran') dict[str, jax.Array | dict[str, jax.Array]]