Marketplace
using-jax
Applies JAX patterns for scientific Python development. Use when working with JAX, distrax, numpyro, blackjax, or scientific computing. Covers vmap, JIT, RNG handling.
$ 安裝
git clone https://github.com/yallup/claude-skills /tmp/claude-skills && cp -r /tmp/claude-skills/plugins/claude-skills/skills/using-jax ~/.claude/skills/claude-skills// tip: Run this command in your terminal to install the skill
SKILL.md
name: using-jax description: Applies JAX patterns for scientific Python development. Use when working with JAX, distrax, numpyro, blackjax, or scientific computing. Covers vmap, JIT, RNG handling. version: 2.0.0
JAX Scientific Computing
Core Rules
- Pure functions - No side effects
- JIT outer functions -
@jax.jiton hot paths - vmap not loops -
jax.vmap(fn)instead of list comprehensions - Split RNG keys - Never reuse keys
Patterns
# RNG: always split
key, k1, k2 = jax.random.split(key, 3)
# Batching: vmap not loops
batched = jax.vmap(fn)(inputs)
# Loops: use scan
_, results = jax.lax.scan(step_fn, init, xs)
Gotchas
- Arrays are immutable
- No Python control flow in JIT - use
jax.lax.cond,jax.lax.scan - Check NaNs:
jnp.isnan(x).any()
Repository

yallup
Author
yallup/claude-skills/plugins/claude-skills/skills/using-jax
1
Stars
0
Forks
Updated2d ago
Added1w ago