Note

This example is available as a jupyter notebook here.

The error quaternion (required for ML purposes)¤

In this notebook we will talk about what functions you need to do ML with quaternions. After all the purpose of this library is to create training data.

Typically, this involves quaternions as target values (to be predicted), similar to an orientation estimation filter (like VQF).

So, suppose you want to train some ML model that predicts a quaternion \(\hat{q} = f_\theta(X)\).

import ring
import jax 
import jax.numpy as jnp
import matplotlib.pyplot as plt

How to get a quaternion as network output?¤

That's easy enough. You normalize a four dimensional vector.

# suppose a 6D IMU input
feature_dim = 6

params = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))
def neural_network(params, X):
    q_unnormalized = params@X
    norm = jnp.linalg.norm(q_unnormalized)
    return q_unnormalized / norm


def loss_fn(params, X, y):
    q, qhat = y, neural_network(params, X)
    # squared angle error
    return ring.maths.angle_error(q, qhat)**2

But this is dangerous as this might lead to NaNs.

X = jnp.zeros((6,))
y = jnp.array([1.0, 0, 0, 0])
loss_fn(params, X, y)
Array(nan, dtype=float32)

We could try to fix is by adding a small number in the divison.

# suppose a 6D IMU input
feature_dim = 6

params = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))
def neural_network(params, X):
    q_unnormalized = params@X
    norm = jnp.linalg.norm(q_unnormalized)
    eps = 1e-8
    return q_unnormalized / (norm + eps)


def loss_fn(params, X, y):
    q, qhat = y, neural_network(params, X)
    # squared angle error
    return ring.maths.angle_error(q, qhat)**2

X = jnp.zeros((6,))
y = jnp.array([1.0, 0, 0, 0])
loss_fn(params, X, y)
Array(0., dtype=float32)

But, still the gradient required for backpropagation gives NaNs.

jax.grad(loss_fn)(params, X, y)
Array([[nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan]], dtype=float32)

The solution is a little involved. TLDR; Use x_xy.maths.safe_normalize

# suppose a 6D IMU input
feature_dim = 6

params = jax.random.normal(jax.random.PRNGKey(1), (4, feature_dim))
def neural_network(params, X):
    q_unnormalized = params@X
    return ring.maths.safe_normalize(q_unnormalized)


def loss_fn(params, X, y):
    q, qhat = y, neural_network(params, X)
    # squared angle error
    return ring.maths.angle_error(q, qhat)**2

X = jnp.zeros((6,))
y = jnp.array([1.0, 0, 0, 0])
loss_fn(params, X, y)
Array(0., dtype=float32)
jax.grad(loss_fn)(params, X, y)
Array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)

A closer look at the function x_xy.maths.angle_error¤

Let's take a closer look at the function x_xy.maths.angle_error which was used in the loss_fn in the above.

What is the behaviour of the error function (sort of the metric) between two quaternions as one approaches the other?

A first implementation might look like this:

def quat_error(q, qhat):
    q_error = ring.maths.quat_mul(ring.maths.quat_inv(q), qhat)
    phi = 2 * jnp.arccos(q_error[0])
    return jnp.abs(phi)

Let's reduce this function to the critical operation phi = ... and let's assume, without loss of generality, that the target quaternion is the identity quaternion.

Then, this effectively becomes about extracting the angle from a quaternion safely.

def quat_angle(q):
    return 2 * jnp.arccos(q[0])
input_angles = jnp.linspace(-0.005, 0.005, num=1000)

def input_to_output_angles_incorrect(angle):
    q = ring.maths.quat_rot_axis(jnp.array([1.0, 0, 0]), angle)
    return quat_angle(q)

def input_to_output_angles_correct(angle):
    q = ring.maths.quat_rot_axis(jnp.array([1.0, 0, 0]), angle)
    return ring.maths.quat_angle(q)
plt.plot(input_angles, jax.vmap(input_to_output_angles_incorrect)(input_angles), label="incorrect")
plt.plot(input_angles, jax.vmap(input_to_output_angles_correct)(input_angles), label="correct")
plt.legend()
plt.show()
No description has been provided for this image

As one might expect, the gradients are also much more stable.

plt.plot(input_angles, jax.vmap(jax.grad(input_to_output_angles_incorrect))(input_angles), label="incorrect")
plt.plot(input_angles, jax.vmap(jax.grad(input_to_output_angles_correct))(input_angles), label="correct")
plt.legend()
plt.show()
No description has been provided for this image

Pytorch library for quaternion operations¤

These functions are for JAX, but the following should work for PyTorch -> https://naver.github.io/roma/