MuJoCo XLA (MJX)#
Starting with version 3.0.0, MuJoCo includes MuJoCo XLA (MJX) under the mjx directory. MJX allows MuJoCo to run on compute hardware supported by the XLA compiler via the JAX framework. MJX runs on a all platforms supported by JAX: Nvidia and AMD GPUs, Apple Silicon, and Google Cloud TPUs.
The MJX API is consistent with the main simulation functions in the MuJoCo API, although it is currently missing some features. While the API documentation is applicable to both libraries, we indicate features unsupported by MJX in the notes below.
MJX is distributed as a separate package called mujoco-mjx
on PyPI.
Although it depends on the main mujoco
package for model compilation and visualization, it is a re-implementation of
MuJoCo that uses the same algorithms as the MuJoCo implementation. However, in order to properly leverage JAX, MJX
deliberately diverges from the MuJoCo API in a few places, see below.
MJX is a successor to the generalized physics pipeline
in Google’s Brax physics and reinforcement learning library. MJX was built
by core contributors to both MuJoCo and Brax, who will together continue to support both Brax (for its reinforcement
learning algorithms and included environments) and MJX (for its physics algorithms). A future version of Brax will
depend on the mujoco-mjx
package, and Brax’s existing
generalized pipeline will be deprecated. This change
will be largely transparent to users of Brax.
Tutorial notebook#
The following IPython notebook demonstrates the use of MJX along with reinforcement learning to train humanoid and
quadruped robots to locomote: .
Installation#
The recommended way to install this package is via PyPI:
pip install mujoco-mjx
A copy of the MuJoCo library is provided as part of this package’s depdendencies and does not need to be downloaded or installed separately.
Basic usage#
Once installed, the package can be imported via from mujoco import mjx
. Structs, functions, and enums are available
directly from the top-level mjx
module.
Structs#
Before running MJX functions on an accelerator device, structs must be copied onto the device via the mjx.device_put
function. Placing an mjModel on device yields an mjx.Model
. Placing an mjData on device yields
an mjx.Data
:
model = mujoco.MjModel.from_xml_string("...")
data = mujoco.MjData(model)
mjx_model = mjx.device_put(model)
mjx_data = mjx.device_put(data)
These MJX variants mirror their MuJoCo counterparts but have three key differences:
Fields in
mjx.Model
andmjx.Data
are JAX arrays copied onto device, instead of numpy arrays.Some fields are missing from
mjx.Model
andmjx.Data
for features that are unsupported in MJX.Arrays in
mjx.Model
andmjx.Data
support adding batch dimensions. Batch dimensions are a natural way to express domain randomization (in the case ofmjx.Model
) or high-throughput simulation for reinforcement learning (in the case ofmjx.Data
).
Neither mjx.Model
nor mjx.Data
are meant to be constructed manually. An mjx.Data
may be created by calling
mjx.make_data
, which mirrors the mj_makeData function in MuJoCo:
model = mujoco.MjModel.from_xml_string("...")
mjx_model = mjx.device_put(model)
mjx_data = mjx.make_data(model)
Using mjx.make_data
may be preferable when constructing batched mjx.Data
structures inside of a vmap
.
Functions#
MuJoCo functions are exposed as MJX functions of the same name, but following
PEP 8-compliant names. Most of the main simulation and
some of the sub-components for forward simulation are available from the top-level mjx
module.
MJX functions are not JIT compiled by default – we leave it to the user to JIT MJX functions, or JIT their own functions that reference MJX functions. See the minimal example below.
Enums and constants#
MJX enums are available as mjx.EnumType.ENUM_VALUE
, for example mjx.JointType.FREE
. Enums for unsupported MJX
features are omitted from the MJX enum declaration. MJX declares no constants but references MuJoCo constants directly.
Minimal example#
# Throw a ball at 100 different velocities.
import jax
import mujoco
from mujoco import mjx
XML=r"""
<mujoco>
<worldbody>
<body>
<freejoint/>
<geom size=".15" mass="1" type="sphere"/>
</body>
</worldbody>
</mujoco>
"""
model = mujoco.MjModel.from_xml_string(XML)
mjx_model = mjx.device_put(model)
@jax.vmap
def batched_step(vel):
mjx_data = mjx.make_data(mjx_model)
qvel = mjx_data.qvel.at[0].set(vel)
mjx_data = mjx_data.replace(qvel=qvel)
pos = mjx.step(mjx_model, mjx_data).qpos[0]
return pos
vel = jax.numpy.arange(0.0, 1.0, 0.01)
pos = jax.jit(batched_step)(vel)
print(pos)
Feature Parity#
MJX supports most of the main simulation features of MuJoCo, with a few exceptions. MJX will raise an exception if asked to copy to device an mjModel with field values referencing unsupported features.
The following features are fully supported in MJX:
Category |
Feature |
---|---|
Dynamics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
|
|
Fluid Model |
The following features are in development and coming soon:
Category |
Feature |
---|---|
Dynamics |
|
|
|
|
|
|
|
|
|
|
|
1, 4, 6 |
|
Fluid Model |
|
|
|
All except |
The following features are unsupported:
Category |
Feature |
---|---|
|
|
|
|
|
🔪 MJX - The Sharp Bits 🔪#
GPUs and TPUs have unique performance tradeoffs that MJX is subject to. MJX specializes in simulating big batches of parallel identical physics scenes using algorithms that can be efficiently vectorized on SIMD hardware. This specialization is useful for machine learning workloads such as reinforcement learning that require massive data throughput.
There are certain workflows that MJX is ill-suited for:
- Single scene simulation
Simulating a single scene (1 instance of mjData), MJX can be 10x slower than MuJoCo, which has been carefully optimized for CPU. MJX works best when simulating thousands or tens of thousands of scenes in parallel.
- Large, complex scenes with many contacts
Accelerators exhibit poor performance for branching code. Branching is used in broad-phase collision detection, when identifying potential collisions between large numbers of bodies in a scene. MJX ships with a simple branchless broad-phase algorithm (see performance tuning) but it is not as powerful as the one in MuJoCo.
To see how this affects simulation, let us consider a physics scene with increasing numbers of physics bodies. We simulate a scene with a variable number of humanoids (from 1 to 10) and then compare MJX’s performance on an Nvidia A100 GPU to MuJoCo on a 12-core workstation:
Notice that as we increase the number of humanoids (which increases the number of potential contacts in a scene), MJX performance degrades more rapidly than MuJoCo. At the limit, for such a large scene, MuJoCo performance nearly matches MJX.
- Scenes with collisions between meshes with many vertices
MJX supports mesh geometries and can determine if two meshes are colliding using branchless versions of mesh collision algorithms. These algorithms work well for smaller meshes (with hundreds of vertices) but suffer with large meshes. With careful tuning, MJX can simulate scenes with mesh collisions well – see the MJX shadow hand config for an example.
Performance tuning#
For MJX to perform well, some configuration parameters should be adjusted from their default MuJoCo values:
- option (*) element
The
iterations
andls_iterations
attributes—which control solver and linesearch iterations, respectively— should be brought down to just low enough that the simulation remains stable. Accurate solver forces are not so important in reinforcement learning in which domain randomization is often used to add noise to physics for sim-to-real. TheNEWTON
Solver often delivers reasonable convergence with one solver iteration, and performs well on GPU.CG
is currently a better choice for TPU.- contact/pair (*) element
Consider explicitly marking geoms for collision detection to reduce the number of contacts that MJX must consider during each step. Enabling only an explicit list of valid contacts can have a dramatic effect on simulation performance in MJX. Doing this well often requires an understanding of the task – for example, the OpenAI Gym Humanoid task resets when the humanoid starts to fall, so full contact with the floor is not needed.
- option/flag (?) element
Disabling
eulerdamp
can help performance and is often not needed for stability.