Note
This example is available as a jupyter notebook here.
Balance an inverted Pendulum on a cart¤
import ring
from ring.algorithms.generator.pd_control import _pd_control
import jax
import jax.numpy as jnp
import numpy as np
import mediapy as media
The step
function also takes generalized forces tau
applied to the degrees of freedom its third input step(sys, state, taus)
.
Let's consider an inverted pendulum on a cart, and apply a left-right force onto the cart such that the pole stays in the upright position.
xml_str = """
<x_xy model="inv_pendulum">
<options dt="0.01" gravity="0 0 9.81"></options>
<defaults>
<geom color="white" edge_color="black"></geom>
</defaults>
<worldbody>
<body damping="0.01" joint="px" name="cart">
<geom dim="0.4 0.1 0.1" mass="1" type="box"></geom>
<body damping="0.01" euler="0 -90 0" joint="ry" name="pendulum">
<geom dim="1 0.1 0.1" mass="0.5" pos="0.5 0 0" type="box"></geom>
</body>
</body>
</worldbody>
</x_xy>
"""
sys = ring.System.create(xml_str)
state = ring.State.create(sys, q=jnp.array([0.0, 0.2]))
xs = []
T = 10.0
for t in range(int(T / sys.dt)):
measurement_noise = np.random.normal() * 5
phi = jnp.rad2deg(state.q[1]) + measurement_noise
cart_motor_input = 0.1 * phi * abs(phi)
taus = jnp.clip(jnp.array([cart_motor_input, 0.0]), -10, 10)
state = jax.jit(ring.step)(sys, state, taus)
xs.append(state.x)
def show_video(sys, xs: list[ring.Transform]):
assert sys.dt == 0.01
# only render every fourth to get a framerate of 25 fps
frames = sys.render(xs, render_every_nth=4, camera="c", add_cameras={-1: '<camera mode="targetbody" name="c" pos="0 -2 2" target="0"></camera>'})
media.show_video(frames, fps=25)
show_video(sys, xs)
PD Control¤
xml_str = """
<x_xy>
<options dt="0.01" gravity="0 0 9.81"></options>
<worldbody>
<body damping="0.01" euler="0 90 0" joint="ry" name="pendulum" pos="0 0 1">
<geom dim="1 0.1 0.1" mass="0.5" pos="0.5 0 0" type="box"></geom>
</body>
</worldbody>
</x_xy>
"""
sys = ring.System.create(xml_str)
P, D = jnp.array([10.0]), jnp.array([1.0])
def simulate_pd_control(sys, P, D):
controller = _pd_control(P, D)
# reference signal
q_ref = jnp.ones((1000, 1)) * jnp.pi / 2
controller_state = controller.init(sys, q_ref)
state = ring.State.create(sys)
xs = []
T = 5.0
for t in range(int(T / sys.dt)):
controller_state, taus = jax.jit(controller.apply)(controller_state, sys, state)
state = jax.jit(ring.step)(sys, state, taus)
xs.append(state.x)
return xs
xs = simulate_pd_control(sys, P, D)
show_video(sys, xs)
Note the steady state error. This is because we have gravity and no Integral part (so no PID control).
If we remove gravity the steady state error also vanishes (as is expected.)
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)
xs = simulate_pd_control(sys_nograv, P, D)
show_video(sys_nograv, xs)