import jax import jax.numpy as jnp from jax import random,grad,jit # # init random key # key = random.PRNGKey(0) # # XOR training data # X = jnp.array([[0,0],[0,1],[1,0],[1,1]],dtype=jnp.int8) y = jnp.array([0,1,1,0],dtype=jnp.int8).reshape(4,1) # # parameter initialization # def init_params(key): k1,k2 = random.split(key) W1 = 0.1*random.normal(k1,(2,2)) b1 = jnp.zeros(2) W2 = 0.1*random.normal(k2,(2,1)) b2 = jnp.zeros(1) return (W1,b1,W2,b2) # # forward pass # def forward(params,in0): W1,b1,W2,b2 = params h1 = jnp.tanh(in0@W1+b1) o2 = jax.nn.sigmoid(h1@W2+b2) return o2 # # loss function # def loss(params): ypred = forward(params,X) return jnp.mean((ypred-y)**2) # # update weights # @jit def update(params,rate=0.5): gradient = grad(loss)(params) return jax.tree.map(lambda params,gradient:params-rate*gradient,params,gradient) # # initialize parameters # params = init_params(key) # # training steps # for step in range(1000): params = update(params,rate=10) if step%100 == 0: print(f"step {step:4d} loss={loss(params):.3f}") # # evaluate fit # pred = forward(params,X) jnp.set_printoptions(precision=2) print("\nPredictions:") print(jnp.concatenate([X,pred],axis=1))