# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
from typing import Any, Optional, Sequence, Union
import mujoco
import numpy as np
import warp as wp
from mujoco_warp._src import types
from mujoco_warp._src import warp_util
from mujoco_warp._src.warp_util import nested_kernel
def _create_array(data: Any, spec: wp.array, sizes: dict[str, int]) -> Union[wp.array, None]:
"""Creates a warp array and populates it with data.
The array shape is determined by a field spec referencing MjModel / MjData array sizes.
"""
shape = None
if spec.shape != (0,):
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in spec.shape)
if data is None and shape is None:
return None # nothing to do
elif data is None:
array = wp.zeros(shape, dtype=spec.dtype)
else:
array = wp.array(np.array(data), dtype=spec.dtype, shape=shape)
if spec.shape[0] == "*":
# add private attribute for JAX to determine which fields are batched
array._is_batched = True
# also set stride 0 to 0 which is expected legacy behavior (but is deprecated)
array.strides = (0,) + array.strides[1:]
return array
[docs]
def put_model(mjm: mujoco.MjModel) -> types.Model:
"""Creates a model on device.
Args:
mjm: The model containing kinematic and dynamic information (host).
Returns:
The model containing kinematic and dynamic information (device).
"""
# check for compatible cuda toolkit and driver versions
warp_util.check_toolkit_driver()
# model: check supported features in array types
for field, field_type in (
(mjm.actuator_trntype, types.TrnType),
(mjm.actuator_dyntype, types.DynType),
(mjm.actuator_gaintype, types.GainType),
(mjm.actuator_biastype, types.BiasType),
(mjm.eq_type, types.EqType),
(mjm.geom_type, types.GeomType),
(mjm.sensor_type, types.SensorType),
(mjm.wrap_type, types.WrapType),
):
missing = ~np.isin(field, field_type)
if missing.any():
raise NotImplementedError(f"{field_type.__name__}: {field[missing]} not supported.")
# opt: check supported features in scalar types
for field, field_type in (
(mjm.opt.integrator, types.IntegratorType),
(mjm.opt.cone, types.ConeType),
(mjm.opt.solver, types.SolverType),
):
if field not in set(field_type):
raise NotImplementedError(f"{field_type.__name__} {field} is unsupported.")
# opt: check supported features in scalar flag types
for field, field_type in (
(mjm.opt.disableflags, types.DisableBit),
(mjm.opt.enableflags, types.EnableBit),
):
if field & ~np.bitwise_or.reduce(field_type):
raise NotImplementedError(f"{field_type.__name__} {field} is unsupported.")
if mjm.nflex > 1:
raise NotImplementedError("Only one flex is unsupported.")
if ((mjm.flex_contype != 0) | (mjm.flex_conaffinity != 0)).any():
raise NotImplementedError("Flex collisions are not implemented.")
if mjm.opt.noslip_iterations > 0:
raise NotImplementedError(f"noslip solver not implemented.")
if (mjm.opt.viscosity > 0 or mjm.opt.density > 0) and mjm.opt.integrator in (
mujoco.mjtIntegrator.mjINT_IMPLICITFAST,
mujoco.mjtIntegrator.mjINT_IMPLICIT,
):
raise NotImplementedError(f"Implicit integrators and fluid model not implemented.")
if (mjm.body_plugin != -1).any():
raise NotImplementedError("Body plugins not supported.")
if (mjm.actuator_plugin != -1).any():
raise NotImplementedError("Actuator plugins not supported.")
if (mjm.sensor_plugin != -1).any():
raise NotImplementedError("Sensor plugins not supported.")
# TODO(team): remove after _update_gradient for Newton uses tile operations for islands
nv_max = 60
if mjm.nv > nv_max and mjm.opt.jacobian == mujoco.mjtJacobian.mjJAC_DENSE:
raise ValueError(f"Dense is unsupported for nv > {nv_max} (nv = {mjm.nv}).")
collision_sensors = (mujoco.mjtSensor.mjSENS_GEOMDIST, mujoco.mjtSensor.mjSENS_GEOMNORMAL, mujoco.mjtSensor.mjSENS_GEOMFROMTO)
is_collision_sensor = np.isin(mjm.sensor_type, collision_sensors)
def not_implemented(objtype, objid, geomtype):
if objtype == mujoco.mjtObj.mjOBJ_BODY:
geomnum = mjm.body_geomnum[objid]
geomadr = mjm.body_geomadr[objid]
for geomid in range(geomadr, geomadr + geomnum):
if mjm.geom_type[geomid] == geomtype:
return True
elif objtype == mujoco.mjtObj.mjOBJ_GEOM:
if mjm.geom_type[objid] == geomtype:
return True
return False
for geoms in [
(types.GeomType.BOX, types.GeomType.BOX),
(types.GeomType.CAPSULE, types.GeomType.BOX),
(types.GeomType.CYLINDER, types.GeomType.BOX),
(types.GeomType.PLANE, types.GeomType.BOX),
]:
for objtype, objid, reftype, refid in zip(
mjm.sensor_objtype[is_collision_sensor],
mjm.sensor_objid[is_collision_sensor],
mjm.sensor_reftype[is_collision_sensor],
mjm.sensor_refid[is_collision_sensor],
):
if not_implemented(objtype, objid, geoms[0]) and not_implemented(reftype, refid, geoms[1]):
raise NotImplementedError(f"Collision sensors with {geoms[0]} and {geoms[1]} are not implemented.")
# create opt
opt = types.Option(**{f.name: getattr(mjm.opt, f.name, None) for f in dataclasses.fields(types.Option)})
# C MuJoCo tolerance was chosen for float64 architecture, but we default to float32 on GPU
# adjust the tolerance for lower precision, to avoid the solver spending iterations needlessly
# bouncing around the optimal solution
opt.tolerance = max(opt.tolerance, 1e-6)
# warp only fields
opt.is_sparse = bool(mujoco.mj_isSparse(mjm))
ls_parallel_id = mujoco.mj_name2id(mjm, mujoco.mjtObj.mjOBJ_NUMERIC, "ls_parallel")
opt.ls_parallel = (ls_parallel_id > -1) and (mjm.numeric_data[mjm.numeric_adr[ls_parallel_id]] == 1)
opt.ls_parallel_min_step = 1.0e-6 # TODO(team): determine good default setting
opt.has_fluid = mjm.opt.wind.any() or mjm.opt.density > 0 or mjm.opt.viscosity > 0
opt.broadphase = types.BroadphaseType.NXN
opt.broadphase_filter = types.BroadphaseFilter.PLANE | types.BroadphaseFilter.SPHERE | types.BroadphaseFilter.OBB
opt.graph_conditional = True
opt.run_collision_detection = True
opt.contact_sensor_maxmatch = 64
# place opt on device
for f in dataclasses.fields(types.Option):
if isinstance(f.type, wp.array):
setattr(opt, f.name, _create_array(getattr(opt, f.name), f.type, {"*": 1}))
else:
setattr(opt, f.name, f.type(getattr(opt, f.name)))
# create stat
stat = types.Statistic(meaninertia=mjm.stat.meaninertia)
# create model
m = types.Model(**{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model)})
m.opt = opt
m.stat = stat
m.nacttrnbody = (mjm.actuator_trntype == mujoco.mjtTrn.mjTRN_BODY).sum()
m.nsensortaxel = mjm.mesh_vertnum[mjm.sensor_objid[mjm.sensor_type == mujoco.mjtSensor.mjSENS_TACTILE]].sum()
m.nsensorcontact = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT).sum()
m.nrangefinder = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER).sum()
m.nmaxcondim = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
m.nmaxpyramid = np.maximum(1, 2 * (m.nmaxcondim - 1))
m.has_sdf_geom = (mjm.geom_type == mujoco.mjtGeom.mjGEOM_SDF).any()
m.block_dim = types.BlockDim()
# body ids grouped by tree level
bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1
for i in range(mjm.nbody):
body_depth[i] = body_depth[mjm.body_parentid[i]] + 1
bodies.setdefault(body_depth[i], []).append(i)
m.body_tree = tuple(wp.array(bodies[i], dtype=int) for i in sorted(bodies))
m.mocap_bodyid = np.arange(mjm.nbody)[mjm.body_mocapid >= 0]
m.mocap_bodyid = m.mocap_bodyid[mjm.body_mocapid[mjm.body_mocapid >= 0].argsort()]
m.body_fluid_ellipsoid = np.zeros(mjm.nbody, dtype=bool)
m.body_fluid_ellipsoid[mjm.geom_bodyid[mjm.geom_fluid.reshape(mjm.ngeom, mujoco.mjNFLUID)[:, 0] > 0]] = True
jnt_limited_slide_hinge = mjm.jnt_limited & np.isin(mjm.jnt_type, (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE))
m.jnt_limited_slide_hinge_adr = np.nonzero(jnt_limited_slide_hinge)[0]
m.jnt_limited_ball_adr = np.nonzero(mjm.jnt_limited & (mjm.jnt_type == mujoco.mjtJoint.mjJNT_BALL))[0]
m.dof_tri_row, m.dof_tri_col = np.tril_indices(mjm.nv)
# precalculated geom pairs
filterparent = not (mjm.opt.disableflags & types.DisableBit.FILTERPARENT)
geom1, geom2 = np.triu_indices(mjm.ngeom, k=1)
m.nxn_geom_pair = np.stack((geom1, geom2), axis=1)
bodyid1 = mjm.geom_bodyid[geom1]
bodyid2 = mjm.geom_bodyid[geom2]
contype1 = mjm.geom_contype[geom1]
contype2 = mjm.geom_contype[geom2]
conaffinity1 = mjm.geom_conaffinity[geom1]
conaffinity2 = mjm.geom_conaffinity[geom2]
weldid1 = mjm.body_weldid[bodyid1]
weldid2 = mjm.body_weldid[bodyid2]
weld_parentid1 = mjm.body_weldid[mjm.body_parentid[weldid1]]
weld_parentid2 = mjm.body_weldid[mjm.body_parentid[weldid2]]
self_collision = weldid1 == weldid2
parent_child_collision = (
filterparent & (weldid1 != 0) & (weldid2 != 0) & ((weldid1 == weld_parentid2) | (weldid2 == weld_parentid1))
)
mask = np.array((contype1 & conaffinity2) | (contype2 & conaffinity1), dtype=bool)
exclude = np.isin((bodyid1 << 16) + bodyid2, mjm.exclude_signature)
nxn_pairid_contact = -1 * np.ones(len(geom1), dtype=int)
nxn_pairid_contact[~(mask & ~self_collision & ~parent_child_collision & ~exclude)] = -2
# contact pairs
def upper_tri_index(n, i, j):
i, j = (j, i) if j < i else (i, j)
return (i * (2 * n - i - 3)) // 2 + j - 1
for i in range(mjm.npair):
nxn_pairid_contact[upper_tri_index(mjm.ngeom, mjm.pair_geom1[i], mjm.pair_geom2[i])] = i
sensor_collision_adr = np.nonzero(is_collision_sensor)[0]
nxn_pairid_collision = -1 * np.ones(len(geom1), dtype=int)
pairids = []
collision_geom_adr = [0]
m.sensor_collision_start_adr = []
for i in range(sensor_collision_adr.size):
sensorid = sensor_collision_adr[i]
objtype = mjm.sensor_objtype[sensorid]
objid = mjm.sensor_objid[sensorid]
reftype = mjm.sensor_reftype[sensorid]
refid = mjm.sensor_refid[sensorid]
# get lists of geoms to collide
if objtype == types.ObjType.BODY:
n1 = mjm.body_geomnum[objid]
id1 = mjm.body_geomadr[objid]
else:
n1 = 1
id1 = objid
if reftype == types.ObjType.BODY:
n2 = mjm.body_geomnum[refid]
id2 = mjm.body_geomadr[refid]
else:
n2 = 1
id2 = refid
# collide all pairs
geomid = 0
for geom1id in range(id1, id1 + n1):
for geom2id in range(id2, id2 + n2):
pairid = upper_tri_index(mjm.ngeom, geom1id, geom2id)
if pairid in pairids:
m.sensor_collision_start_adr.append(nxn_pairid_collision[pairid])
else:
pairids.append(pairid)
adr = collision_geom_adr[-1] + geomid
nxn_pairid_collision[pairid] = adr
m.sensor_collision_start_adr.append(adr)
geomid += 1
if i < sensor_collision_adr.size - 1:
collision_geom_adr.append(collision_geom_adr[-1] + n1 * n2)
m.nsensorcollision = (nxn_pairid_collision >= 0).sum()
nxn_include = (nxn_pairid_contact > -2) | (nxn_pairid_collision >= 0)
if nxn_include.sum() < 250_000:
opt.broadphase = types.BroadphaseType.NXN
elif mjm.ngeom < 1000:
opt.broadphase = types.BroadphaseType.SAP_TILE
else:
opt.broadphase = types.BroadphaseType.SAP_SEGMENTED
m.nxn_geom_pair_filtered = m.nxn_geom_pair[nxn_include]
m.nxn_pairid = np.hstack([nxn_pairid_contact.reshape((-1, 1)), nxn_pairid_collision.reshape((-1, 1))])
m.nxn_pairid_filtered = m.nxn_pairid[nxn_include]
# count contact pair types
def geom_trid_index(i, j):
i, j = (j, i) if j < i else (i, j)
return (i * (2 * len(types.GeomType) - i - 1)) // 2 + j
m.geom_pair_type_count = tuple(
np.bincount(
[geom_trid_index(mjm.geom_type[geom1[i]], mjm.geom_type[geom2[i]]) for i in np.arange(len(geom1)) if nxn_include[i]],
minlength=len(types.GeomType) * (len(types.GeomType) + 1) // 2,
)
)
# compute nmaxpolygon and nmaxmeshdeg given the geom pairs for the model
nboxbox = m.geom_pair_type_count[geom_trid_index(types.GeomType.BOX, types.GeomType.BOX)]
nboxmesh = m.geom_pair_type_count[geom_trid_index(types.GeomType.BOX, types.GeomType.MESH)]
nmeshmesh = m.geom_pair_type_count[geom_trid_index(types.GeomType.MESH, types.GeomType.MESH)]
# need at least 4 (square sides) if there's a box collision needing multiccd
m.nmaxpolygon = 4 * (nboxbox + nboxmesh > 0)
m.nmaxmeshdeg = 3 * (nboxbox + nboxmesh > 0)
# possibly need to allocate more memory if there's meshes
if nmeshmesh + nboxmesh > 0:
# TODO(kbayes): remove nboxbox or enable ccd for box-box collisions
m.nmaxpolygon = np.append(mjm.mesh_polyvertnum, m.nmaxpolygon).max()
m.nmaxmeshdeg = np.append(mjm.mesh_polymapnum, m.nmaxmeshdeg).max()
# filter plugins for only geom plugins, drop the rest
m.plugin, m.plugin_attr = [], []
m.geom_plugin_index = np.full_like(mjm.geom_type, -1)
for i in range(len(mjm.geom_plugin)):
if mjm.geom_plugin[i] == -1:
continue
p = mjm.geom_plugin[i]
m.geom_plugin_index[i] = len(m.plugin)
m.plugin.append(mjm.plugin[p])
start = mjm.plugin_attradr[p]
end = mjm.plugin_attradr[p + 1] if p + 1 < mjm.nplugin else len(mjm.plugin_attr)
values = mjm.plugin_attr[start:end]
attr_values = []
current = []
for v in values:
if v == 0:
if current:
s = "".join(chr(int(x)) for x in current)
attr_values.append(float(s))
current = []
else:
current.append(v)
# Pad with zeros if less than 3
attr_values += [0.0] * (3 - len(attr_values))
m.plugin_attr.append(attr_values[:3])
# equality constraint addresses
m.eq_connect_adr = np.nonzero(mjm.eq_type == types.EqType.CONNECT)[0]
m.eq_wld_adr = np.nonzero(mjm.eq_type == types.EqType.WELD)[0]
m.eq_jnt_adr = np.nonzero(mjm.eq_type == types.EqType.JOINT)[0]
m.eq_ten_adr = np.nonzero(mjm.eq_type == types.EqType.TENDON)[0]
# fixed tendon
m.tendon_jnt_adr, m.wrap_jnt_adr = [], []
for i in range(mjm.ntendon):
adr = mjm.tendon_adr[i]
if mjm.wrap_type[adr] == mujoco.mjtWrap.mjWRAP_JOINT:
tendon_num = mjm.tendon_num[i]
for j in range(tendon_num):
m.tendon_jnt_adr.append(i)
m.wrap_jnt_adr.append(adr + j)
# spatial tendon
m.tendon_site_pair_adr, m.tendon_geom_adr = [], []
m.ten_wrapadr_site, m.ten_wrapnum_site = [0], []
for i, tendon_num in enumerate(mjm.tendon_num):
adr = mjm.tendon_adr[i]
# sites
if (mjm.wrap_type[adr : adr + tendon_num] == mujoco.mjtWrap.mjWRAP_SITE).all():
if i < mjm.ntendon:
m.ten_wrapadr_site.append(m.ten_wrapadr_site[-1] + tendon_num)
m.ten_wrapnum_site.append(tendon_num)
else:
if i < mjm.ntendon:
m.ten_wrapadr_site.append(m.ten_wrapadr_site[-1])
m.ten_wrapnum_site.append(0)
# geoms
for j in range(tendon_num):
wrap_type = mjm.wrap_type[adr + j]
if j < tendon_num - 1:
next_wrap_type = mjm.wrap_type[adr + j + 1]
if wrap_type == mujoco.mjtWrap.mjWRAP_SITE and next_wrap_type == mujoco.mjtWrap.mjWRAP_SITE:
m.tendon_site_pair_adr.append(i)
if wrap_type == mujoco.mjtWrap.mjWRAP_SPHERE or wrap_type == mujoco.mjtWrap.mjWRAP_CYLINDER:
m.tendon_geom_adr.append(i)
m.tendon_limited_adr = np.nonzero(mjm.tendon_limited)[0]
m.wrap_site_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_SITE)[0]
m.wrap_site_pair_adr = np.setdiff1d(m.wrap_site_adr[np.nonzero(np.diff(m.wrap_site_adr) == 1)[0]], mjm.tendon_adr[1:] - 1)
m.wrap_geom_adr = np.nonzero(np.isin(mjm.wrap_type, [mujoco.mjtWrap.mjWRAP_SPHERE, mujoco.mjtWrap.mjWRAP_CYLINDER]))[0]
# pulley scaling
m.wrap_pulley_scale = np.ones(mjm.nwrap, dtype=float)
pulley_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_PULLEY)[0]
for tadr, tnum in zip(mjm.tendon_adr, mjm.tendon_num):
for padr in pulley_adr:
if tadr <= padr < tadr + tnum:
m.wrap_pulley_scale[padr : tadr + tnum] = 1.0 / mjm.wrap_prm[padr]
m.actuator_trntype_body_adr = np.nonzero(mjm.actuator_trntype == mujoco.mjtTrn.mjTRN_BODY)[0]
# sensor addresses
m.sensor_pos_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITPOS)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITPOS)
)[0]
m.sensor_limitpos_adr = np.nonzero(
(mjm.sensor_type == mujoco.mjtSensor.mjSENS_JOINTLIMITPOS) | (mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONLIMITPOS)
)[0]
m.sensor_vel_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITVEL)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITVEL)
)[0]
m.sensor_limitvel_adr = np.nonzero(
(mjm.sensor_type == mujoco.mjtSensor.mjSENS_JOINTLIMITVEL) | (mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONLIMITVEL)
)[0]
m.sensor_acc_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)
& (
(mjm.sensor_type != mujoco.mjtSensor.mjSENS_TOUCH)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITFRC)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITFRC)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONACTFRC)
)
)[0]
m.sensor_rangefinder_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER)[0]
m.rangefinder_sensor_adr = np.full(mjm.nsensor, -1)
m.rangefinder_sensor_adr[m.sensor_rangefinder_adr] = np.arange(len(m.sensor_rangefinder_adr))
m.collision_sensor_adr = np.full(mjm.nsensor, -1)
m.collision_sensor_adr[sensor_collision_adr] = np.arange(len(sensor_collision_adr))
m.sensor_touch_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_TOUCH)[0]
limitfrc_sensors = (mujoco.mjtSensor.mjSENS_JOINTLIMITFRC, mujoco.mjtSensor.mjSENS_TENDONLIMITFRC)
m.sensor_limitfrc_adr = np.nonzero(np.isin(mjm.sensor_type, limitfrc_sensors))[0]
m.sensor_e_potential = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_E_POTENTIAL).any()
m.sensor_e_kinetic = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_E_KINETIC).any()
m.sensor_tendonactfrc_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONACTFRC)[0]
subtreevel_sensors = (mujoco.mjtSensor.mjSENS_SUBTREELINVEL, mujoco.mjtSensor.mjSENS_SUBTREEANGMOM)
m.sensor_subtree_vel = np.isin(mjm.sensor_type, subtreevel_sensors).any()
m.sensor_contact_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT)[0]
m.sensor_adr_to_contact_adr = np.clip(np.cumsum(mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT) - 1, a_min=0, a_max=None)
m.sensor_rne_postconstraint = np.isin(
mjm.sensor_type,
[
mujoco.mjtSensor.mjSENS_ACCELEROMETER,
mujoco.mjtSensor.mjSENS_FORCE,
mujoco.mjtSensor.mjSENS_TORQUE,
mujoco.mjtSensor.mjSENS_FRAMELINACC,
mujoco.mjtSensor.mjSENS_FRAMEANGACC,
],
).any()
m.sensor_rangefinder_bodyid = mjm.site_bodyid[mjm.sensor_objid[mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER]]
m.taxel_vertadr = [
j + mjm.mesh_vertadr[mjm.sensor_objid[i]]
for i in range(mjm.nsensor)
if mjm.sensor_type[i] == mujoco.mjtSensor.mjSENS_TACTILE
for j in range(mjm.mesh_vertnum[mjm.sensor_objid[i]])
]
m.taxel_sensorid = [
i
for i in range(mjm.nsensor)
if mjm.sensor_type[i] == mujoco.mjtSensor.mjSENS_TACTILE
for j in range(mjm.mesh_vertnum[mjm.sensor_objid[i]])
]
# qM_tiles records the block diagonal structure of qM
tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1]
tiles = {}
for i in range(len(tile_corners)):
tile_beg = tile_corners[i]
tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1]
tiles.setdefault(tile_end - tile_beg, []).append(tile_beg)
m.qM_tiles = tuple(types.TileSet(adr=wp.array(tiles[sz], dtype=int), size=sz) for sz in sorted(tiles.keys()))
# qLD_updates has dof tree ordering of qLD updates for sparse factor m
qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1
for k in range(mjm.nv):
# skip diagonal rows
if mjm.M_rownnz[k] == 1:
continue
dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1
i = mjm.dof_parentid[k]
diag_k = mjm.M_rowadr[k] + mjm.M_rownnz[k] - 1
Madr_ki = diag_k - 1
while i > -1:
qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki))
i = mjm.dof_parentid[i]
Madr_ki -= 1
m.qLD_updates = tuple(wp.array(qLD_updates[i], dtype=wp.vec3i) for i in sorted(qLD_updates))
# indices for sparse qM_fullm (used in solver)
m.qM_fullm_i, m.qM_fullm_j = [], []
for i in range(mjm.nv):
j = i
while j > -1:
m.qM_fullm_i.append(i)
m.qM_fullm_j.append(j)
j = mjm.dof_parentid[j]
# indices for sparse qM mul_m (used in support)
m.qM_mulm_i, m.qM_mulm_j, m.qM_madr_ij = [], [], []
for i in range(mjm.nv):
madr_ij, j = mjm.dof_Madr[i], i
while True:
madr_ij, j = madr_ij + 1, mjm.dof_parentid[j]
if j == -1:
break
m.qM_mulm_i.append(i)
m.qM_mulm_j.append(j)
m.qM_madr_ij.append(madr_ij)
# place m on device
sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int})
for f in dataclasses.fields(types.Model):
if isinstance(f.type, wp.array):
setattr(m, f.name, _create_array(getattr(m, f.name), f.type, sizes))
return m
def _get_padded_sizes(nv: int, njmax: int, is_sparse: bool, tile_size: int):
# if dense - we just pad to the next multiple of 4 for nv, to get the fast load path.
# we pad to the next multiple of tile_size for njmax to avoid out of bounds accesses.
# if sparse - we pad to the next multiple of tile_size for njmax, and nv.
def round_up(x, multiple):
return ((x + multiple - 1) // multiple) * multiple
njmax_padded = round_up(njmax, tile_size)
nv_padded = round_up(nv, tile_size) if is_sparse else round_up(nv, 4)
return njmax_padded, nv_padded
[docs]
def make_data(
mjm: mujoco.MjModel,
nworld: int = 1,
nconmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
) -> types.Data:
"""Creates a data object on device.
Args:
mjm: The model containing kinematic and dynamic information (host).
nworld: Number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogeneous arrays: one world may have more than nconmax contacts.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax, njmax to Model?
# TODO(team): improve heuristic for nconmax and njmax
nconmax = nconmax or 20
njmax = njmax or nconmax * 6
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
naconmax = max(512, nworld * nconmax)
elif naconmax < 0:
raise ValueError("naconmax must be >= 0")
if njmax < 0:
raise ValueError("njmax must be >= 0")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if mujoco.mj_isSparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, mujoco.mj_isSparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)})
efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)})
d_kwargs = {
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"njmax": njmax,
"qM": None,
"qLD": None,
"geom_xpos": None,
"geom_xmat": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
d_kwargs[f.name] = _create_array(None, f.type, sizes)
d = types.Data(**d_kwargs)
if mujoco.mj_isSparse(mjm):
d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=float)
d.qLD = wp.zeros((nworld, 1, mjm.nC), dtype=float)
else:
d.qM = wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float)
d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=float)
# static geoms (attached to the world) have their poses calculated once during make_data instead
# of during each physics step. this speeds up scenes with many static geoms (e.g. terrains)
# TODO(team): remove this when we introduce dof islands + sleeping
mjd = mujoco.MjData(mjm)
mujoco.mj_kinematics(mjm, mjd)
d.geom_xpos = wp.array(np.tile(mjd.geom_xpos, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.vec3)
d.geom_xmat = wp.array(np.tile(mjd.geom_xmat, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.mat33)
return d
[docs]
def put_data(
mjm: mujoco.MjModel,
mjd: mujoco.MjData,
nworld: int = 1,
nconmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
) -> types.Data:
"""Moves data from host to a device.
Args:
mjm: The model containing kinematic and dynamic information (host).
mjd: The data object containing current state and output arrays (host).
nworld: The number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogenous arrays: one world may have more than nconmax contacts.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax and njmax to Model?
# TODO(team): decide what to do about uninitialized warp-only fields created by put_data
# we need to ensure these are only workspace fields and don't carry state
# TODO(team): better heuristic for nconmax and njmax
nconmax = nconmax or max(5, 4 * mjd.ncon)
njmax = njmax or max(5, 4 * mjd.nefc)
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
if mjd.ncon > nconmax:
raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")
naconmax = max(512, nworld * nconmax)
elif naconmax < mjd.ncon * nworld:
raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})")
if njmax < 0:
raise ValueError("njmax must be >= 0")
if mjd.nefc > njmax:
raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if mujoco.mj_isSparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, mujoco.mj_isSparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
# ensure static geom positions are computed
# TODO: remove once MjData creation semantics are fixed
mujoco.mj_kinematics(mjm, mjd)
# create contact
contact_kwargs = {"efc_address": None, "worldid": None, "type": None, "geomcollisionid": None}
for f in dataclasses.fields(types.Contact):
if f.name in contact_kwargs:
continue
val = getattr(mjd.contact, f.name)
val = np.repeat(val, nworld, axis=0)
width = ((0, naconmax - val.shape[0]),) + ((0, 0),) * (val.ndim - 1)
val = np.pad(val, width)
contact_kwargs[f.name] = _create_array(val, f.type, sizes)
contact = types.Contact(**contact_kwargs)
contact.efc_address = np.zeros((naconmax, sizes["nmaxpyramid"]), dtype=int)
for i in range(mjd.ncon):
efc_address = mjd.contact.efc_address[i]
if efc_address == -1:
continue
condim = mjd.contact.dim[i]
ndim = max(1, 2 * (condim - 1)) if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL else condim
for j in range(nworld):
contact.efc_address[j * mjd.ncon + i, :ndim] = efc_address + np.arange(ndim)
contact.efc_address = wp.array(contact.efc_address, dtype=int)
contact.worldid = np.pad(np.repeat(np.arange(nworld), mjd.ncon), (0, naconmax - nworld * mjd.ncon))
contact.worldid = wp.array(contact.worldid, dtype=int)
contact.type = wp.ones((naconmax,), dtype=int) # TODO(team): set values
contact.geomcollisionid = wp.empty((naconmax,), dtype=int) # TODO(team): set values
# create efc
efc_kwargs = {"J": None}
for f in dataclasses.fields(types.Constraint):
if f.name in efc_kwargs:
continue
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in f.type.shape)
val = np.zeros(shape, dtype=f.type.dtype)
if f.name in ("type", "id", "pos", "margin", "D", "vel", "aref", "frictionloss", "force"):
val[:, : mjd.nefc] = np.tile(getattr(mjd, "efc_" + f.name), (nworld, 1))
efc_kwargs[f.name] = wp.array(val, dtype=f.type.dtype)
efc = types.Constraint(**efc_kwargs)
if mujoco.mj_isSparse(mjm):
efc_j = np.zeros((mjd.nefc, mjm.nv))
mujoco.mju_sparse2dense(efc_j, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind)
else:
efc_j = mjd.efc_J.reshape((mjd.nefc, mjm.nv))
efc.J = np.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=f.type.dtype)
efc.J[:, : mjd.nefc, : mjm.nv] = np.tile(efc_j, (nworld, 1, 1))
efc.J = wp.array(efc.J, dtype=float)
# create data
d_kwargs = {
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"njmax": njmax,
# fields set after initialization:
"solver_niter": None,
"qM": None,
"qLD": None,
"ten_J": None,
"actuator_moment": None,
"nacon": None,
"ne_connect": None,
"ne_weld": None,
"ne_jnt": None,
"ne_ten": None,
"nsolving": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
val = getattr(mjd, f.name, None)
if val is not None:
shape = val.shape if hasattr(val, "shape") else ()
val = np.full((nworld,) + shape, val)
d_kwargs[f.name] = _create_array(val, f.type, sizes)
d = types.Data(**d_kwargs)
d.solver_niter = wp.full((nworld,), mjd.solver_niter[0], dtype=int)
if mujoco.mj_isSparse(mjm):
d.qM = wp.array(np.full((nworld, 1, mjm.nM), mjd.qM), dtype=float)
d.qLD = wp.array(np.full((nworld, 1, mjm.nC), mjd.qLD), dtype=float)
ten_J = np.zeros((mjm.ntendon, mjm.nv))
mujoco.mju_sparse2dense(ten_J, mjd.ten_J.reshape(-1), mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind.reshape(-1))
d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float)
else:
qM = np.zeros((mjm.nv, mjm.nv))
mujoco.mj_fullM(mjm, qM, mjd.qM)
qLD = np.linalg.cholesky(qM) if (mjd.qM != 0.0).any() and (mjd.qLD != 0.0).any() else np.zeros((mjm.nv, mjm.nv))
padding = sizes["nv_pad"] - mjm.nv
qM_padded = np.pad(qM, ((0, padding), (0, padding)), mode="constant", constant_values=0.0)
d.qM = wp.array(np.full((nworld, sizes["nv_pad"], sizes["nv_pad"]), qM_padded), dtype=float)
d.qLD = wp.array(np.full((nworld, mjm.nv, mjm.nv), qLD), dtype=float)
ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv))
d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float)
# TODO(taylorhowell): sparse actuator_moment
actuator_moment = np.zeros((mjm.nu, mjm.nv))
mujoco.mju_sparse2dense(actuator_moment, mjd.actuator_moment, mjd.moment_rownnz, mjd.moment_rowadr, mjd.moment_colind)
d.actuator_moment = wp.array(np.full((nworld, mjm.nu, mjm.nv), actuator_moment), dtype=float)
d.nacon = wp.array([mjd.ncon * nworld], dtype=int)
d.ne_connect = wp.full(nworld, 3 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_CONNECT) & mjd.eq_active), dtype=int)
d.ne_weld = wp.full(nworld, 6 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_WELD) & mjd.eq_active), dtype=int)
d.ne_jnt = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_JOINT) & mjd.eq_active), dtype=int)
d.ne_ten = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_TENDON) & mjd.eq_active), dtype=int)
d.nsolving = wp.array([nworld], dtype=int)
return d
[docs]
def get_data_into(
result: mujoco.MjData,
mjm: mujoco.MjModel,
d: types.Data,
world_id: int = 0,
):
"""Gets data from a device into an existing mujoco.MjData.
Args:
result: The data object containing the current state and output arrays (host).
mjm: The model containing kinematic and dynamic information (host).
d: The data object containing the current state and output arrays (device).
world_id: The id of the world to get the data from.
"""
# nacon and nefc can overflow. in that case, only pull up to the max contacts and constraints
nacon = min(d.nacon.numpy()[0], d.naconmax)
nefc = min(d.nefc.numpy()[world_id], d.njmax)
ncon_filter = np.zeros_like(d.contact.worldid.numpy(), dtype=bool)
ncon_filter[:nacon] = d.contact.worldid.numpy()[:nacon] == world_id
ncon = ncon_filter.sum()
if ncon != result.ncon or nefc != result.nefc:
# TODO(team): if sparse, set nJ based on sparse efc_J
mujoco._functions._realloc_con_efc(result, ncon=ncon, nefc=nefc, nJ=nefc * mjm.nv)
ne = d.ne.numpy()[world_id]
nf = d.nf.numpy()[world_id]
nl = d.nl.numpy()[world_id]
# efc indexing
# mujoco expects contiguous efc ordering for contacts
# this ordering is not guaranteed with mujoco warp, we enforce order here
if ncon > 0:
efc_idx_efl = np.arange(ne + nf + nl)
contact_dim = d.contact.dim.numpy()[ncon_filter]
contact_efc_address = d.contact.efc_address.numpy()[ncon_filter]
efc_idx_c = []
contact_efc_address_ordered = [ne + nf + nl]
for i in range(ncon):
dim = contact_dim[i]
if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL:
ndim = np.maximum(1, 2 * (dim - 1))
else:
ndim = dim
efc_idx_c.append(contact_efc_address[i, :ndim])
if i < ncon - 1:
contact_efc_address_ordered.append(contact_efc_address_ordered[-1] + ndim)
efc_idx = np.concatenate((efc_idx_efl, *efc_idx_c))
contact_efc_address_ordered = np.array(contact_efc_address_ordered)
else:
efc_idx = np.array(np.arange(nefc))
contact_efc_address_ordered = np.empty(0)
efc_idx = efc_idx[:nefc] # dont emit indices for overflow constraints
result.solver_niter[0] = d.solver_niter.numpy()[world_id]
result.ncon = ncon
result.ne = ne
result.nf = nf
result.nl = nl
result.time = d.time.numpy()[world_id]
result.energy[:] = d.energy.numpy()[world_id]
result.qpos[:] = d.qpos.numpy()[world_id]
result.qvel[:] = d.qvel.numpy()[world_id]
result.act[:] = d.act.numpy()[world_id]
result.qacc_warmstart[:] = d.qacc_warmstart.numpy()[world_id]
result.ctrl[:] = d.ctrl.numpy()[world_id]
result.qfrc_applied[:] = d.qfrc_applied.numpy()[world_id]
result.xfrc_applied[:] = d.xfrc_applied.numpy()[world_id]
result.eq_active[:] = d.eq_active.numpy()[world_id]
result.mocap_pos[:] = d.mocap_pos.numpy()[world_id]
result.mocap_quat[:] = d.mocap_quat.numpy()[world_id]
result.qacc[:] = d.qacc.numpy()[world_id]
result.act_dot[:] = d.act_dot.numpy()[world_id]
result.xpos[:] = d.xpos.numpy()[world_id]
result.xquat[:] = d.xquat.numpy()[world_id]
result.xmat[:] = d.xmat.numpy()[world_id].reshape((-1, 9))
result.xipos[:] = d.xipos.numpy()[world_id]
result.ximat[:] = d.ximat.numpy()[world_id].reshape((-1, 9))
result.xanchor[:] = d.xanchor.numpy()[world_id]
result.xaxis[:] = d.xaxis.numpy()[world_id]
result.geom_xpos[:] = d.geom_xpos.numpy()[world_id]
result.geom_xmat[:] = d.geom_xmat.numpy()[world_id].reshape((-1, 9))
result.site_xpos[:] = d.site_xpos.numpy()[world_id]
result.site_xmat[:] = d.site_xmat.numpy()[world_id].reshape((-1, 9))
result.cam_xpos[:] = d.cam_xpos.numpy()[world_id]
result.cam_xmat[:] = d.cam_xmat.numpy()[world_id].reshape((-1, 9))
result.light_xpos[:] = d.light_xpos.numpy()[world_id]
result.light_xdir[:] = d.light_xdir.numpy()[world_id]
result.subtree_com[:] = d.subtree_com.numpy()[world_id]
result.cdof[:] = d.cdof.numpy()[world_id]
result.cinert[:] = d.cinert.numpy()[world_id]
result.flexvert_xpos[:] = d.flexvert_xpos.numpy()[world_id]
result.flexedge_length[:] = d.flexedge_length.numpy()[world_id]
result.flexedge_velocity[:] = d.flexedge_velocity.numpy()[world_id]
result.actuator_length[:] = d.actuator_length.numpy()[world_id]
mujoco.mju_dense2sparse(
result.actuator_moment,
d.actuator_moment.numpy()[world_id],
result.moment_rownnz,
result.moment_rowadr,
result.moment_colind,
)
result.crb[:] = d.crb.numpy()[world_id]
result.qLDiagInv[:] = d.qLDiagInv.numpy()[world_id]
result.ten_velocity[:] = d.ten_velocity.numpy()[world_id]
result.actuator_velocity[:] = d.actuator_velocity.numpy()[world_id]
result.cvel[:] = d.cvel.numpy()[world_id]
result.cdof_dot[:] = d.cdof_dot.numpy()[world_id]
result.qfrc_bias[:] = d.qfrc_bias.numpy()[world_id]
result.qfrc_spring[:] = d.qfrc_spring.numpy()[world_id]
result.qfrc_damper[:] = d.qfrc_damper.numpy()[world_id]
result.qfrc_gravcomp[:] = d.qfrc_gravcomp.numpy()[world_id]
result.qfrc_fluid[:] = d.qfrc_fluid.numpy()[world_id]
result.qfrc_passive[:] = d.qfrc_passive.numpy()[world_id]
result.subtree_linvel[:] = d.subtree_linvel.numpy()[world_id]
result.subtree_angmom[:] = d.subtree_angmom.numpy()[world_id]
result.actuator_force[:] = d.actuator_force.numpy()[world_id]
result.qfrc_actuator[:] = d.qfrc_actuator.numpy()[world_id]
result.qfrc_smooth[:] = d.qfrc_smooth.numpy()[world_id]
result.qacc_smooth[:] = d.qacc_smooth.numpy()[world_id]
result.qfrc_constraint[:] = d.qfrc_constraint.numpy()[world_id]
result.qfrc_inverse[:] = d.qfrc_inverse.numpy()[world_id]
# contact
result.contact.dist[:ncon] = d.contact.dist.numpy()[ncon_filter]
result.contact.pos[:ncon] = d.contact.pos.numpy()[ncon_filter]
result.contact.frame[:ncon] = d.contact.frame.numpy()[ncon_filter].reshape((-1, 9))
result.contact.includemargin[:ncon] = d.contact.includemargin.numpy()[ncon_filter]
result.contact.friction[:ncon] = d.contact.friction.numpy()[ncon_filter]
result.contact.solref[:ncon] = d.contact.solref.numpy()[ncon_filter]
result.contact.solreffriction[:ncon] = d.contact.solreffriction.numpy()[ncon_filter]
result.contact.solimp[:ncon] = d.contact.solimp.numpy()[ncon_filter]
result.contact.dim[:ncon] = d.contact.dim.numpy()[ncon_filter]
result.contact.geom[:ncon] = d.contact.geom.numpy()[ncon_filter]
result.contact.efc_address[:ncon] = contact_efc_address_ordered[:ncon]
if mujoco.mj_isSparse(mjm):
result.qM[:] = d.qM.numpy()[world_id, 0]
result.qLD[:] = d.qLD.numpy()[world_id, 0]
if nefc > 0:
efc_J = d.efc.J.numpy()[world_id, efc_idx, : mjm.nv]
mujoco.mju_dense2sparse(result.efc_J, efc_J, result.efc_J_rownnz, result.efc_J_rowadr, result.efc_J_colind)
else:
qM = d.qM.numpy()[world_id]
adr = 0
for i in range(mjm.nv):
j = i
while j >= 0:
result.qM[adr] = qM[i, j]
j = mjm.dof_parentid[j]
adr += 1
mujoco.mj_factorM(mjm, result)
if nefc > 0:
result.efc_J[: nefc * mjm.nv] = d.efc.J.numpy()[world_id, :nefc, : mjm.nv].flatten()
# efc
result.efc_type[:] = d.efc.type.numpy()[world_id, efc_idx]
result.efc_id[:] = d.efc.id.numpy()[world_id, efc_idx]
result.efc_pos[:] = d.efc.pos.numpy()[world_id, efc_idx]
result.efc_margin[:] = d.efc.margin.numpy()[world_id, efc_idx]
result.efc_D[:] = d.efc.D.numpy()[world_id, efc_idx]
result.efc_vel[:] = d.efc.vel.numpy()[world_id, efc_idx]
result.efc_aref[:] = d.efc.aref.numpy()[world_id, efc_idx]
result.efc_frictionloss[:] = d.efc.frictionloss.numpy()[world_id, efc_idx]
result.efc_state[:] = d.efc.state.numpy()[world_id, efc_idx]
result.efc_force[:] = d.efc.force.numpy()[world_id, efc_idx]
# rne_postconstraint
result.cacc[:] = d.cacc.numpy()[world_id]
result.cfrc_int[:] = d.cfrc_int.numpy()[world_id]
result.cfrc_ext[:] = d.cfrc_ext.numpy()[world_id]
# tendon
result.ten_length[:] = d.ten_length.numpy()[world_id]
result.ten_J[:] = d.ten_J.numpy()[world_id]
result.ten_wrapadr[:] = d.ten_wrapadr.numpy()[world_id]
result.ten_wrapnum[:] = d.ten_wrapnum.numpy()[world_id]
result.wrap_obj[:] = d.wrap_obj.numpy()[world_id]
result.wrap_xpos[:] = d.wrap_xpos.numpy()[world_id]
# sensors
result.sensordata[:] = d.sensordata.numpy()[world_id]
[docs]
def reset_data(m: types.Model, d: types.Data, reset: Optional[wp.array] = None):
"""Clear data, set defaults; optionally by world.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
reset: Per-world bitmask. Reset if True.
"""
@nested_kernel(module="unique", enable_backward=False)
def reset_xfrc_applied(reset_in: wp.array(dtype=bool), xfrc_applied_out: wp.array2d(dtype=wp.spatial_vector)):
worldid, bodyid, elemid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
xfrc_applied_out[worldid, bodyid][elemid] = 0.0
@nested_kernel(module="unique", enable_backward=False)
def reset_qM(reset_in: wp.array(dtype=bool), qM_out: wp.array3d(dtype=float)):
worldid, elemid1, elemid2 = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
qM_out[worldid, elemid1, elemid2] = 0.0
@nested_kernel(module="unique", enable_backward=False)
def reset_nworld(
# Model:
nq: int,
nv: int,
nu: int,
na: int,
neq: int,
nsensordata: int,
qpos0: wp.array2d(dtype=float),
eq_active0: wp.array(dtype=bool),
# Data in:
nworld_in: int,
# In:
reset_in: wp.array(dtype=bool),
# Data out:
solver_niter_out: wp.array(dtype=int),
ne_out: wp.array(dtype=int),
nf_out: wp.array(dtype=int),
nl_out: wp.array(dtype=int),
nefc_out: wp.array(dtype=int),
time_out: wp.array(dtype=float),
energy_out: wp.array(dtype=wp.vec2),
qpos_out: wp.array2d(dtype=float),
qvel_out: wp.array2d(dtype=float),
act_out: wp.array2d(dtype=float),
qacc_warmstart_out: wp.array2d(dtype=float),
ctrl_out: wp.array2d(dtype=float),
qfrc_applied_out: wp.array2d(dtype=float),
eq_active_out: wp.array2d(dtype=bool),
qacc_out: wp.array2d(dtype=float),
act_dot_out: wp.array2d(dtype=float),
sensordata_out: wp.array2d(dtype=float),
nacon_out: wp.array(dtype=int),
ne_connect_out: wp.array(dtype=int),
ne_weld_out: wp.array(dtype=int),
ne_jnt_out: wp.array(dtype=int),
ne_ten_out: wp.array(dtype=int),
nsolving_out: wp.array(dtype=int),
):
worldid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
solver_niter_out[worldid] = 0
if worldid == 0:
nacon_out[0] = 0
ne_out[worldid] = 0
ne_connect_out[worldid] = 0
ne_weld_out[worldid] = 0
ne_jnt_out[worldid] = 0
ne_ten_out[worldid] = 0
nf_out[worldid] = 0
nl_out[worldid] = 0
nefc_out[worldid] = 0
if worldid == 0:
nsolving_out[0] = nworld_in
time_out[worldid] = 0.0
energy_out[worldid] = wp.vec2(0.0, 0.0)
for i in range(nq):
qpos_out[worldid, i] = qpos0[worldid, i]
if i < nv:
qvel_out[worldid, i] = 0.0
qacc_warmstart_out[worldid, i] = 0.0
qfrc_applied_out[worldid, i] = 0.0
qacc_out[worldid, i] = 0.0
for i in range(nu):
ctrl_out[worldid, i] = 0.0
if i < na:
act_out[worldid, i] = 0.0
act_dot_out[worldid, i] = 0.0
for i in range(neq):
eq_active_out[worldid, i] = eq_active0[i]
for i in range(nsensordata):
sensordata_out[worldid, i] = 0.0
@nested_kernel(module="unique", enable_backward=False)
def reset_mocap(
# Model:
body_mocapid: wp.array(dtype=int),
body_pos: wp.array2d(dtype=wp.vec3),
body_quat: wp.array2d(dtype=wp.quat),
# In:
reset_in: wp.array(dtype=bool),
# Data out:
mocap_pos_out: wp.array2d(dtype=wp.vec3),
mocap_quat_out: wp.array2d(dtype=wp.quat),
):
worldid, bodyid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
mocapid = body_mocapid[bodyid]
if mocapid >= 0:
mocap_pos_out[worldid, mocapid] = body_pos[worldid, bodyid]
mocap_quat_out[worldid, mocapid] = body_quat[worldid, bodyid]
@nested_kernel(module="unique", enable_backward=False)
def reset_contact(
# Data in:
nacon_in: wp.array(dtype=int),
# In:
reset_in: wp.array(dtype=bool),
nefcaddress: int,
# Data out:
contact_dist_out: wp.array(dtype=float),
contact_pos_out: wp.array(dtype=wp.vec3),
contact_frame_out: wp.array(dtype=wp.mat33),
contact_includemargin_out: wp.array(dtype=float),
contact_friction_out: wp.array(dtype=types.vec5),
contact_solref_out: wp.array(dtype=wp.vec2),
contact_solreffriction_out: wp.array(dtype=wp.vec2),
contact_solimp_out: wp.array(dtype=types.vec5),
contact_dim_out: wp.array(dtype=int),
contact_geom_out: wp.array(dtype=wp.vec2i),
contact_efc_address_out: wp.array2d(dtype=int),
contact_worldid_out: wp.array(dtype=int),
contact_type_out: wp.array(dtype=int),
contact_geomcollisionid_out: wp.array(dtype=int),
):
conid = wp.tid()
if conid >= nacon_in[0]:
return
worldid = contact_worldid_out[conid]
if wp.static(reset is not None):
if worldid >= 0:
if not reset_in[worldid]:
return
contact_dist_out[conid] = 0.0
contact_pos_out[conid] = wp.vec3(0.0)
contact_frame_out[conid] = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
contact_includemargin_out[conid] = 0.0
contact_friction_out[conid] = types.vec5(0.0, 0.0, 0.0, 0.0, 0.0)
contact_solref_out[conid] = wp.vec2(0.0, 0.0)
contact_solreffriction_out[conid] = wp.vec2(0.0, 0.0)
contact_solimp_out[conid] = types.vec5(0.0, 0.0, 0.0, 0.0, 0.0)
contact_dim_out[conid] = 0
contact_geom_out[conid] = wp.vec2i(0, 0)
for i in range(nefcaddress):
contact_efc_address_out[conid, i] = 0
contact_worldid_out[conid] = 0
contact_type_out[conid] = 0
contact_geomcollisionid_out[conid] = 0
reset_input = reset or wp.ones(d.nworld, dtype=bool)
wp.launch(reset_xfrc_applied, dim=(d.nworld, m.nbody, 6), inputs=[reset_input], outputs=[d.xfrc_applied])
wp.launch(
reset_qM,
dim=(d.nworld, d.qM.shape[1], d.qM.shape[2]),
inputs=[reset_input],
outputs=[d.qM],
)
# set mocap_pos/quat = body_pos/quat for mocap bodies
wp.launch(
reset_mocap,
dim=(d.nworld, m.nbody),
inputs=[m.body_mocapid, m.body_pos, m.body_quat, reset_input],
outputs=[d.mocap_pos, d.mocap_quat],
)
# clear contacts
wp.launch(
reset_contact,
dim=d.naconmax,
inputs=[d.nacon, reset_input, d.contact.efc_address.shape[1]],
outputs=[
d.contact.dist,
d.contact.pos,
d.contact.frame,
d.contact.includemargin,
d.contact.friction,
d.contact.solref,
d.contact.solreffriction,
d.contact.solimp,
d.contact.dim,
d.contact.geom,
d.contact.efc_address,
d.contact.worldid,
d.contact.type,
d.contact.geomcollisionid,
],
)
wp.launch(
reset_nworld,
dim=d.nworld,
inputs=[m.nq, m.nv, m.nu, m.na, m.neq, m.nsensordata, m.qpos0, m.eq_active0, d.nworld, reset_input],
outputs=[
d.solver_niter,
d.ne,
d.nf,
d.nl,
d.nefc,
d.time,
d.energy,
d.qpos,
d.qvel,
d.act,
d.qacc_warmstart,
d.ctrl,
d.qfrc_applied,
d.eq_active,
d.qacc,
d.act_dot,
d.sensordata,
d.nacon,
d.ne_connect,
d.ne_weld,
d.ne_jnt,
d.ne_ten,
d.nsolving,
],
)
def override_model(model: Union[types.Model, mujoco.MjModel], overrides: Union[dict[str, Any], Sequence[str]]):
"""Overrides model parameters.
Overrides are of the format:
opt.iterations = 1
opt.ls_parallel = True
opt.cone = pyramidal
opt.disableflags = contact | spring
"""
enum_fields = {
"opt.broadphase": types.BroadphaseType,
"opt.broadphase_filter": types.BroadphaseFilter,
"opt.cone": types.ConeType,
"opt.disableflags": types.DisableBit,
"opt.enableflags": types.EnableBit,
"opt.integrator": types.IntegratorType,
"opt.solver": types.SolverType,
}
mjw_only_fields = {"opt.broadphase", "opt.broadphase_filter", "opt.ls_parallel", "opt.graph_conditional"}
mj_only_fields = {"opt.jacobian"}
if not isinstance(overrides, dict):
overrides_dict = {}
for override in overrides:
if "=" not in override:
raise ValueError(f"Invalid override format: {override}")
k, v = override.split("=", 1)
overrides_dict[k.strip()] = v.strip()
overrides = overrides_dict
for key, val in overrides.items():
# skip overrides on MjModel for properties that are only on mjw.Model
if key in mjw_only_fields and isinstance(model, mujoco.MjModel):
continue
if key in mj_only_fields and isinstance(model, types.Model):
continue
obj, attrs = model, key.split(".")
for i, attr in enumerate(attrs):
if not hasattr(obj, attr):
raise ValueError(f"Unrecognized model field: {key}")
if i < len(attrs) - 1:
obj = getattr(obj, attr)
continue
typ = type(getattr(obj, attr))
if key in enum_fields and isinstance(val, str):
# special case: enum value
enum_members = val.split("|")
val = 0
for enum_member in enum_members:
enum_member = enum_member.strip().upper()
if enum_member not in enum_fields[key].__members__:
raise ValueError(f"Unrecognized enum value for {enum_fields[key].__name__}: {enum_member}")
val |= int(enum_fields[key][enum_member])
elif typ is bool and isinstance(val, str):
# special case: "true", "TRUE", "false", "FALSE" etc.
if val.upper() not in ("TRUE", "FALSE"):
raise ValueError(f"Unrecognized value for field: {key}")
val = val.upper() == "TRUE"
else:
val = typ(val)
setattr(obj, attr, val)
def find_keys(model: mujoco.MjModel, keyname_prefix: str) -> list[int]:
"""Finds keyframes that start with keyname_prefix."""
keys = []
for keyid in range(model.nkey):
name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_KEY, keyid)
if name.startswith(keyname_prefix):
keys.append(keyid)
return keys
def make_trajectory(model: mujoco.MjModel, keys: list[int]) -> np.ndarray:
"""Make a ctrl trajectory with linear interpolation."""
ctrls = []
prev_ctrl_key = np.zeros(model.nu, dtype=np.float64)
prev_time, time = 0.0, 0.0
for key in keys:
ctrl_key, ctrl_time = model.key_ctrl[key], model.key_time[key]
if not ctrls and ctrl_time != 0.0:
raise ValueError("first keyframe must have time 0.0")
elif ctrls and ctrl_time <= prev_time:
raise ValueError("keyframes must be in time order")
while time < ctrl_time:
frac = (time - prev_time) / (ctrl_time - prev_time)
ctrls.append(prev_ctrl_key * (1 - frac) + ctrl_key * frac)
time += model.opt.timestep
ctrls.append(ctrl_key)
time += model.opt.timestep
prev_ctrl_key = ctrl_key
prev_time = time
return np.array(ctrls)