Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | X | Z _ __call__() (jax.stages.Compiled method) (jax.stages.Wrapped method) __init__() (jax.Array method) (jax.custom_batching.custom_vmap method) (jax.custom_jvp method) (jax.custom_vjp method) (jax.Device method) (jax.dtypes.bfloat16 method) (jax.dtypes.prng_key method) (jax.experimental.checkify.Error method) (jax.experimental.custom_dce.custom_dce method) (jax.experimental.pallas.BlockSpec method) (jax.experimental.pallas.GridSpec method) (jax.experimental.pallas.mosaic_gpu.Barrier method) (jax.experimental.pallas.mosaic_gpu.BlockSpec method) (jax.experimental.pallas.mosaic_gpu.CompilerParams method) (jax.experimental.pallas.mosaic_gpu.Layout method) (jax.experimental.pallas.mosaic_gpu.MemorySpace method) (jax.experimental.pallas.mosaic_gpu.SemaphoreSignal method) (jax.experimental.pallas.mosaic_gpu.SemaphoreType method) (jax.experimental.pallas.mosaic_gpu.SwizzleTransform method) (jax.experimental.pallas.mosaic_gpu.TilingTransform method) (jax.experimental.pallas.mosaic_gpu.TransposeTransform method) (jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef method) (jax.experimental.pallas.Slice method) (jax.experimental.pallas.tpu.BufferedRef method) (jax.experimental.pallas.tpu.BufferedRefBase method) (jax.experimental.pallas.tpu.ChipVersion method) (jax.experimental.pallas.tpu.CompilerParams method) (jax.experimental.pallas.tpu.GridDimensionSemantics method) (jax.experimental.pallas.tpu.InterpretParams method) (jax.experimental.pallas.tpu.MemorySpace method) (jax.experimental.pallas.tpu.PrefetchScalarGridSpec method) (jax.experimental.pallas.tpu.SemaphoreType method) (jax.experimental.pallas.tpu.TpuInfo method) (jax.experimental.pallas.triton.CompilerParams method) (jax.experimental.sparse.BCOO method) (jax.experimental.sparse.BCSR method) (jax.experimental.sparse.COO method) (jax.experimental.sparse.CSC method) (jax.experimental.sparse.CSR method) (jax.experimental.sparse.JAXSparse method) (jax.export.SymbolicScope method) (jax.extend.core.ClosedJaxpr method) (jax.extend.core.Jaxpr method) (jax.extend.core.JaxprEqn method) (jax.extend.core.Literal method) (jax.extend.core.Primitive method) (jax.extend.core.Token method) (jax.extend.core.Var method) (jax.extend.linear_util.Callable method) (jax.extend.linear_util.WrappedFun method) (jax.lax.linalg.SvdAlgorithm method) (jax.nn.initializers.Initializer method) (jax.numpy.character method) (jax.numpy.complex128 method) (jax.numpy.complex64 method) (jax.numpy.complexfloating method) (jax.numpy.dtype method) (jax.numpy.finfo method) (jax.numpy.flexible method) (jax.numpy.float16 method) (jax.numpy.float32 method) (jax.numpy.float64 method) (jax.numpy.floating method) (jax.numpy.generic method) (jax.numpy.iinfo method) (jax.numpy.inexact method) (jax.numpy.int16 method) (jax.numpy.int32 method) (jax.numpy.int64 method) (jax.numpy.int8 method) (jax.numpy.integer method) (jax.numpy.number method) (jax.numpy.object_ method) (jax.numpy.signedinteger method) (jax.numpy.ufunc method) (jax.numpy.uint16 method) (jax.numpy.uint32 method) (jax.numpy.uint64 method) (jax.numpy.uint8 method) (jax.numpy.unsignedinteger method) (jax.profiler.StepTraceAnnotation method) (jax.profiler.TraceAnnotation method) (jax.ref.AbstractRef method) (jax.ref.Ref method) (jax.scipy.interpolate.RegularGridInterpolator method) (jax.scipy.optimize.OptimizeResults method) (jax.scipy.spatial.transform.Rotation method) (jax.scipy.spatial.transform.Slerp method) (jax.scipy.stats.gaussian_kde method) (jax.set_mesh method) (jax.ShapeDtypeStruct method) (jax.tree_util.Partial method) _get_vjp (jax.export.Exported attribute) A abs() (in module jax.lax) (in module jax.numpy) absolute() (in module jax.numpy) AbstractRef (class in jax.ref) ACC (in module jax.experimental.pallas.mosaic_gpu) accum_ref (jax.experimental.pallas.tpu.BufferedRef attribute) accumulation_type (jax.lax.DotAlgorithmPreset property) AccuracyMode (class in jax.lax) acos() (in module jax.lax) (in module jax.numpy) acosh() (in module jax.lax) (in module jax.numpy) adagrad() (in module jax.example_libraries.optimizers) adam() (in module jax.example_libraries.optimizers) adamax() (in module jax.example_libraries.optimizers) add (in module jax.numpy) add() (in module jax.lax) addressable_devices (jax.sharding.NamedSharding property) (jax.sharding.Sharding property) addressable_devices_indices_map() (jax.sharding.Sharding method) addressable_shards (jax.Array property) addupdate() (in module jax.ref) after_all() (in module jax.lax) all() (in module jax.numpy) (in module jax.tree) (jax.Array method) all_checks (in module jax.experimental.checkify) all_gather() (in module jax.lax) all_leaves() (in module jax.tree_util) all_to_all() (in module jax.lax) allclose() (in module jax.numpy) allow_collective_id_without_custom_barrier (jax.experimental.pallas.tpu.CompilerParams attribute) allow_hbm_allocation_in_run_scoped (jax.experimental.pallas.tpu.InterpretParams attribute) allow_input_fusion (jax.experimental.pallas.tpu.CompilerParams attribute) amax() (in module jax.numpy) amin() (in module jax.numpy) angle() (in module jax.numpy) annotate_function() (in module jax.profiler) any() (in module jax.numpy) (jax.Array method) ANY_F8_ANY_F8_ANY (jax.lax.DotAlgorithmPreset attribute) ANY_F8_ANY_F8_ANY_FAST_ACCUM (jax.lax.DotAlgorithmPreset attribute) ANY_F8_ANY_F8_F32 (jax.lax.DotAlgorithmPreset attribute) ANY_F8_ANY_F8_F32_FAST_ACCUM (jax.lax.DotAlgorithmPreset attribute) append() (in module jax.numpy) apply_along_axis() (in module jax.numpy) apply_over_axes() (in module jax.numpy) approx_math (jax.experimental.pallas.mosaic_gpu.CompilerParams attribute) approx_max_k() (in module jax.lax) approx_min_k() (in module jax.lax) approx_tanh() (in module jax.experimental.pallas.triton) arange() (in module jax.numpy) arccos() (in module jax.numpy) arccosh() (in module jax.numpy) arcsin() (in module jax.numpy) arcsinh() (in module jax.numpy) arctan() (in module jax.numpy) arctan2() (in module jax.numpy) arctanh() (in module jax.numpy) argmax() (in module jax.lax) (in module jax.numpy) (jax.Array method) argmin() (in module jax.lax) (in module jax.numpy) (jax.Array method) argpartition() (in module jax.numpy) (jax.Array method) argsort() (in module jax.numpy) (jax.Array method) argwhere() (in module jax.numpy) around() (in module jax.numpy) Array (class in jax) array() (in module jax.numpy) array_equal() (in module jax.numpy) array_equiv() (in module jax.numpy) array_repr() (in module jax.numpy) array_split() (in module jax.numpy) array_str() (in module jax.numpy) array_types (in module jax.extend.core) ArrayLike (in module jax.typing) as_text() (jax.stages.Compiled method) (jax.stages.Lowered method) as_torch_kernel() (in module jax.experimental.pallas.mosaic_gpu) asarray() (in module jax.numpy) asin() (in module jax.lax) (in module jax.numpy) asinh() (in module jax.lax) (in module jax.numpy) assert_equal() (in module jax.experimental.multihost_utils) associative_scan() (in module jax.lax) astype() (in module jax.numpy) (jax.Array method) async_copy() (in module jax.experimental.pallas.tpu) async_load_tmem() (in module jax.experimental.pallas.mosaic_gpu) async_remote_copy() (in module jax.experimental.pallas.tpu) async_store_tmem() (in module jax.experimental.pallas.mosaic_gpu) at (jax.Array property) (jax.numpy.ndarray property) atan() (in module jax.lax) (in module jax.numpy) atan2() (in module jax.lax) (in module jax.numpy) atanh() (in module jax.lax) (in module jax.numpy) atleast_1d() (in module jax.numpy) atleast_2d() (in module jax.numpy) atleast_3d() (in module jax.numpy) atomic_add() (in module jax.experimental.pallas.triton) atomic_and() (in module jax.experimental.pallas.triton) atomic_cas() (in module jax.experimental.pallas.triton) atomic_max() (in module jax.experimental.pallas.triton) atomic_min() (in module jax.experimental.pallas.triton) atomic_or() (in module jax.experimental.pallas.triton) atomic_xchg() (in module jax.experimental.pallas.triton) atomic_xor() (in module jax.experimental.pallas.triton) automatic_checks (in module jax.experimental.checkify) average() (in module jax.numpy) AvgPool() (in module jax.example_libraries.stax) AWAY_FROM_ZERO (jax.lax.RoundingMethod attribute) axis_index() (in module jax.lax) axis_size() (in module jax.lax) B backend_xla_version() (in module jax.extend.backend) backends() (in module jax.extend.backend) ball() (in module jax.random) Barrier (class in jax.experimental.pallas.mosaic_gpu) barrier_arrive() (in module jax.experimental.pallas.mosaic_gpu) barrier_wait() (in module jax.experimental.pallas.mosaic_gpu) bartlett() (in module jax.numpy) batch_matmul() (in module jax.lax) BatchNorm() (in module jax.example_libraries.stax) BCOO (class in jax.experimental.sparse) bcoo_broadcast_in_dim() (in module jax.experimental.sparse) bcoo_concatenate() (in module jax.experimental.sparse) bcoo_dot_general() (in module jax.experimental.sparse) bcoo_dot_general_sampled() (in module jax.experimental.sparse) bcoo_dynamic_slice() (in module jax.experimental.sparse) bcoo_extract() (in module jax.experimental.sparse) bcoo_fromdense() (in module jax.experimental.sparse) bcoo_gather() (in module jax.experimental.sparse) bcoo_multiply_dense() (in module jax.experimental.sparse) bcoo_multiply_sparse() (in module jax.experimental.sparse) bcoo_reduce_sum() (in module jax.experimental.sparse) bcoo_reshape() (in module jax.experimental.sparse) bcoo_slice() (in module jax.experimental.sparse) bcoo_sort_indices() (in module jax.experimental.sparse) bcoo_squeeze() (in module jax.experimental.sparse) bcoo_sum_duplicates() (in module jax.experimental.sparse) bcoo_todense() (in module jax.experimental.sparse) bcoo_transpose() (in module jax.experimental.sparse) bcoo_update_layout() (in module jax.experimental.sparse) BCSR (class in jax.experimental.sparse) bcsr_dot_general() (in module jax.experimental.sparse) bcsr_extract() (in module jax.experimental.sparse) bcsr_fromdense() (in module jax.experimental.sparse) bcsr_todense() (in module jax.experimental.sparse) bernoulli() (in module jax.random) (in module jax.scipy.special) bessel_i0e() (in module jax.lax) bessel_i1e() (in module jax.lax) beta() (in module jax.random) (in module jax.scipy.special) betainc() (in module jax.lax) (in module jax.scipy.special) betaln() (in module jax.scipy.special) BF16_BF16_BF16 (jax.lax.DotAlgorithmPreset attribute) BF16_BF16_F32 (jax.lax.DotAlgorithmPreset attribute) BF16_BF16_F32_X3 (jax.lax.DotAlgorithmPreset attribute) BF16_BF16_F32_X6 (jax.lax.DotAlgorithmPreset attribute) BF16_BF16_F32_X9 (jax.lax.DotAlgorithmPreset attribute) bfloat16 (class in jax.dtypes) bicgstab() (in module jax.scipy.sparse.linalg) bincount() (in module jax.numpy) binomial() (in module jax.random) bitcast_convert_type() (in module jax.lax) bits (jax.numpy.finfo attribute) bits() (in module jax.random) bitwise_and (in module jax.numpy) bitwise_and() (in module jax.lax) bitwise_count() (in module jax.numpy) bitwise_invert() (in module jax.numpy) bitwise_left_shift() (in module jax.numpy) bitwise_not() (in module jax.lax) (in module jax.numpy) bitwise_or (in module jax.numpy) bitwise_or() (in module jax.lax) bitwise_right_shift() (in module jax.numpy) bitwise_xor (in module jax.numpy) bitwise_xor() (in module jax.lax) blackman() (in module jax.numpy) block() (in module jax.numpy) block_diag() (in module jax.scipy.linalg) block_shape (jax.experimental.pallas.tpu.BufferedRef attribute) block_until_ready() (in module jax) BlockSpec (class in jax.experimental.pallas) (class in jax.experimental.pallas.mosaic_gpu) bool_ (in module jax.numpy) breakpoint() (in module jax.debug) broadcast() (in module jax.lax) (in module jax.tree) broadcast_arrays() (in module jax.numpy) broadcast_in_dim() (in module jax.lax) broadcast_one_to_all() (in module jax.experimental.multihost_utils) broadcast_shapes() (in module jax.lax) (in module jax.numpy) broadcast_to() (in module jax.experimental.pallas) (in module jax.numpy) broadcast_to_rank() (in module jax.lax) broadcasted_iota() (in module jax.lax) buffer_type (jax.experimental.pallas.tpu.BufferedRef attribute) BufferedRef (class in jax.experimental.pallas.tpu) BufferedRefBase (class in jax.experimental.pallas.tpu) C c_ (in module jax.numpy) cache() (in module jax.extend.linear_util) call() (jax.export.Exported method) Callable (class in jax.extend.linear_util) callback() (in module jax.debug) calling_convention_version (jax.export.Exported attribute) can_cast() (in module jax.numpy) canonicalize_dtype() (in module jax.dtypes) categorical() (in module jax.random) cauchy() (in module jax.random) cbrt() (in module jax.lax) (in module jax.numpy) cdf() (in module jax.scipy.stats.bernoulli) (in module jax.scipy.stats.beta) (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.chi2) (in module jax.scipy.stats.expon) (in module jax.scipy.stats.gamma) (in module jax.scipy.stats.gennorm) (in module jax.scipy.stats.gumbel_l) (in module jax.scipy.stats.gumbel_r) (in module jax.scipy.stats.laplace) (in module jax.scipy.stats.logistic) (in module jax.scipy.stats.norm) (in module jax.scipy.stats.pareto) (in module jax.scipy.stats.poisson) (in module jax.scipy.stats.truncnorm) (in module jax.scipy.stats.uniform) cdiv() (in module jax.experimental.pallas) cdouble (in module jax.numpy) ceil() (in module jax.lax) (in module jax.numpy) celu() (in module jax.nn) cg() (in module jax.scipy.sparse.linalg) character (class in jax.numpy) check() (in module jax.experimental.checkify) check_error() (in module jax.experimental.checkify) check_grads() (in module jax.test_util) check_jvp() (in module jax.test_util) check_tracer_leaks (in module jax) check_vjp() (in module jax.test_util) checkify() (in module jax.experimental.checkify) checking_leaks (in module jax) checkpoint() (in module jax) checkpoint_dots() (jax.checkpoint_policies method) checkpoint_dots_with_no_batch_dims() (jax.checkpoint_policies method) checkpoint_name() (in module jax.ad_checkpoint) ChipVersion (class in jax.experimental.pallas.tpu) chisquare() (in module jax.random) cho_factor() (in module jax.scipy.linalg) cho_solve() (in module jax.scipy.linalg) choice() (in module jax.random) cholesky() (in module jax.lax.linalg) (in module jax.numpy.linalg) (in module jax.scipy.linalg) cholesky_update() (in module jax.lax.linalg) choose() (in module jax.numpy) (jax.Array method) clamp() (in module jax.lax) clear_backends() (in module jax.extend.backend) clear_caches() (in module jax) clip() (in module jax.numpy) (jax.Array method) clip_grads() (in module jax.example_libraries.optimizers) clone() (in module jax.random) ClosedJaxpr (class in jax.extend.core) closure_convert() (in module jax) clz() (in module jax.lax) collapse() (in module jax.lax) collective_axes (jax.experimental.pallas.mosaic_gpu.BlockSpec attribute) collective_id (jax.experimental.pallas.tpu.CompilerParams attribute) column_stack() (in module jax.numpy) commit_smem() (in module jax.experimental.pallas.mosaic_gpu) commit_tmem() (in module jax.experimental.pallas.mosaic_gpu) committed (jax.Array property) compile() (jax.stages.Lowered method) Compiled (class in jax.stages) compiler_ir() (jax.stages.Lowered method) CompilerParams (class in jax.experimental.pallas.mosaic_gpu) (class in jax.experimental.pallas.tpu) (class in jax.experimental.pallas.triton) complex() (in module jax.lax) complex128 (class in jax.numpy) complex64 (class in jax.numpy) complex_ (in module jax.numpy) complexfloating (class in jax.numpy) ComplexWarning composite() (in module jax.lax) compress() (in module jax.numpy) (jax.Array method) compute_index (jax.experimental.pallas.tpu.BufferedRef attribute) concat() (in module jax.numpy) concatenate() (in module jax.lax) (in module jax.numpy) ConcretizationTypeError (class in jax.errors) cond() (in module jax.lax) (in module jax.numpy.linalg) config (in module jax) conj() (in module jax.lax) (in module jax.numpy) (jax.Array method) conjugate() (in module jax.numpy) (jax.Array method) constant() (in module jax.example_libraries.optimizers) (in module jax.nn.initializers) Conv() (in module jax.example_libraries.stax) conv() (in module jax.lax) Conv1DTranspose() (in module jax.example_libraries.stax) conv_dimension_numbers() (in module jax.lax) conv_general_dilated() (in module jax.lax) conv_general_dilated_local() (in module jax.lax) conv_general_dilated_patches() (in module jax.lax) conv_transpose() (in module jax.lax) conv_with_general_padding() (in module jax.lax) ConvDimensionNumbers (class in jax.lax) convert_element_type() (in module jax.lax) ConvGeneralDilatedDimensionNumbers (in module jax.lax) convolve() (in module jax.numpy) (in module jax.scipy.signal) convolve2d() (in module jax.scipy.signal) ConvTranspose() (in module jax.example_libraries.stax) COO (class in jax.experimental.sparse) coo_fromdense() (in module jax.experimental.sparse) coo_matmat() (in module jax.experimental.sparse) coo_matvec() (in module jax.experimental.sparse) coo_todense() (in module jax.experimental.sparse) copy() (in module jax.numpy) (jax.Array method) copy_gmem_to_smem() (in module jax.experimental.pallas.mosaic_gpu) copy_in_slot (jax.experimental.pallas.tpu.BufferedRef attribute) copy_out_slot (jax.experimental.pallas.tpu.BufferedRef attribute) copy_smem_to_gmem() (in module jax.experimental.pallas.mosaic_gpu) copy_to_host_async() (in module jax) (jax.Array method) copysign() (in module jax.numpy) core_barrier() (in module jax.experimental.pallas.tpu) core_map() (in module jax.experimental.pallas) corrcoef() (in module jax.numpy) correlate() (in module jax.numpy) (in module jax.scipy.signal) correlate2d() (in module jax.scipy.signal) cos() (in module jax.lax) (in module jax.numpy) cosh() (in module jax.lax) (in module jax.numpy) cost_analysis() (jax.stages.Compiled method) (jax.stages.Lowered method) count_nonzero() (in module jax.numpy) cov() (in module jax.numpy) CPU create_device_mesh() (in module jax.experimental.mesh_utils) create_hybrid_device_mesh() (in module jax.experimental.mesh_utils) cross() (in module jax.numpy) (in module jax.numpy.linalg) CSC (class in jax.experimental.sparse) csd() (in module jax.scipy.signal) csingle (in module jax.numpy) CSR (class in jax.experimental.sparse) csr_fromdense() (in module jax.experimental.sparse) csr_matmat() (in module jax.experimental.sparse) csr_matvec() (in module jax.experimental.sparse) csr_todense() (in module jax.experimental.sparse) cumlogsumexp() (in module jax.lax) cummax() (in module jax.lax) cummin() (in module jax.lax) cumprod() (in module jax.lax) (in module jax.numpy) (jax.Array method) cumsum() (in module jax.lax) (in module jax.numpy) (jax.Array method) cumulative_prod() (in module jax.numpy) cumulative_sum() (in module jax.numpy) current_ref (jax.experimental.pallas.tpu.BufferedRef attribute) CUSOLVER (jax.lax.linalg.EigImplementation attribute) custom_call() (jax.export.DisabledSafetyCheck class method) custom_dce (class in jax.experimental.custom_dce) custom_gradient() (in module jax) custom_jvp (class in jax) custom_linear_solve() (in module jax.lax) custom_partitioning() (in module jax.experimental.custom_partitioning) custom_root() (in module jax.lax) custom_vjp (class in jax) custom_vmap (class in jax.custom_batching) D data (jax.experimental.sparse.BCOO attribute) dct() (in module jax.scipy.fft) dctn() (in module jax.scipy.fft) debug_barrier() (in module jax.experimental.pallas.triton) debug_check() (in module jax.experimental.pallas) debug_infs (in module jax) debug_nans (in module jax) debug_print() (in module jax.experimental.pallas) def_dce() (jax.experimental.custom_dce.custom_dce method) def_vmap() (jax.custom_batching.custom_vmap method) DEFAULT (jax.lax.AccuracyMode attribute) (jax.lax.DotAlgorithmPreset attribute) default_backend() (in module jax) default_device (in module jax) default_export_platform() (in module jax.export) default_matmul_precision (in module jax) default_prng_impl (in module jax) define_prng_impl() (in module jax.extend.random) defjvp() (jax.custom_jvp method) defjvps() (jax.custom_jvp method) defvjp() (jax.custom_vjp method) deg2rad() (in module jax.numpy) degrees() (in module jax.numpy) delay_release (jax.experimental.pallas.mosaic_gpu.BlockSpec attribute) (jax.experimental.pallas.mosaic_gpu.CompilerParams attribute) delete() (in module jax.numpy) delta_orthogonal() (in module jax.nn.initializers) Dense() (in module jax.example_libraries.stax) deserialize() (in module jax.export) deserialize_and_load() (in module jax.experimental.serialize_executable) deserialize_portable_artifact (in module jax.extend.mlir) det() (in module jax.numpy.linalg) (in module jax.scipy.linalg) detrend() (in module jax.scipy.signal) Device (class in jax) device (jax.Array property) device_count() (in module jax) device_get() (in module jax) device_memory_profile() (in module jax.profiler) device_put() (in module jax) device_set (jax.sharding.NamedSharding property) (jax.sharding.Sharding property) (jax.sharding.SingleDeviceSharding property) devices() (in module jax) devices_indices_map() (jax.sharding.Sharding method) (jax.sharding.SingleDeviceSharding method) diag() (in module jax.numpy) diag_indices() (in module jax.numpy) diag_indices_from() (in module jax.numpy) diagflat() (in module jax.numpy) diagonal() (in module jax.numpy) (in module jax.numpy.linalg) (jax.Array method) diff() (in module jax.numpy) digamma() (in module jax.lax) (in module jax.scipy.special) digitize() (in module jax.numpy) dimension_semantics (jax.experimental.pallas.mosaic_gpu.CompilerParams attribute) (jax.experimental.pallas.tpu.CompilerParams attribute) dirichlet() (in module jax.random) disable_bounds_checks (jax.experimental.pallas.tpu.CompilerParams attribute) disable_jit() (in module jax) disable_x64() (in module jax.experimental) disabled_safety_checks (jax.export.Exported attribute) DisabledSafetyCheck (class in jax.export) div() (in module jax.lax) div_checks (in module jax.experimental.checkify) divide() (in module jax.numpy) divmod() (in module jax.numpy) dma_execution_mode (jax.experimental.pallas.tpu.InterpretParams attribute) dot() (in module jax.experimental.pallas) (in module jax.lax) (in module jax.numpy) (jax.Array method) dot_general() (in module jax.lax) dot_product_attention() (in module jax.nn) DotAlgorithm (class in jax.lax) DotAlgorithmPreset (class in jax.lax) DotDimensionNumbers (in module jax.lax) dots_saveable() (jax.checkpoint_policies method) dots_with_no_batch_dims_saveable() (jax.checkpoint_policies method) double (in module jax.numpy) double_sided_maxwell() (in module jax.random) Dropout() (in module jax.example_libraries.stax) dslice() (in module jax.experimental.pallas) dsplit() (in module jax.numpy) dstack() (in module jax.numpy) dtype (class in jax.numpy) (jax.Array property) (jax.numpy.finfo attribute) DTypeLike (in module jax.typing) dynamic_index_in_dim() (in module jax.lax) dynamic_scheduling_loop() (in module jax.experimental.pallas.mosaic_gpu) dynamic_slice() (in module jax.lax) dynamic_slice_in_dim() (in module jax.lax) dynamic_update_index_in_dim() (in module jax.lax) dynamic_update_slice() (in module jax.lax) dynamic_update_slice_in_dim() (in module jax.lax) E ediff1d() (in module jax.numpy) eig() (in module jax.lax.linalg) (in module jax.numpy.linalg) eigh() (in module jax.lax.linalg) (in module jax.numpy.linalg) (in module jax.scipy.linalg) eigh_tridiagonal() (in module jax.scipy.linalg) EighImplementation (class in jax.lax.linalg) EigImplementation (class in jax.lax.linalg) eigvals() (in module jax.numpy.linalg) eigvalsh() (in module jax.numpy.linalg) einsum() (in module jax.numpy) einsum_path() (in module jax.numpy) elementwise() (in module jax.example_libraries.stax) elementwise_inline_asm() (in module jax.experimental.pallas.triton) elu() (in module jax.nn) emit_pipeline() (in module jax.experimental.pallas.mosaic_gpu) (in module jax.experimental.pallas.tpu) emit_pipeline_warp_specialized() (in module jax.experimental.pallas.mosaic_gpu) emit_pipeline_with_allocations() (in module jax.experimental.pallas.tpu) empty() (in module jax.experimental.pallas) (in module jax.experimental.sparse) (in module jax.lax) (in module jax.numpy) empty_like() (in module jax.experimental.pallas) (in module jax.numpy) enable_checks (in module jax) enable_custom_prng (in module jax) enable_custom_vjp_by_custom_transpose (in module jax) enable_x64 (in module jax) enable_x64() (in module jax.experimental) ensure_compile_time_eval() (in module jax) entr() (in module jax.scipy.special) entropy() (in module jax.scipy.stats.poisson) eps (jax.numpy.finfo attribute) epsneg (jax.numpy.finfo attribute) eq() (in module jax.lax) equal() (in module jax.numpy) erf() (in module jax.lax) (in module jax.scipy.special) erf_inv() (in module jax.lax) erfc() (in module jax.lax) (in module jax.scipy.special) erfinv() (in module jax.scipy.special) Error (class in jax.experimental.checkify) eval_shape() (in module jax) evaluate() (jax.scipy.stats.gaussian_kde method) everything_saveable() (jax.checkpoint_policies method) exp() (in module jax.lax) (in module jax.numpy) exp1() (in module jax.scipy.special) exp2() (in module jax.lax) (in module jax.numpy) expand_dims() (in module jax.lax) (in module jax.numpy) expi (in module jax.scipy.special) expit() (in module jax.scipy.special) expm() (in module jax.scipy.linalg) expm1() (in module jax.lax) (in module jax.numpy) expm_frechet() (in module jax.scipy.linalg) expn (in module jax.scipy.special) exponential() (in module jax.random) exponential_decay() (in module jax.example_libraries.optimizers) export() (in module jax.export) Exported (class in jax.export) extract() (in module jax.numpy) eye() (in module jax.experimental.sparse) (in module jax.numpy) F f() (in module jax.random) F16_F16_F16 (jax.lax.DotAlgorithmPreset attribute) F16_F16_F32 (jax.lax.DotAlgorithmPreset attribute) F32_F32_F32 (jax.lax.DotAlgorithmPreset attribute) F64_F64_F64 (jax.lax.DotAlgorithmPreset attribute) fabs() (in module jax.numpy) factorial() (in module jax.scipy.special) FanInConcat() (in module jax.example_libraries.stax) FanOut() (in module jax.example_libraries.stax) ffi_call() (in module jax.ffi) ffi_lowering() (in module jax.ffi) FFT (jax.lax.FftType attribute) fft() (in module jax.lax) (in module jax.numpy.fft) fft2() (in module jax.numpy.fft) fftconvolve() (in module jax.scipy.signal) fftfreq() (in module jax.numpy.fft) fftn() (in module jax.numpy.fft) fftshift() (in module jax.numpy.fft) FftType (class in jax.lax) fill_diagonal() (in module jax.numpy) finfo (class in jax.numpy) fix() (in module jax.numpy) flags (jax.experimental.pallas.tpu.CompilerParams attribute) flat (jax.Array property) flatnonzero() (in module jax.numpy) flatten() (in module jax.tree) (jax.Array method) flatten_with_path() (in module jax.tree) flexible (class in jax.numpy) flip() (in module jax.numpy) fliplr() (in module jax.numpy) flipud() (in module jax.numpy) float0 (in module jax.dtypes) float16 (class in jax.numpy) float32 (class in jax.numpy) float64 (class in jax.numpy) float_ (in module jax.numpy) float_checks (in module jax.experimental.checkify) float_power() (in module jax.numpy) floating (class in jax.numpy) floor() (in module jax.lax) (in module jax.numpy) floor_divide() (in module jax.numpy) fmax() (in module jax.numpy) fmin() (in module jax.numpy) fmod() (in module jax.numpy) fold_in() (in module jax.random) force_tpu_interpret_mode() (in module jax.experimental.pallas.tpu) fori_loop() (in module jax.lax) forward-mode autodiff freeze() (in module jax.ref) fresnel (in module jax.scipy.special) frexp() (in module jax.numpy) from_dlpack() (in module jax.dlpack) (in module jax.numpy) frombuffer() (in module jax.numpy) fromfile() (in module jax.numpy) fromfunction() (in module jax.numpy) fromiter() (in module jax.numpy) frompyfunc() (in module jax.numpy) fromstring() (in module jax.numpy) full() (in module jax.lax) (in module jax.numpy) full_like() (in module jax.lax) (in module jax.numpy) fun_name (jax.export.Exported attribute) functional programming funm() (in module jax.scipy.linalg) G gamma() (in module jax.random) (in module jax.scipy.special) gammainc() (in module jax.scipy.special) gammaincc() (in module jax.scipy.special) gammaln() (in module jax.scipy.special) gammasgn() (in module jax.scipy.special) gather() (in module jax.lax) GatherDimensionNumbers (class in jax.lax) GatherScatterMode (class in jax.lax) gaussian_kde (class in jax.scipy.stats) gcd() (in module jax.numpy) ge() (in module jax.lax) gelu() (in module jax.nn) GeneralConv() (in module jax.example_libraries.stax) GeneralConvTranspose() (in module jax.example_libraries.stax) generalized_normal() (in module jax.random) generic (class in jax.numpy) geometric() (in module jax.random) geomspace() (in module jax.numpy) get() (in module jax.ref) get_backend() (in module jax.extend.backend) get_barrier_semaphore() (in module jax.experimental.pallas.tpu) get_compile_options() (in module jax.extend.backend) get_default_device() (in module jax.extend.backend) get_global() (in module jax.experimental.pallas) get_pipeline_schedule() (in module jax.experimental.pallas.tpu) get_printoptions() (in module jax.numpy) get_scaled_dot_general_config() (in module jax.nn) get_tpu_info() (in module jax.experimental.pallas.tpu) global_array_to_host_local_array() (in module jax.experimental.multihost_utils) global_shards (jax.Array property) glorot_normal() (in module jax.nn.initializers) glorot_uniform() (in module jax.nn.initializers) glu() (in module jax.nn) GMEM (in module jax.experimental.pallas.mosaic_gpu) gmres() (in module jax.scipy.sparse.linalg) GPU grad() (in module jax) (in module jax.experimental.sparse) gradient() (in module jax.numpy) greater() (in module jax.numpy) greater_equal() (in module jax.numpy) grid_point_recorder (jax.experimental.pallas.tpu.InterpretParams attribute) GridDimensionSemantics (class in jax.experimental.pallas.tpu) GridSpec (class in jax.experimental.pallas) gt() (in module jax.lax) gumbel() (in module jax.random) H hamming() (in module jax.numpy) hanning() (in module jax.numpy) hard_sigmoid() (in module jax.nn) hard_silu() (in module jax.nn) hard_swish() (in module jax.nn) hard_tanh() (in module jax.nn) has_side_effects (jax.experimental.pallas.tpu.CompilerParams attribute) has_vjp() (jax.export.Exported method) he_normal() (in module jax.nn.initializers) he_uniform() (in module jax.nn.initializers) heaviside() (in module jax.numpy) hessenberg() (in module jax.lax.linalg) (in module jax.scipy.linalg) hessian() (in module jax) hfft() (in module jax.numpy.fft) HIGHEST (jax.lax.AccuracyMode attribute) hilbert() (in module jax.scipy.linalg) histogram() (in module jax.numpy) histogram2d() (in module jax.numpy) histogram_bin_edges() (in module jax.numpy) histogramdd() (in module jax.numpy) hlo_to_stablehlo (in module jax.extend.mlir) host_local_array_to_global_array() (in module jax.experimental.multihost_utils) householder_product() (in module jax.lax.linalg) hsplit() (in module jax.numpy) hstack() (in module jax.numpy) hyp1f1 (in module jax.scipy.special) hyp2f1 (in module jax.scipy.special) hypot() (in module jax.numpy) I i0() (in module jax.numpy) (in module jax.scipy.special) i0e() (in module jax.scipy.special) i1() (in module jax.scipy.special) i1e() (in module jax.scipy.special) idct() (in module jax.scipy.fft) idctn() (in module jax.scipy.fft) identity() (in module jax.nn) (in module jax.numpy) iexp (jax.numpy.finfo attribute) IFFT (jax.lax.FftType attribute) ifft() (in module jax.numpy.fft) ifft2() (in module jax.numpy.fft) ifftn() (in module jax.numpy.fft) ifftshift() (in module jax.numpy.fft) igamma() (in module jax.lax) igamma_grad_a() (in module jax.lax) igammac() (in module jax.lax) ihfft() (in module jax.numpy.fft) iinfo (class in jax.numpy) imag (jax.Array property) imag() (in module jax.lax) (in module jax.numpy) in_avals (jax.export.Exported attribute) in_shardings_hlo (jax.export.Exported attribute) in_shardings_jax() (jax.export.Exported method) in_tree (jax.export.Exported attribute) (jax.stages.Compiled property) (jax.stages.Lowered property) index_checks (in module jax.experimental.checkify) index_exp (in module jax.numpy) index_in_dim() (in module jax.lax) index_take() (in module jax.lax) indices (jax.experimental.sparse.BCOO attribute) indices() (in module jax.numpy) inexact (class in jax.numpy) init_fn (jax.example_libraries.optimizers.Optimizer attribute) initialize() (in module jax.distributed) Initializer (class in jax.nn.initializers) inner() (in module jax.numpy) insert() (in module jax.numpy) inspect_array_sharding() (in module jax.debug) int16 (class in jax.numpy) int32 (class in jax.numpy) int64 (class in jax.numpy) int8 (class in jax.numpy) int_ (in module jax.numpy) integer (class in jax.numpy) integer_pow() (in module jax.lax) integrate_box_1d() (jax.scipy.stats.gaussian_kde method) integrate_gaussian() (jax.scipy.stats.gaussian_kde method) integrate_kde() (jax.scipy.stats.gaussian_kde method) internal_scratch_in_bytes (jax.experimental.pallas.tpu.CompilerParams attribute) interp() (in module jax.numpy) InterpretParams (class in jax.experimental.pallas.tpu) intersect1d() (in module jax.numpy) inv() (in module jax.numpy.linalg) (in module jax.scipy.linalg) inverse_time_decay() (in module jax.example_libraries.optimizers) invert() (in module jax.numpy) io_callback() (in module jax.experimental) iota() (in module jax.lax) IRFFT (jax.lax.FftType attribute) irfft() (in module jax.numpy.fft) irfft2() (in module jax.numpy.fft) irfftn() (in module jax.numpy.fft) is_accumulator (jax.experimental.pallas.tpu.BufferedRef attribute) is_custom_call() (jax.export.DisabledSafetyCheck method) is_equivalent_to() (jax.sharding.Sharding method) is_finite() (in module jax.lax) is_fully_addressable (jax.Array property) (jax.sharding.NamedSharding property) (jax.sharding.Sharding property) (jax.sharding.SingleDeviceSharding property) is_fully_replicated (jax.Array property) (jax.sharding.NamedSharding property) (jax.sharding.Sharding property) (jax.sharding.SingleDeviceSharding property) is_input (jax.experimental.pallas.tpu.BufferedRef attribute) is_input_output (jax.experimental.pallas.tpu.BufferedRef attribute) is_output (jax.experimental.pallas.tpu.BufferedRef attribute) is_supported_dtype() (in module jax.dlpack) is_symbolic_dim() (in module jax.export) is_tpu_device() (in module jax.experimental.pallas.tpu) isclose() (in module jax.numpy) iscomplex() (in module jax.numpy) iscomplexobj() (in module jax.numpy) isdtype() (in module jax.numpy) isf() (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.logistic) (in module jax.scipy.stats.norm) isfinite() (in module jax.numpy) isin() (in module jax.numpy) isinf() (in module jax.numpy) isnan() (in module jax.numpy) isneginf() (in module jax.numpy) isposinf() (in module jax.numpy) isreal() (in module jax.numpy) isrealobj() (in module jax.numpy) isscalar() (in module jax.numpy) issubdtype() (in module jax.dtypes) (in module jax.numpy) istft() (in module jax.scipy.signal) item() (jax.Array method) itemsize (jax.Array property) itemsize_bits() (in module jax.dtypes) iterable() (in module jax.numpy) ix_() (in module jax.numpy) J jacfwd() (in module jax) JACOBI (jax.lax.linalg.EighImplementation attribute) jacobian() (in module jax) jacrev() (in module jax) jax.ad_checkpoint module jax.debug module jax.distributed module jax.dlpack module jax.dtypes module jax.example_libraries module jax.example_libraries.optimizers module jax.example_libraries.stax module jax.experimental.checkify module jax.experimental.compilation_cache.compilation_cache module jax.experimental.custom_dce module jax.experimental.custom_partitioning module jax.experimental.jet module jax.experimental.key_reuse module jax.experimental.mesh_utils module jax.experimental.multihost_utils module jax.experimental.pallas module jax.experimental.pallas.mosaic_gpu module jax.experimental.pallas.tpu module jax.experimental.pallas.triton module jax.experimental.serialize_executable module jax.experimental.sparse module jax.experimental.sparse.linalg module jax.export module jax.export.maximum_supported_serialization_version (in module jax.export) jax.export.minimum_supported_serialization_version (in module jax.export) jax.extend module jax.extend.backend module jax.extend.backend.ifrt_proxy module jax.extend.core module jax.extend.core.primitives module jax.extend.linear_util module jax.extend.mlir module jax.extend.mlir.dialects module jax.extend.mlir.ir module jax.extend.mlir.passmanager module jax.extend.random module jax.ffi module jax.flatten_util module jax.image module jax.lax module jax.lax.linalg module jax.nn module jax.nn.initializers module jax.numpy module jax.numpy.fft module jax.numpy.linalg module jax.ops module jax.profiler module jax.random module jax.ref module jax.scipy.cluster.vq module jax.scipy.fft module jax.scipy.integrate module jax.scipy.interpolate module jax.scipy.linalg module jax.scipy.ndimage module jax.scipy.optimize module jax.scipy.signal module jax.scipy.sparse.linalg module jax.scipy.spatial.transform module jax.scipy.special module jax.scipy.stats module jax.scipy.stats.bernoulli module jax.scipy.stats.beta module jax.scipy.stats.betabinom module jax.scipy.stats.binom module jax.scipy.stats.cauchy module jax.scipy.stats.chi2 module jax.scipy.stats.dirichlet module jax.scipy.stats.expon module jax.scipy.stats.gamma module jax.scipy.stats.gennorm module jax.scipy.stats.geom module jax.scipy.stats.gumbel_l module jax.scipy.stats.gumbel_r module jax.scipy.stats.laplace module jax.scipy.stats.logistic module jax.scipy.stats.multinomial module jax.scipy.stats.multivariate_normal module jax.scipy.stats.nbinom module jax.scipy.stats.norm module jax.scipy.stats.pareto module jax.scipy.stats.poisson module jax.scipy.stats.t module jax.scipy.stats.truncnorm module jax.scipy.stats.uniform module jax.scipy.stats.vonmises module jax.scipy.stats.wrapcauchy module jax.sharding module jax.stages module jax.test_util module jax.tree module jax.tree_util module jax.typing module JAXIndexError (class in jax.errors) jaxpr Jaxpr (class in jax.extend.core) jaxpr_as_fun (in module jax.extend.core) JaxprEqn (class in jax.extend.core) JaxRuntimeError (class in jax.errors) JAXSparse (class in jax.experimental.sparse) JAXTypeError (class in jax.errors) jet() (in module jax.experimental.jet) JIT jit() (in module jax) JoinPoint (class in jax.example_libraries.optimizers) JVP jvp() (in module jax) K kaiming_normal() (in module jax.nn.initializers) kaiming_uniform() (in module jax.nn.initializers) kaiser() (in module jax.numpy) kernel() (in module jax.experimental.pallas) (in module jax.experimental.pallas.mosaic_gpu) kernel_type (jax.experimental.pallas.tpu.CompilerParams attribute) key() (in module jax.random) key_data() (in module jax.random) KeyEntry (in module jax.tree_util) KeyPath (in module jax.tree_util) KeyReuseError (class in jax.errors) keystr() (in module jax.tree_util) kl_div() (in module jax.scipy.special) kron() (in module jax.numpy) L l2_norm() (in module jax.example_libraries.optimizers) LAPACK (jax.lax.linalg.EigImplementation attribute) laplace() (in module jax.random) Layout (class in jax.experimental.pallas.mosaic_gpu) layout_cast() (in module jax.experimental.pallas.mosaic_gpu) lcm() (in module jax.numpy) ldexp() (in module jax.numpy) le() (in module jax.lax) leaky_relu() (in module jax.nn) leaves() (in module jax.tree) leaves_with_path() (in module jax.tree) lecun_normal() (in module jax.nn.initializers) lecun_uniform() (in module jax.nn.initializers) left_shift() (in module jax.numpy) less() (in module jax.numpy) less_equal() (in module jax.numpy) lexsort() (in module jax.numpy) lgamma() (in module jax.lax) linear_transpose() (in module jax) linearize() (in module jax) linspace() (in module jax.numpy) Literal (class in jax.extend.core) live_arrays() (in module jax) load() (in module jax.experimental.pallas) (in module jax.numpy) lobpcg_standard() (in module jax.experimental.sparse.linalg) local_device_count() (in module jax) local_devices() (in module jax) log() (in module jax.lax) (in module jax.numpy) log10() (in module jax.numpy) log1mexp (in module jax.nn) log1p() (in module jax.lax) (in module jax.numpy) log2() (in module jax.numpy) log_compiles (in module jax) log_ndtr (in module jax.scipy.special) log_sigmoid() (in module jax.nn) log_softmax() (in module jax.nn) (in module jax.scipy.special) logaddexp (in module jax.numpy) logaddexp2 (in module jax.numpy) logcdf() (in module jax.scipy.stats.beta) (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.chi2) (in module jax.scipy.stats.expon) (in module jax.scipy.stats.gamma) (in module jax.scipy.stats.gumbel_l) (in module jax.scipy.stats.gumbel_r) (in module jax.scipy.stats.norm) (in module jax.scipy.stats.pareto) (in module jax.scipy.stats.truncnorm) loggamma() (in module jax.random) logical_and (in module jax.numpy) logical_not() (in module jax.numpy) logical_or (in module jax.numpy) logical_xor (in module jax.numpy) logistic() (in module jax.lax) (in module jax.random) logit (in module jax.scipy.special) logmeanexp() (in module jax.nn) lognormal() (in module jax.random) logpdf() (in module jax.scipy.stats.beta) (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.chi2) (in module jax.scipy.stats.dirichlet) (in module jax.scipy.stats.expon) (in module jax.scipy.stats.gamma) (in module jax.scipy.stats.gennorm) (in module jax.scipy.stats.gumbel_l) (in module jax.scipy.stats.gumbel_r) (in module jax.scipy.stats.laplace) (in module jax.scipy.stats.logistic) (in module jax.scipy.stats.multivariate_normal) (in module jax.scipy.stats.norm) (in module jax.scipy.stats.pareto) (in module jax.scipy.stats.t) (in module jax.scipy.stats.truncnorm) (in module jax.scipy.stats.uniform) (in module jax.scipy.stats.vonmises) (in module jax.scipy.stats.wrapcauchy) (jax.scipy.stats.gaussian_kde method) logpmf() (in module jax.scipy.stats.bernoulli) (in module jax.scipy.stats.betabinom) (in module jax.scipy.stats.binom) (in module jax.scipy.stats.geom) (in module jax.scipy.stats.multinomial) (in module jax.scipy.stats.nbinom) (in module jax.scipy.stats.poisson) logsf() (in module jax.scipy.stats.beta) (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.chi2) (in module jax.scipy.stats.expon) (in module jax.scipy.stats.gamma) (in module jax.scipy.stats.gumbel_l) (in module jax.scipy.stats.gumbel_r) (in module jax.scipy.stats.norm) (in module jax.scipy.stats.pareto) (in module jax.scipy.stats.truncnorm) logspace() (in module jax.numpy) logsumexp() (in module jax.nn) (in module jax.scipy.special) loop() (in module jax.experimental.pallas) lower() (jax.stages.Traced method) (jax.stages.Wrapped method) Lowered (class in jax.stages) lpmn() (in module jax.scipy.special) lpmn_values() (in module jax.scipy.special) lstsq() (in module jax.numpy.linalg) lt() (in module jax.lax) lu() (in module jax.lax.linalg) (in module jax.scipy.linalg) lu_factor() (in module jax.scipy.linalg) lu_pivots_to_permutation() (in module jax.lax.linalg) lu_solve() (in module jax.scipy.linalg) M machep (jax.numpy.finfo attribute) MAGMA (jax.lax.linalg.EigImplementation attribute) make_array_from_callback() (in module jax) make_array_from_process_local_data() (in module jax) make_array_from_single_device_arrays() (in module jax) make_async_copy() (in module jax.experimental.pallas.tpu) make_async_remote_copy() (in module jax.experimental.pallas.tpu) make_jaxpr() (in module jax) make_mesh() (in module jax) make_pipeline_allocations() (in module jax.experimental.pallas.tpu) make_schedule() (in module jax.example_libraries.optimizers) map() (in module jax.lax) (in module jax.tree) map_coordinates() (in module jax.scipy.ndimage) map_with_path() (in module jax.tree) mask_indices() (in module jax.numpy) matmul() (in module jax.numpy) (in module jax.numpy.linalg) matrix_norm() (in module jax.numpy.linalg) matrix_power() (in module jax.numpy.linalg) matrix_rank() (in module jax.numpy.linalg) matrix_transpose() (in module jax.numpy) (in module jax.numpy.linalg) matvec() (in module jax.numpy) max (jax.numpy.finfo attribute) max() (in module jax.lax) (in module jax.numpy) (jax.Array method) max_concurrent_steps (jax.experimental.pallas.mosaic_gpu.CompilerParams attribute) max_contiguous() (in module jax.experimental.pallas) maxexp (jax.numpy.finfo attribute) maximum (in module jax.numpy) maximum_supported_calling_convention_version (in module jax.export) MaxPool() (in module jax.example_libraries.stax) maxwell() (in module jax.random) mean() (in module jax.numpy) (jax.Array method) median() (in module jax.numpy) memory_analysis() (jax.stages.Compiled method) memory_kind (jax.sharding.NamedSharding property) (jax.sharding.Sharding property) (jax.sharding.SingleDeviceSharding property) memory_space (jax.experimental.pallas.tpu.BufferedRef attribute) MemorySpace (class in jax.experimental.pallas.mosaic_gpu) (class in jax.experimental.pallas.tpu) merge_linear_aux() (in module jax.extend.linear_util) Mesh (class in jax.sharding) mesh (jax.sharding.NamedSharding property) meshgrid() (in module jax.numpy) mgrid (in module jax.numpy) min (jax.numpy.finfo attribute) min() (in module jax.lax) (in module jax.numpy) (jax.Array method) minexp (jax.numpy.finfo attribute) minimize() (in module jax.scipy.optimize) minimum (in module jax.numpy) minimum_supported_calling_convention_version (in module jax.export) mish() (in module jax.nn) mlir_module() (jax.export.Exported method) mlir_module_serialized (jax.export.Exported attribute) mod() (in module jax.numpy) mode() (in module jax.scipy.stats) modf() (in module jax.numpy) module jax.ad_checkpoint jax.debug jax.distributed jax.dlpack jax.dtypes jax.example_libraries jax.example_libraries.optimizers jax.example_libraries.stax jax.experimental.checkify jax.experimental.compilation_cache.compilation_cache jax.experimental.custom_dce jax.experimental.custom_partitioning jax.experimental.jet jax.experimental.key_reuse jax.experimental.mesh_utils jax.experimental.multihost_utils jax.experimental.pallas jax.experimental.pallas.mosaic_gpu jax.experimental.pallas.tpu jax.experimental.pallas.triton jax.experimental.serialize_executable jax.experimental.sparse jax.experimental.sparse.linalg jax.export jax.extend jax.extend.backend jax.extend.backend.ifrt_proxy jax.extend.core jax.extend.core.primitives jax.extend.linear_util jax.extend.mlir jax.extend.mlir.dialects jax.extend.mlir.ir jax.extend.mlir.passmanager jax.extend.random jax.ffi jax.flatten_util jax.image jax.lax jax.lax.linalg jax.nn jax.nn.initializers jax.numpy jax.numpy.fft jax.numpy.linalg jax.ops jax.profiler jax.random jax.ref jax.scipy.cluster.vq jax.scipy.fft jax.scipy.integrate jax.scipy.interpolate jax.scipy.linalg jax.scipy.ndimage jax.scipy.optimize jax.scipy.signal jax.scipy.sparse.linalg jax.scipy.spatial.transform jax.scipy.special jax.scipy.stats jax.scipy.stats.bernoulli jax.scipy.stats.beta jax.scipy.stats.betabinom jax.scipy.stats.binom jax.scipy.stats.cauchy jax.scipy.stats.chi2 jax.scipy.stats.dirichlet jax.scipy.stats.expon jax.scipy.stats.gamma jax.scipy.stats.gennorm jax.scipy.stats.geom jax.scipy.stats.gumbel_l jax.scipy.stats.gumbel_r jax.scipy.stats.laplace jax.scipy.stats.logistic jax.scipy.stats.multinomial jax.scipy.stats.multivariate_normal jax.scipy.stats.nbinom jax.scipy.stats.norm jax.scipy.stats.pareto jax.scipy.stats.poisson jax.scipy.stats.t jax.scipy.stats.truncnorm jax.scipy.stats.uniform jax.scipy.stats.vonmises jax.scipy.stats.wrapcauchy jax.sharding jax.stages jax.test_util jax.tree jax.tree_util jax.typing module_kept_var_idx (jax.export.Exported attribute) momentum() (in module jax.example_libraries.optimizers) moveaxis() (in module jax.numpy) mT (jax.Array property) mul() (in module jax.lax) multi_dot() (in module jax.numpy.linalg) multigammaln() (in module jax.scipy.special) multimem_load_reduce() (in module jax.experimental.pallas.mosaic_gpu) multimem_store() (in module jax.experimental.pallas.mosaic_gpu) multinomial() (in module jax.random) multiple_of() (in module jax.experimental.pallas) multiply (in module jax.numpy) multivariate_normal() (in module jax.random) N named_call() (in module jax) named_scope() (in module jax) NamedSharding (class in jax.sharding) nan_checks (in module jax.experimental.checkify) nan_to_num() (in module jax.numpy) nanargmax() (in module jax.numpy) nanargmin() (in module jax.numpy) nancumprod() (in module jax.numpy) nancumsum() (in module jax.numpy) nanmax() (in module jax.numpy) nanmean() (in module jax.numpy) nanmedian() (in module jax.numpy) nanmin() (in module jax.numpy) nanpercentile() (in module jax.numpy) nanprod() (in module jax.numpy) nanquantile() (in module jax.numpy) nanstd() (in module jax.numpy) nansum() (in module jax.numpy) nanvar() (in module jax.numpy) nbytes (jax.Array property) nd_loop() (in module jax.experimental.pallas.mosaic_gpu) ndarray (in module jax.numpy) ndim (jax.Array property) ndim() (in module jax.numpy) ndtr() (in module jax.scipy.special) ndtri() (in module jax.scipy.special) ne() (in module jax.lax) neg() (in module jax.lax) negative (in module jax.numpy) negep (jax.numpy.finfo attribute) nesterov() (in module jax.example_libraries.optimizers) new_ref() (in module jax.ref) nexp (jax.numpy.finfo attribute) next_fetch_smem (jax.experimental.pallas.tpu.BufferedRef attribute) next_fetch_sreg (jax.experimental.pallas.tpu.BufferedRef attribute) nextafter() (in module jax.lax) (in module jax.numpy) nmant (jax.numpy.finfo attribute) no_tracing (in module jax) NonConcreteBooleanIndexError (class in jax.errors) nonzero() (in module jax.numpy) (jax.Array method) norm() (in module jax.numpy.linalg) normal() (in module jax.nn.initializers) (in module jax.random) not_equal() (in module jax.numpy) nothing_saveable() (jax.checkpoint_policies method) nr_devices (jax.export.Exported attribute) num_arrivals (jax.experimental.pallas.mosaic_gpu.Barrier attribute) num_barriers (jax.experimental.pallas.mosaic_gpu.Barrier attribute) num_devices (jax.sharding.NamedSharding property) (jax.sharding.Sharding property) (jax.sharding.SingleDeviceSharding property) num_programs() (in module jax.experimental.pallas) num_stages (jax.experimental.pallas.triton.CompilerParams attribute) num_warps (jax.experimental.pallas.triton.CompilerParams attribute) number (class in jax.numpy) numpy_rank_promotion (in module jax) O object_ (class in jax.numpy) offload_dot_with_no_batch_dims() (jax.checkpoint_policies method) ogrid (in module jax.numpy) one_hot() (in module jax.nn) ones() (in module jax.nn.initializers) (in module jax.numpy) ones_like() (in module jax.numpy) optimization_barrier() (in module jax.lax) Optimizer (class in jax.example_libraries.optimizers) optimizer() (in module jax.example_libraries.optimizers) OptimizeResults (class in jax.scipy.optimize) OptimizerState (class in jax.example_libraries.optimizers) ordered_effects (jax.export.Exported attribute) orders_tensor_core (jax.experimental.pallas.mosaic_gpu.Barrier attribute) orthogonal() (in module jax.nn.initializers) (in module jax.random) out_avals (jax.export.Exported attribute) out_shardings_hlo (jax.export.Exported attribute) out_shardings_jax() (jax.export.Exported method) out_tree (jax.export.Exported attribute) outer() (in module jax.numpy) (in module jax.numpy.linalg) P pack_optimizer_state() (in module jax.example_libraries.optimizers) packbits() (in module jax.numpy) packed_state (jax.example_libraries.optimizers.OptimizerState attribute) pad() (in module jax.lax) (in module jax.numpy) pallas_call() (in module jax.experimental.pallas) parallel() (in module jax.example_libraries.stax) params_fn (jax.example_libraries.optimizers.Optimizer attribute) pareto() (in module jax.random) Partial (class in jax.tree_util) partition() (in module jax.numpy) PartitionSpec (class in jax.sharding) pascal() (in module jax.scipy.linalg) pdf() (in module jax.scipy.stats.beta) (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.chi2) (in module jax.scipy.stats.dirichlet) (in module jax.scipy.stats.expon) (in module jax.scipy.stats.gamma) (in module jax.scipy.stats.gennorm) (in module jax.scipy.stats.gumbel_l) (in module jax.scipy.stats.gumbel_r) (in module jax.scipy.stats.laplace) (in module jax.scipy.stats.logistic) (in module jax.scipy.stats.multivariate_normal) (in module jax.scipy.stats.norm) (in module jax.scipy.stats.pareto) (in module jax.scipy.stats.t) (in module jax.scipy.stats.truncnorm) (in module jax.scipy.stats.uniform) (in module jax.scipy.stats.vonmises) (in module jax.scipy.stats.wrapcauchy) (jax.scipy.stats.gaussian_kde method) percentile() (in module jax.numpy) permutation() (in module jax.random) permute_dims() (in module jax.numpy) piecewise() (in module jax.numpy) piecewise_constant() (in module jax.example_libraries.optimizers) pinv() (in module jax.numpy.linalg) place() (in module jax.numpy) planar_snake() (in module jax.experimental.pallas.mosaic_gpu) platform() (jax.export.DisabledSafetyCheck class method) platform_dependent() (in module jax.lax) platforms (jax.export.Exported attribute) pmap() (in module jax) pmax() (in module jax.lax) pmean() (in module jax.lax) pmf() (in module jax.scipy.stats.bernoulli) (in module jax.scipy.stats.betabinom) (in module jax.scipy.stats.binom) (in module jax.scipy.stats.geom) (in module jax.scipy.stats.multinomial) (in module jax.scipy.stats.nbinom) (in module jax.scipy.stats.poisson) pmin() (in module jax.lax) poch (in module jax.scipy.special) poisson() (in module jax.random) polar() (in module jax.scipy.linalg) poly() (in module jax.numpy) polyadd() (in module jax.numpy) polyder() (in module jax.numpy) polydiv() (in module jax.numpy) polyfit() (in module jax.numpy) polygamma() (in module jax.lax) (in module jax.scipy.special) polyint() (in module jax.numpy) polymul() (in module jax.numpy) polynomial_decay() (in module jax.example_libraries.optimizers) polysub() (in module jax.numpy) polyval() (in module jax.numpy) population_count() (in module jax.lax) positive() (in module jax.numpy) pow() (in module jax.lax) (in module jax.numpy) power() (in module jax.numpy) ppermute() (in module jax.lax) ppf() (in module jax.scipy.stats.bernoulli) (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.expon) (in module jax.scipy.stats.gumbel_l) (in module jax.scipy.stats.gumbel_r) (in module jax.scipy.stats.logistic) (in module jax.scipy.stats.norm) (in module jax.scipy.stats.pareto) (in module jax.scipy.stats.uniform) Precision (class in jax.lax) precision (jax.numpy.finfo attribute) PrecisionLike (in module jax.lax) precv() (in module jax.lax) PrefetchScalarGridSpec (class in jax.experimental.pallas.tpu) primitive Primitive (class in jax.extend.core) print() (in module jax.debug) print_environment_info() (in module jax) printoptions() (in module jax.numpy) prng_key (class in jax.dtypes) prng_seed() (in module jax.experimental.pallas.tpu) PRNGKey() (in module jax.random) process_allgather() (in module jax.experimental.multihost_utils) process_count() (in module jax) process_index() (in module jax) process_indices() (in module jax) prod() (in module jax.numpy) (jax.Array method) profile_dir (jax.experimental.pallas.mosaic_gpu.CompilerParams attribute) profile_space (jax.experimental.pallas.mosaic_gpu.CompilerParams attribute) program_id() (in module jax.experimental.pallas) promote_types() (in module jax.numpy) psend() (in module jax.lax) pshuffle() (in module jax.lax) psum() (in module jax.lax) psum_scatter() (in module jax.lax) pswapaxes() (in module jax.lax) ptp() (in module jax.numpy) (jax.Array method) pure function pure_callback() (in module jax) put() (in module jax.numpy) put_along_axis() (in module jax.numpy) pycapsule() (in module jax.ffi) pytree Q QDWH (jax.lax.linalg.EighImplementation attribute) qdwh() (in module jax.lax.linalg) QR (jax.lax.linalg.EighImplementation attribute) qr() (in module jax.lax.linalg) (in module jax.numpy.linalg) (in module jax.scipy.linalg) quantile() (in module jax.numpy) query_cluster_cancel() (in module jax.experimental.pallas.mosaic_gpu) R r_ (in module jax.numpy) rad2deg() (in module jax.numpy) rademacher() (in module jax.random) radians() (in module jax.numpy) ragged_all_to_all() (in module jax.lax) ragged_dot() (in module jax.lax) ragged_dot_general() (in module jax.lax) RaggedDotDimensionNumbers (class in jax.lax) randint() (in module jax.random) random_bcoo() (in module jax.experimental.sparse) random_gamma_grad() (in module jax.lax) random_seed (jax.experimental.pallas.tpu.InterpretParams attribute) random_seed() (in module jax.extend.random) RandomAlgorithm (class in jax.lax) rankdata() (in module jax.scipy.stats) ravel() (in module jax.numpy) (jax.Array method) ravel_multi_index() (in module jax.numpy) ravel_pytree() (in module jax.flatten_util) rayleigh() (in module jax.random) rbg_prng_impl (in module jax.extend.random) real (jax.Array property) real() (in module jax.lax) (in module jax.numpy) reciprocal() (in module jax.lax) (in module jax.numpy) reduce() (in module jax.lax) (in module jax.tree) reduce_and() (in module jax.lax) reduce_associative() (in module jax.tree) reduce_max() (in module jax.lax) reduce_min() (in module jax.lax) reduce_or() (in module jax.lax) reduce_precision() (in module jax.lax) reduce_prod() (in module jax.lax) reduce_sum() (in module jax.lax) reduce_window() (in module jax.lax) reduce_xor() (in module jax.lax) reduced (jax.sharding.PartitionSpec property) Ref (class in jax.ref) refine_polymorphic_shapes (in module jax.extend.mlir) register_backend_cache() (in module jax.extend.backend) register_backend_factory() (in module jax.extend.backend) register_dataclass() (in module jax.tree_util) register_ffi_target() (in module jax.ffi) register_ffi_type() (in module jax.ffi) register_namedtuple_serialization() (in module jax.export) register_pytree_node() (in module jax.tree_util) register_pytree_node_class() (in module jax.tree_util) register_pytree_node_serialization() (in module jax.export) register_pytree_with_keys() (in module jax.tree_util) register_pytree_with_keys_class() (in module jax.tree_util) register_static() (in module jax.tree_util) RegularGridInterpolator (class in jax.scipy.interpolate) rel_entr() (in module jax.scipy.special) relu (in module jax.nn) relu6 (in module jax.nn) rem() (in module jax.lax) remainder() (in module jax.numpy) repeat() (in module jax.numpy) (jax.Array method) resample() (jax.scipy.stats.gaussian_kde method) reset_cache() (in module jax.experimental.compilation_cache.compilation_cache) reset_tpu_interpret_mode_state() (in module jax.experimental.pallas.tpu) reshape() (in module jax.lax) (in module jax.numpy) (jax.Array method) resize() (in module jax.image) (in module jax.numpy) ResizeMethod (class in jax.image) resolution (jax.numpy.finfo attribute) result_type() (in module jax.dtypes) (in module jax.numpy) rev() (in module jax.lax) reverse-mode autodiff RFFT (jax.lax.FftType attribute) rfft() (in module jax.numpy.fft) rfft2() (in module jax.numpy.fft) rfftfreq() (in module jax.numpy.fft) rfftn() (in module jax.numpy.fft) right_shift() (in module jax.numpy) rint() (in module jax.numpy) rmsprop() (in module jax.example_libraries.optimizers) rmsprop_momentum() (in module jax.example_libraries.optimizers) rng_bit_generator() (in module jax.lax) RNG_DEFAULT (jax.lax.RandomAlgorithm attribute) RNG_PHILOX (jax.lax.RandomAlgorithm attribute) RNG_THREE_FRY (jax.lax.RandomAlgorithm attribute) rng_uniform() (in module jax.lax) roll() (in module jax.numpy) rollaxis() (in module jax.numpy) roots() (in module jax.numpy) rot90() (in module jax.numpy) Rotation (class in jax.scipy.spatial.transform) round() (in module jax.lax) (in module jax.numpy) (jax.Array method) RoundingMethod (class in jax.lax) rsf2csf() (in module jax.scipy.linalg) rsqrt() (in module jax.lax) run_on_first_core() (in module jax.experimental.pallas.tpu) run_scoped() (in module jax.experimental.pallas) runtime_executable() (jax.stages.Compiled method) S s_ (in module jax.numpy) sample_block() (in module jax.experimental.pallas.tpu) save() (in module jax.numpy) save_and_offload_only_these_names() (jax.checkpoint_policies method) save_any_names_but_these() (jax.checkpoint_policies method) save_device_memory_profile() (in module jax.profiler) save_from_both_policies() (jax.checkpoint_policies method) save_only_these_names() (jax.checkpoint_policies method) savez() (in module jax.numpy) scalar_type_of() (in module jax.dtypes) scale_and_translate() (in module jax.image) scaled_dot_general() (in module jax.nn) scaled_matmul() (in module jax.nn) scan() (in module jax.lax) scatter() (in module jax.lax) scatter_add() (in module jax.lax) scatter_apply() (in module jax.lax) scatter_max() (in module jax.lax) scatter_min() (in module jax.lax) scatter_mul() (in module jax.lax) scatter_sub() (in module jax.lax) ScatterDimensionNumbers (class in jax.lax) schur() (in module jax.lax.linalg) (in module jax.scipy.linalg) searchsorted() (in module jax.numpy) (jax.Array method) seed_with_impl() (in module jax.extend.random) segment_max() (in module jax.ops) segment_min() (in module jax.ops) segment_prod() (in module jax.ops) segment_sum() (in module jax.ops) select() (in module jax.lax) (in module jax.numpy) select_n() (in module jax.lax) selu() (in module jax.nn) sem() (in module jax.scipy.stats) sem_recvs (jax.experimental.pallas.tpu.BufferedRef attribute) sem_sends (jax.experimental.pallas.tpu.BufferedRef attribute) semaphore_read() (in module jax.experimental.pallas) semaphore_signal() (in module jax.experimental.pallas) semaphore_signal_parallel() (in module jax.experimental.pallas.mosaic_gpu) semaphore_wait() (in module jax.experimental.pallas) SemaphoreSignal (class in jax.experimental.pallas.mosaic_gpu) SemaphoreType (class in jax.experimental.pallas.mosaic_gpu) (class in jax.experimental.pallas.tpu) sequential_vmap() (in module jax.custom_batching) serial() (in module jax.example_libraries.stax) serialization_format (jax.experimental.pallas.tpu.CompilerParams attribute) serialize() (in module jax.experimental.serialize_executable) (jax.export.Exported method) serialize_portable_artifact (in module jax.extend.mlir) set() (in module jax.ref) set_cache_dir() (in module jax.experimental.compilation_cache.compilation_cache) set_max_registers() (in module jax.experimental.pallas.mosaic_gpu) set_mesh (class in jax) set_printoptions() (in module jax.numpy) set_tpu_interpret_mode() (in module jax.experimental.pallas.tpu) setdiff1d() (in module jax.numpy) setxor1d() (in module jax.numpy) sf() (in module jax.scipy.stats.beta) (in module jax.scipy.stats.cauchy) (in module jax.scipy.stats.chi2) (in module jax.scipy.stats.expon) (in module jax.scipy.stats.gamma) (in module jax.scipy.stats.gumbel_l) (in module jax.scipy.stats.gumbel_r) (in module jax.scipy.stats.logistic) (in module jax.scipy.stats.norm) (in module jax.scipy.stats.pareto) (in module jax.scipy.stats.truncnorm) sgd() (in module jax.example_libraries.optimizers) shape (jax.Array property) shape() (in module jax.numpy) shape_dependent() (in module jax.example_libraries.stax) ShapeDtypeStruct (class in jax) shard_map() (in module jax) shard_shape() (jax.sharding.Sharding method) Sharding (class in jax.sharding) sharding (jax.Array property) shift_left() (in module jax.lax) shift_right_arithmetic() (in module jax.lax) shift_right_logical() (in module jax.lax) shutdown() (in module jax.distributed) sici (in module jax.scipy.special) sigmoid() (in module jax.nn) sign() (in module jax.lax) (in module jax.numpy) signbit() (in module jax.numpy) signedinteger (class in jax.numpy) silu() (in module jax.nn) sin() (in module jax.lax) (in module jax.numpy) sinc() (in module jax.numpy) single (in module jax.numpy) SingleDeviceSharding (class in jax.sharding) sinh() (in module jax.lax) (in module jax.numpy) size (jax.Array property) size() (in module jax.numpy) skip_device_barrier (jax.experimental.pallas.tpu.CompilerParams attribute) Slerp (class in jax.scipy.spatial.transform) Slice (class in jax.experimental.pallas) slice() (in module jax.lax) slice_in_dim() (in module jax.lax) slogdet() (in module jax.numpy.linalg) sm3() (in module jax.example_libraries.optimizers) smallest_normal (jax.numpy.finfo attribute) smallest_subnormal (jax.numpy.finfo attribute) smap() (in module jax) SMEM (in module jax.experimental.pallas.mosaic_gpu) soft_sign() (in module jax.nn) softmax() (in module jax.nn) (in module jax.scipy.special) softplus() (in module jax.nn) solve() (in module jax.numpy.linalg) (in module jax.scipy.linalg) solve_sylvester() (in module jax.scipy.linalg) solve_triangular() (in module jax.scipy.linalg) sort() (in module jax.lax) (in module jax.numpy) (jax.Array method) sort_complex() (in module jax.numpy) sort_key_val() (in module jax.lax) spacing() (in module jax.numpy) sparse_plus() (in module jax.nn) sparse_sigmoid() (in module jax.nn) sparsify() (in module jax.experimental.sparse) spec (jax.experimental.pallas.tpu.BufferedRef attribute) (jax.sharding.NamedSharding property) spence() (in module jax.scipy.special) sph_harm() (in module jax.scipy.special) split() (in module jax.lax) (in module jax.numpy) (in module jax.random) SPMD spsolve() (in module jax.experimental.sparse.linalg) sqrt() (in module jax.lax) (in module jax.numpy) sqrtm() (in module jax.scipy.linalg) square() (in module jax.lax) (in module jax.numpy) squareplus() (in module jax.nn) squeeze() (in module jax.lax) (in module jax.numpy) (jax.Array method) stack() (in module jax.numpy) standardize() (in module jax.nn) start_server() (in module jax.profiler) start_trace() (in module jax.profiler) stateful_bernoulli() (in module jax.experimental.pallas.tpu) stateful_bits() (in module jax.experimental.pallas.tpu) stateful_normal() (in module jax.experimental.pallas.tpu) stateful_uniform() (in module jax.experimental.pallas.tpu) static std() (in module jax.numpy) (jax.Array method) StepTraceAnnotation (class in jax.profiler) stft() (in module jax.scipy.signal) stop_gradient() (in module jax.lax) stop_trace() (in module jax.profiler) store() (in module jax.experimental.pallas) StoreException structure() (in module jax.tree) sub() (in module jax.lax) subtract (in module jax.numpy) subtree_defs (jax.example_libraries.optimizers.OptimizerState attribute) sum() (in module jax.numpy) (jax.Array method) SumPool() (in module jax.example_libraries.stax) supported_lhs_types (jax.lax.DotAlgorithmPreset property) supported_output_types() (jax.lax.DotAlgorithmPreset method) supported_rhs_types (jax.lax.DotAlgorithmPreset property) svd() (in module jax.lax.linalg) (in module jax.numpy.linalg) (in module jax.scipy.linalg) SvdAlgorithm (class in jax.lax.linalg) svdvals() (in module jax.numpy.linalg) swap (jax.experimental.pallas.tpu.BufferedRef attribute) swap() (in module jax.experimental.pallas) (in module jax.ref) swapaxes() (in module jax.numpy) (jax.Array method) swish() (in module jax.nn) switch() (in module jax.lax) SwizzleTransform (class in jax.experimental.pallas.mosaic_gpu) symbolic_args_specs() (in module jax.export) symbolic_shape() (in module jax.export) SymbolicScope (class in jax.export) symmetric_product() (in module jax.lax.linalg) sync_copy() (in module jax.experimental.pallas.tpu) sync_global_devices() (in module jax.experimental.multihost_utils) T T (jax.Array property) t() (in module jax.random) take() (in module jax.numpy) (jax.Array method) take_along_axis() (in module jax.numpy) tan() (in module jax.lax) (in module jax.numpy) tanh() (in module jax.lax) (in module jax.nn) (in module jax.numpy) tcgen05_commit_arrive() (in module jax.experimental.pallas.mosaic_gpu) tcgen05_mma() (in module jax.experimental.pallas.mosaic_gpu) tensordot() (in module jax.numpy) (in module jax.numpy.linalg) tensorinv() (in module jax.numpy.linalg) tensorsolve() (in module jax.numpy.linalg) TF32_TF32_F32 (jax.lax.DotAlgorithmPreset attribute) TF32_TF32_F32_X3 (jax.lax.DotAlgorithmPreset attribute) threefry2x32_p (in module jax.extend.random) threefry_2x32() (in module jax.extend.random) threefry_prng_impl (in module jax.extend.random) tile() (in module jax.numpy) TilingTransform (class in jax.experimental.pallas.mosaic_gpu) tiny (jax.numpy.finfo attribute) to_device() (jax.Array method) TO_NEAREST_EVEN (jax.lax.RoundingMethod attribute) to_pallas_key() (in module jax.experimental.pallas.tpu) todense() (in module jax.experimental.sparse) toeplitz() (in module jax.scipy.linalg) Token (class in jax.extend.core) Tolerance (class in jax.lax) top_k() (in module jax.lax) TPU TpuInfo (class in jax.experimental.pallas.tpu) trace() (in module jax.numpy) (in module jax.numpy.linalg) (in module jax.profiler) (jax.Array method) (jax.stages.Wrapped method) TraceAnnotation (class in jax.profiler) Traced (class in jax.stages) Tracer TracerArrayConversionError (class in jax.errors) TracerBoolConversionError (class in jax.errors) TracerIntegerConversionError (class in jax.errors) transfer_guard() (in module jax) transformation (in module jax.extend.linear_util) transformation2 (in module jax.extend.linear_util) transformation_with_aux (in module jax.extend.linear_util) transformation_with_aux2 (in module jax.extend.linear_util) transforms (jax.experimental.pallas.mosaic_gpu.BlockSpec attribute) transpose() (in module jax.lax) (in module jax.numpy) (in module jax.tree) (jax.Array method) TransposeTransform (class in jax.experimental.pallas.mosaic_gpu) trapezoid() (in module jax.numpy) (in module jax.scipy.integrate) tree_all() (in module jax.tree_util) tree_broadcast() (in module jax.tree_util) tree_def (jax.example_libraries.optimizers.OptimizerState attribute) tree_flatten() (in module jax.tree_util) tree_flatten_with_path() (in module jax.tree_util) tree_leaves() (in module jax.tree_util) tree_leaves_with_path() (in module jax.tree_util) tree_map() (in module jax.tree_util) tree_map_with_path() (in module jax.tree_util) tree_reduce() (in module jax.tree_util) tree_reduce_associative() (in module jax.tree_util) tree_structure() (in module jax.tree_util) tree_transpose() (in module jax.tree_util) tree_unflatten() (in module jax.tree_util) treedef_children() (in module jax.tree_util) treedef_is_leaf() (in module jax.tree_util) treedef_tuple() (in module jax.tree_util) tri() (in module jax.numpy) triangular() (in module jax.random) triangular_solve() (in module jax.lax.linalg) tridiagonal() (in module jax.lax.linalg) tridiagonal_solve() (in module jax.lax.linalg) tril() (in module jax.numpy) tril_indices() (in module jax.numpy) tril_indices_from() (in module jax.numpy) trim_zeros() (in module jax.numpy) triu() (in module jax.numpy) triu_indices() (in module jax.numpy) triu_indices_from() (in module jax.numpy) true_divide() (in module jax.numpy) trunc() (in module jax.numpy) truncated_normal() (in module jax.nn.initializers) (in module jax.random) try_cluster_cancel() (in module jax.experimental.pallas.mosaic_gpu) typeof() (in module jax) U ufunc (class in jax.numpy) uint (in module jax.numpy) uint16 (class in jax.numpy) uint32 (class in jax.numpy) uint64 (class in jax.numpy) uint8 (class in jax.numpy) UnexpectedTracerError (class in jax.errors) unflatten() (in module jax.tree) uniform() (in module jax.nn.initializers) (in module jax.random) union1d() (in module jax.numpy) unique() (in module jax.numpy) unique_all() (in module jax.numpy) unique_counts() (in module jax.numpy) unique_inverse() (in module jax.numpy) unique_values() (in module jax.numpy) unordered_effects (jax.export.Exported attribute) unpack_optimizer_state() (in module jax.example_libraries.optimizers) unpackbits() (in module jax.numpy) unravel_index() (in module jax.numpy) unreduced (jax.sharding.PartitionSpec property) unsafe_no_auto_barriers (jax.experimental.pallas.mosaic_gpu.CompilerParams attribute) unsafe_rbg_prng_impl (in module jax.extend.random) unsignedinteger (class in jax.numpy) unstack() (in module jax.numpy) unwrap() (in module jax.numpy) update_fn (jax.example_libraries.optimizers.Optimizer attribute) use_tc_tiling_on_sc (jax.experimental.pallas.tpu.CompilerParams attribute) user_checks (in module jax.experimental.checkify) uses_global_constants (jax.export.Exported attribute) V value_and_grad() (in module jax) (in module jax.experimental.sparse) vander() (in module jax.numpy) Var (class in jax.extend.core) var() (in module jax.numpy) (jax.Array method) variance_scaling() (in module jax.nn.initializers) vdot() (in module jax.numpy) vecdot() (in module jax.numpy) (in module jax.numpy.linalg) vecmat() (in module jax.numpy) vector_norm() (in module jax.numpy.linalg) vectorize() (in module jax.numpy) view() (jax.Array method) visualize_array_sharding() (in module jax.debug) visualize_sharding() (in module jax.debug) VJP vjp() (in module jax) (jax.export.Exported method) vmap() (in module jax) vmem_limit_bytes (jax.experimental.pallas.tpu.CompilerParams attribute) vq() (in module jax.scipy.cluster.vq) vsplit() (in module jax.numpy) vstack() (in module jax.numpy) W wait_in_slot (jax.experimental.pallas.tpu.BufferedRef attribute) wait_load_tmem() (in module jax.experimental.pallas.mosaic_gpu) wait_out_slot (jax.experimental.pallas.tpu.BufferedRef attribute) wait_smem_to_gmem() (in module jax.experimental.pallas.mosaic_gpu) wald() (in module jax.random) weak type weibull_min() (in module jax.random) welch() (in module jax.scipy.signal) wgmma() (in module jax.experimental.pallas.mosaic_gpu) wgmma_wait() (in module jax.experimental.pallas.mosaic_gpu) WGMMAAccumulatorRef (class in jax.experimental.pallas.mosaic_gpu) when() (in module jax.experimental.pallas) where() (in module jax.numpy) while_loop() (in module jax.lax) window_ref (jax.experimental.pallas.tpu.BufferedRef attribute) with_memory_kind() (jax.sharding.NamedSharding method) (jax.sharding.Sharding method) (jax.sharding.SingleDeviceSharding method) with_memory_space_constraint() (in module jax.experimental.pallas.tpu) with_sharding_constraint() (in module jax.lax) wrap_init() (in module jax.extend.linear_util) wrap_key_data() (in module jax.random) Wrapped (class in jax.stages) WrappedFun (class in jax.extend.linear_util) X xavier_normal() (in module jax.nn.initializers) xavier_uniform() (in module jax.nn.initializers) XLA xlog1py (in module jax.scipy.special) xlogy (in module jax.scipy.special) Z zeros() (in module jax.nn.initializers) (in module jax.numpy) zeros_like() (in module jax.numpy) zeta (in module jax.scipy.special) zeta() (in module jax.lax)