Note

This example is available as a jupyter notebook here.

Balance an inverted Pendulum on a cart¤

import ring

from ring.algorithms.generator.pd_control import _pd_control

import jax
import jax.numpy as jnp
import numpy as np

import mediapy as media

The step function also takes generalized forces tau applied to the degrees of freedom its third input step(sys, state, taus).

Let's consider an inverted pendulum on a cart, and apply a left-right force onto the cart such that the pole stays in the upright position.

xml_str = """
<x_xy model="inv_pendulum">
<options dt="0.01" gravity="0 0 9.81"></options>
<defaults>
<geom color="white" edge_color="black"></geom>
</defaults>
<worldbody>
<body damping="0.01" joint="px" name="cart">
<geom dim="0.4 0.1 0.1" mass="1" type="box"></geom>
<body damping="0.01" euler="0 -90 0" joint="ry" name="pendulum">
<geom dim="1 0.1 0.1" mass="0.5" pos="0.5 0 0" type="box"></geom>
</body>
</body>
</worldbody>
</x_xy>
"""

sys = ring.System.create(xml_str)
state = ring.State.create(sys, q=jnp.array([0.0, 0.2])) 

xs = []
T = 10.0
for t in range(int(T / sys.dt)):
    measurement_noise = np.random.normal() * 5
    phi = jnp.rad2deg(state.q[1]) + measurement_noise
    cart_motor_input = 0.1 * phi * abs(phi)
    taus = jnp.clip(jnp.array([cart_motor_input, 0.0]), -10, 10) 
    state = jax.jit(ring.step)(sys, state, taus)
    xs.append(state.x)
def show_video(sys, xs: list[ring.Transform]):
    assert sys.dt == 0.01
    # only render every fourth to get a framerate of 25 fps
    frames = sys.render(xs, render_every_nth=4, camera="c", add_cameras={-1: '<camera mode="targetbody" name="c" pos="0 -2 2" target="0"></camera>'})
    media.show_video(frames, fps=25)

show_video(sys, xs)
Rendering frames..: 100%|██████████| 250/250 [00:02<00:00, 102.28it/s]

PD Control¤

xml_str = """
<x_xy>
<options dt="0.01" gravity="0 0 9.81"></options>
<worldbody>
<body damping="0.01" euler="0 90 0" joint="ry" name="pendulum" pos="0 0 1">
<geom dim="1 0.1 0.1" mass="0.5" pos="0.5 0 0" type="box"></geom>
</body>
</worldbody>
</x_xy>
"""

sys = ring.System.create(xml_str)
P, D = jnp.array([10.0]), jnp.array([1.0])

def simulate_pd_control(sys, P, D):
    controller = _pd_control(P, D)
    # reference signal
    q_ref = jnp.ones((1000, 1)) * jnp.pi / 2
    controller_state = controller.init(sys, q_ref)
    state = ring.State.create(sys) 

    xs = []
    T = 5.0
    for t in range(int(T / sys.dt)):
        controller_state, taus = jax.jit(controller.apply)(controller_state, sys, state)
        state = jax.jit(ring.step)(sys, state, taus)
        xs.append(state.x)
    return xs
xs = simulate_pd_control(sys, P, D)
show_video(sys, xs)
Rendering frames..: 100%|██████████| 125/125 [00:01<00:00, 108.19it/s]

Note the steady state error. This is because we have gravity and no Integral part (so no PID control).

If we remove gravity the steady state error also vanishes (as is expected.)

sys_nograv = sys.replace(gravity = sys.gravity * 0.0)
xs = simulate_pd_control(sys_nograv, P, D)
show_video(sys_nograv, xs)
Rendering frames..: 100%|██████████| 125/125 [00:00<00:00, 166.82it/s]