Note

This example is available as a jupyter notebook here.

Batched Dynamical Simulation¤

System object is a registered Jax-PyTree. This means it's a nested array.

This enables us to stack multiple systems (or states) to enable vectorized operations.

Batched System¤

I.e. simulating two different system with the same initial state.

import ring

import jax
import jax.numpy as jnp


xml_str = """
<x_xy model="double_pendulum">
<options dt="0.01" gravity="0 0 9.81"></options>
<worldbody>
<body damping="2" euler="0 90 0" joint="ry" name="upper">
<geom dim="1 0.25 0.2" mass="10" pos="0.5 0 0" type="box"></geom>
<body damping="2" joint="ry" name="lower" pos="1 0 0">
<geom dim="1 0.25 0.2" mass="10" pos="0.5 0 0" type="box"></geom>
</body>
</body>
</worldbody>
</x_xy>
"""

sys = ring.System.create(xml_str)
state = ring.State.create(sys)
# second system with gravity disabled
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)
sys_batched = sys.batch(sys_nograv)

next_state_batched = jax.vmap(ring.step, in_axes=(0, None))(sys_batched, state)
# note how the state of the system without gravity has not changed at all
next_state_batched.q
Array([[0., 0.],
       [0., 0.]], dtype=float32)

Batched State¤

second_state = ring.State.create(sys, qd=jnp.ones((2,)))
state_batched = state.batch(second_state)
next_state_batched = jax.vmap(ring.step, in_axes=(None, 0))(sys, state_batched)
next_state_batched.q
Array([[0.        , 0.        ],
       [0.01004834, 0.00982152]], dtype=float32)

Batched Kinematic Simulation¤

Batched kinematic simulation is done by providing the sizes argument to build_generator

batchsize = 8
seed = 1
gen = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), keep_output_extras=True).to_lazy_gen(batchsize)
(X, y), (_, q, x, _) = gen(jax.random.PRNGKey(seed))
q.shape
(8, 1000, 2)