A JAX a TensorFlow és a PyTorch új versenytársa. A JAX az egyszerűséget hangsúlyozza a sebesség és a méretezhetőség feláldozása nélkül. Mivel a JAX kevesebb kazánkódot igényel, a programok rövidebbek, közelebb állnak a matematikához, így könnyebben érthetőek.
TL;DR:
- 🐍 Hozzáférés a NumPy funkciókhoz a
import jax.numpy
és a SciPy funkciókhoz aimport jax.scipy
segítségével. - 🔥 Gyorsíts fel az éppen időben történő összeállítással a
@jax.jit
-al díszítve. - ∇ Vegyen deriváltokat a
jax.grad
használatával. - ➡️ Vektorizálás a
jax.vmap
segítségével, és párhuzamosítás az eszközök között ajax.pmap
segítségével.
Ez a bejegyzés egy előadás rövidített változata, amelyet tavaly tavasszal a PyGrunn 11-en, Európa egyik legnagyobb Python konferenciáján tartottam.
Funkcionális programozás
A JAX funkcionális programozási filozófiát követ. Ez azt jelenti, hogy a funkcióknak önállóaknak kell lenniük, vagy tiszta: a mellékhatások nem megengedettek. Lényegében a tiszta függvény úgy néz ki, mint egy matematikai függvény (1. ábra). Bejön a bemenet, kijön valami, de nincs kommunikáció a külvilággal.
❌ 1. példa
A következő részlet egy példa arra, hogy nem funkcionálisan tiszta.
import jax.numpy as jnp bias = jnp.array(0) def impure_example(x): total = x + bias return total
Figyelje meg a bias
kívül impure_example
. A fordítás során (lásd alább) a bias
gyorsítótárba kerülhet, és ezért már nem tükrözi a bias
változásait.
✅ 2. példa
Íme egy példa, amely tiszta.
def pure_example(x, weights, bias): activation = weights @ x + bias return activation
Itt a pure_example
önálló: minden paraméter argumentumként kerül átadásra.
🎰 Determinisztikus mintavevők
A számítógépekben az igazi véletlenszerűség nem létezik. Ehelyett az olyan könyvtárak, mint a NumPy és a TensorFlow, egy pszeudo-véletlenszám-állapotot követnek „véletlenszerű” minták létrehozásához. A funkcionális programozás közvetlen következménye, hogy a véletlen függvények eltérően működnek. Mivel a globális állapot már nem engedélyezett, egy pszeudovéletlenszám-generátor (PRNG) kulcsot kell megadnia minden alkalommal, amikor véletlen számból mintavételez (2. ábra).
import jax key = jax.random.PRNGKey(42) u = jax.random.uniform(key)
Ezenkívül Ön felelős a „véletlenszerű állapot” előmozdításáért minden későbbi hívás esetén.
key = jax.random.PRNGKey(43) # Split off and consume subkey. key, subkey = jax.random.split(key) u = jax.random.uniform(subkey) # Split off and consume second subkey. key, subkey = jax.random.split(key) u = jax.random.uniform(subkey) ..
🔥 jit
Felgyorsíthatja kódját, ha éppen időben összeállítja a JAX utasításait. Például a skálázott exponenciális lineáris egységek (SELU) függvény összeállításához használja a jax.numpy
NumPy függvényeit, és adja hozzá a jax.jit
dekorátort a függvényhez az alábbiak szerint:
from jax import jit @jit def selu(x, α=1.67, λ=1.05): return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)
A motorháztető alatt a JAX követi az utasításokat, és jaxpr-vé alakítja. Ez lehetővé teszi a gyorsított lineáris algebra (XLA) fordító számára, hogy nagyon hatékonyan optimalizált kódot készítsen a gyorsítóhoz.
∇ grad
A JAX egyik legerősebb tulajdonsága, hogy könnyen átmásolható. A jax.grad
paraméterrel egy új függvényt definiálhat, amely a szimbolikus derivált.
from jax import grad def f(x): return x + 0.5 * x**2 df_dx = grad(f) d2f_dx2 = grad(grad(f))
Ahogy a példában is látható, nem korlátozódik az elsőrendű származékokra. Kiveheti az n-edik sorrendű származékot, ha egyszerűen láncolja a grad
függvényt n-szer egymás után.
➡️ vmap és pmap
A mátrixszorzás komoly mentális gimnasztikát igényel, hogy az összes tétel mérete megfelelő legyen. A JAX vektoros térkép funkciója vmap
enyhíti ezt a terhet a függvény vektorizálásával. Alapvetően minden olyan kóddarab, amely elemenként alkalmaz egy f
függvényt, vmap
-re cserélhető. Nézzünk egy példát.
A lineáris függvény kiszámításához:
def linear(x): return weights @ x
a [x₁, x₂,..] példák kötegében naivan (a vmap
nélkül) ) hajtsa végre az alábbiak szerint:
def naively_batched_linear(X_batched): return jnp.stack([linear(x) for x in X_batched])
Ehelyett a linear
-vel vmap
vektorizálásával a teljes köteget egy menetben kiszámíthatjuk:
def vmap_batched_linear(X_batched): return vmap(linear)(X_batched)
Bónusz: Ha el szeretné osztani a munkaterhelést a gyorsítók között, ugyanazt a játékot játszhatja: cserélje le a vmap
-t pmap
-re, és számításai több eszközön is skálázódnak.
További információért ajánlom Laurence Moroney bemutatkozó videóját. További olvasáshoz vessen egy pillantást a JAX dokumentumokra.