Note
This example is available as a jupyter notebook here.
Batched Dynamical Simulation¤
System
object is a registered Jax-PyTree. This means it's a nested array.
This enables us to stack multiple systems (or states) to enable vectorized operations.
Batched System¤
I.e. simulating two different system with the same initial state.
import ring
import jax
import jax.numpy as jnp
xml_str = """
<x_xy model="double_pendulum">
<options dt="0.01" gravity="0 0 9.81"></options>
<worldbody>
<body damping="2" euler="0 90 0" joint="ry" name="upper">
<geom dim="1 0.25 0.2" mass="10" pos="0.5 0 0" type="box"></geom>
<body damping="2" joint="ry" name="lower" pos="1 0 0">
<geom dim="1 0.25 0.2" mass="10" pos="0.5 0 0" type="box"></geom>
</body>
</body>
</worldbody>
</x_xy>
"""
sys = ring.System.create(xml_str)
state = ring.State.create(sys)
# second system with gravity disabled
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)
sys_batched = sys.batch(sys_nograv)
next_state_batched = jax.vmap(ring.step, in_axes=(0, None))(sys_batched, state)
# note how the state of the system without gravity has not changed at all
next_state_batched.q
Batched State¤
second_state = ring.State.create(sys, qd=jnp.ones((2,)))
state_batched = state.batch(second_state)
next_state_batched = jax.vmap(ring.step, in_axes=(None, 0))(sys, state_batched)
next_state_batched.q
Batched Kinematic Simulation¤
Batched kinematic simulation is done by providing the sizes
argument to build_generator
batchsize = 8
seed = 1
gen = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), keep_output_extras=True).to_lazy_gen(batchsize)
(X, y), (_, q, x, _) = gen(jax.random.PRNGKey(seed))
q.shape