A JAX a TensorFlow és a PyTorch új versenytársa. A JAX az egyszerűséget hangsúlyozza a sebesség és a méretezhetőség feláldozása nélkül. Mivel a JAX kevesebb kazánkódot igényel, a programok rövidebbek, közelebb állnak a matematikához, így könnyebben érthetőek.

TL;DR:

  • 🐍 Hozzáférés a NumPy funkciókhoz a import jax.numpy és a SciPy funkciókhoz a import jax.scipy segítségével.
  • 🔥 Gyorsíts fel az éppen időben történő összeállítással a @jax.jit-al díszítve.
  • ∇ Vegyen deriváltokat a jax.grad használatával.
  • ➡️ Vektorizálás a jax.vmap segítségével, és párhuzamosítás az eszközök között a jax.pmap segítségével.

Ez a bejegyzés egy előadás rövidített változata, amelyet tavaly tavasszal a PyGrunn 11-en, Európa egyik legnagyobb Python konferenciáján tartottam.

Funkcionális programozás

A JAX funkcionális programozási filozófiát követ. Ez azt jelenti, hogy a funkcióknak önállóaknak kell lenniük, vagy tiszta: a mellékhatások nem megengedettek. Lényegében a tiszta függvény úgy néz ki, mint egy matematikai függvény (1. ábra). Bejön a bemenet, kijön valami, de nincs kommunikáció a külvilággal.

❌ 1. példa

A következő részlet egy példa arra, hogy nem funkcionálisan tiszta.

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

Figyelje meg a bias kívül impure_example. A fordítás során (lásd alább) a bias gyorsítótárba kerülhet, és ezért már nem tükrözi a bias változásait.

✅ 2. példa

Íme egy példa, amely tiszta.

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

Itt a pure_example önálló: minden paraméter argumentumként kerül átadásra.

🎰 Determinisztikus mintavevők

A számítógépekben az igazi véletlenszerűség nem létezik. Ehelyett az olyan könyvtárak, mint a NumPy és a TensorFlow, egy pszeudo-véletlenszám-állapotot követnek „véletlenszerű” minták létrehozásához. A funkcionális programozás közvetlen következménye, hogy a véletlen függvények eltérően működnek. Mivel a globális állapot már nem engedélyezett, egy pszeudovéletlenszám-generátor (PRNG) kulcsot kell megadnia minden alkalommal, amikor véletlen számból mintavételez (2. ábra).

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

Ezenkívül Ön felelős a „véletlenszerű állapot” előmozdításáért minden későbbi hívás esetén.

key = jax.random.PRNGKey(43)
# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

🔥 jit

Felgyorsíthatja kódját, ha éppen időben összeállítja a JAX utasításait. Például a skálázott exponenciális lineáris egységek (SELU) függvény összeállításához használja a jax.numpy NumPy függvényeit, és adja hozzá a jax.jit dekorátort a függvényhez az alábbiak szerint:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

A motorháztető alatt a JAX követi az utasításokat, és jaxpr-vé alakítja. Ez lehetővé teszi a gyorsított lineáris algebra (XLA) fordító számára, hogy nagyon hatékonyan optimalizált kódot készítsen a gyorsítóhoz.

∇ grad

A JAX egyik legerősebb tulajdonsága, hogy könnyen átmásolható. A jax.grad paraméterrel egy új függvényt definiálhat, amely a szimbolikus derivált.

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

Ahogy a példában is látható, nem korlátozódik az elsőrendű származékokra. Kiveheti az n-edik sorrendű származékot, ha egyszerűen láncolja a grad függvényt n-szer egymás után.

➡️ vmap és pmap

A mátrixszorzás komoly mentális gimnasztikát igényel, hogy az összes tétel mérete megfelelő legyen. A JAX vektoros térkép funkciója vmap enyhíti ezt a terhet a függvény vektorizálásával. Alapvetően minden olyan kóddarab, amely elemenként alkalmaz egy f függvényt, vmap-re cserélhető. Nézzünk egy példát.

A lineáris függvény kiszámításához:

def linear(x):
 return weights @ x

a [x₁, x₂,..] példák kötegében naivan (a vmap nélkül) ) hajtsa végre az alábbiak szerint:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

Ehelyett a linear-vel vmap vektorizálásával a teljes köteget egy menetben kiszámíthatjuk:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)

Bónusz: Ha el szeretné osztani a munkaterhelést a gyorsítók között, ugyanazt a játékot játszhatja: cserélje le a vmap-t pmap-re, és számításai több eszközön is skálázódnak.

További információért ajánlom Laurence Moroney bemutatkozó videóját. További olvasáshoz vessen egy pillantást a JAX dokumentumokra.