Source code for mujoco_warp._src.ray

# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Optional, Tuple

import warp as wp

from mujoco_warp._src.math import safe_div
from mujoco_warp._src.types import MJ_MINVAL
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import GeomType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import vec6

wp.set_module_options({"enable_backward": False})


@wp.func
def _ray_map(pos: wp.vec3, mat: wp.mat33, pnt: wp.vec3, vec: wp.vec3) -> Tuple[wp.vec3, wp.vec3]:
  """Maps ray to local geom frame coordinates.

  Args:
      pos: position of geom frame
      mat: orientation of geom frame
      pnt: starting point of ray in world coordinates
      vec: direction of ray in world coordinates

  Returns:
      3D point and 3D direction in local geom frame
  """
  matT = wp.transpose(mat)
  lpnt = matT @ (pnt - pos)
  lvec = matT @ vec

  return lpnt, lvec


@wp.func
def _ray_eliminate(
  # Model:
  body_weldid: wp.array(dtype=int),
  geom_bodyid: wp.array(dtype=int),
  geom_matid: wp.array(dtype=int),  # kernel_analyzer: ignore
  geom_group: wp.array(dtype=int),
  geom_rgba: wp.array(dtype=wp.vec4),  # kernel_analyzer: ignore
  mat_rgba: wp.array(dtype=wp.vec4),  # kernel_analyzer: ignore
  # In:
  geomid: int,
  geomgroup: vec6,
  flg_static: bool,
  bodyexclude: int,
) -> bool:
  """Eliminate ray."""
  bodyid = geom_bodyid[geomid]
  matid = geom_matid[geomid]

  # body exclusion
  if bodyid == bodyexclude:
    return True

  # invisible geom exclusion
  if matid < 0 and geom_rgba[geomid][3] == 0.0:
    return True

  # invisible material exclusion
  if matid >= 0:
    if mat_rgba[matid][3] == 0.0:
      return True

  # static exclusion
  if not flg_static and body_weldid[bodyid] == 0:
    return True

  # no geomgroup inclusion
  if (
    geomgroup[0] == -1
    and geomgroup[1] == -1
    and geomgroup[2] == -1
    and geomgroup[3] == -1
    and geomgroup[4] == -1
    and geomgroup[5] == -1
  ):
    return False

  # group inclusion/exclusion
  groupid = wp.min(5, wp.max(0, geom_group[geomid]))

  return geomgroup[groupid] == 0


@wp.func
def _ray_quad(a: float, b: float, c: float) -> Tuple[float, wp.vec2]:
  """Compute solutions from quadratic: a*x^2 + 2*b*x + c = 0."""
  det = b * b - a * c
  if det < MJ_MINVAL:
    return wp.inf, wp.vec2(wp.inf, wp.inf)
  det = wp.sqrt(det)

  # compute the two solutions
  den = safe_div(1.0, a)
  x0 = (-b - det) * den
  x1 = (-b + det) * den
  x = wp.vec2(x0, x1)

  # finalize result
  if x0 >= 0.0:
    return x0, x
  elif x1 >= 0.0:
    return x1, x
  else:
    return wp.inf, x


@wp.func
def _ray_triangle(v0: wp.vec3, v1: wp.vec3, v2: wp.vec3, pnt: wp.vec3, vec: wp.vec3, b0: wp.vec3, b1: wp.vec3) -> float:
  """Returns the distance at which a ray intersects with a triangle."""
  dif0 = v0 - pnt
  dif1 = v1 - pnt
  dif2 = v2 - pnt

  # project difference vectors in normal plane
  planar_00 = wp.dot(dif0, b0)
  planar_01 = wp.dot(dif0, b1)
  planar_10 = wp.dot(dif1, b0)
  planar_11 = wp.dot(dif1, b1)
  planar_20 = wp.dot(dif2, b0)
  planar_21 = wp.dot(dif2, b1)

  # reject if on the same side of any coordinate axis
  if (
    (planar_00 > 0.0 and planar_10 > 0.0 and planar_20 > 0.0)
    or (planar_00 < 0.0 and planar_10 < 0.0 and planar_20 < 0.0)
    or (planar_01 > 0.0 and planar_11 > 0.0 and planar_21 > 0.0)
    or (planar_01 < 0.0 and planar_11 < 0.0 and planar_21 < 0.0)
  ):
    return float(wp.inf)

  # determine if origin is inside planar projection of triangle
  # A = (p0-p2, p1-p2), b = -p2, solve A*t = b
  A00 = planar_00 - planar_20
  A10 = planar_10 - planar_20
  A01 = planar_01 - planar_21
  A11 = planar_11 - planar_21

  b = wp.vec2(-planar_20, -planar_21)

  det = A00 * A11 - A10 * A01
  if wp.abs(det) < MJ_MINVAL:
    return float(wp.inf)

  t0 = (A11 * b[0] - A10 * b[1]) / det
  t1 = (-A01 * b[0] + A00 * b[1]) / det

  # check if outside
  if t0 < 0.0 or t1 < 0.0 or t0 + t1 > 1.0:
    return float(wp.inf)

  # intersect ray with plane of triangle
  dif0 = v0 - v2
  dif1 = v1 - v2
  dif2 = pnt - v2
  nrm = wp.cross(dif0, dif1)  # normal to triangle plane
  denom = wp.dot(vec, nrm)
  if wp.abs(denom) < MJ_MINVAL:
    return float(wp.inf)

  dist = -wp.dot(dif2, nrm) / denom
  return wp.where(dist >= 0.0, dist, float(wp.inf))


@wp.func
def _ray_plane(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3) -> float:
  """Returns the distance at which a ray intersects with a plane."""
  # map to local frame
  lpnt, lvec = _ray_map(pos, mat, pnt, vec)

  # z-vec not pointing towards front face: reject
  if lvec[2] > -MJ_MINVAL:
    return wp.inf

  # intersection with plane
  x = -lpnt[2] / lvec[2]
  if x < 0.0:
    return wp.inf

  p = wp.vec2(lpnt[0] + x * lvec[0], lpnt[1] + x * lvec[1])

  # accept only within rendered rectangle
  if (size[0] <= 0.0 or wp.abs(p[0]) <= size[0]) and (size[1] <= 0.0 or wp.abs(p[1]) <= size[1]):
    return x
  else:
    return wp.inf


@wp.func
def _ray_sphere(pos: wp.vec3, dist_sqr: float, pnt: wp.vec3, vec: wp.vec3) -> float:
  """Returns the distance at which a ray intersects with a sphere."""
  dif = pnt - pos

  a = wp.dot(vec, vec)
  b = wp.dot(vec, dif)
  c = wp.dot(dif, dif) - dist_sqr

  sol, _ = _ray_quad(a, b, c)
  return sol


@wp.func
def _ray_capsule(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3) -> float:
  """Returns the distance at which a ray intersects with a capsule."""
  # bounding sphere test
  ssz = size[0] + size[1]
  if _ray_sphere(pos, ssz * ssz, pnt, vec) < 0.0:
    return wp.inf

  # map to local frame
  lpnt, lvec = _ray_map(pos, mat, pnt, vec)

  # init solution
  x = -1.0

  # cylinder round side: (x * lvec + lpnt)' * (x * lvec + lpnt) = size[0] * size[0]
  sq_size0 = size[0] * size[0]
  a = lvec[0] * lvec[0] + lvec[1] * lvec[1]
  b = lvec[0] * lpnt[0] + lvec[1] * lpnt[1]
  c = lpnt[0] * lpnt[0] + lpnt[1] * lpnt[1] - sq_size0

  # solve a * x^2 + 2 * b * x + c = 0
  sol, xx = _ray_quad(a, b, c)

  # make sure round solution is between flat sides
  if sol >= 0.0 and wp.abs(lpnt[2] + sol * vec[2]) <= size[1]:
    if x < 0.0 or sol < x:
      x = sol

  # top cap
  ldif = wp.vec3(lpnt[0], lpnt[1], lpnt[2] - size[1])
  a += lvec[2] * lvec[2]
  b = wp.dot(lvec, ldif)
  c = wp.dot(ldif, ldif) - sq_size0
  _, xx = _ray_quad(a, b, c)

  # accept only top half of sphere
  for i in range(2):
    if xx[i] >= 0.0 and lpnt[2] + xx[i] * lvec[2] >= size[1]:
      if x < 0.0 or xx[i] < x:
        x = xx[i]

  # bottom cap
  ldif = wp.vec3(ldif[0], ldif[1], lpnt[2] + size[1])
  b = wp.dot(lvec, ldif)
  c = wp.dot(ldif, ldif) - sq_size0
  _, xx = _ray_quad(a, b, c)

  # accept only bottom half of sphere
  for i in range(2):
    if xx[i] >= 0.0 and lpnt[2] + xx[i] * lvec[2] <= -size[1]:
      if x < 0.0 or xx[i] < x:
        x = xx[i]

  return x


@wp.func
def _ray_ellipsoid(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3) -> float:
  """Returns the distance at which a ray intersects with an ellipsoid."""
  # map to local frame
  lpnt, lvec = _ray_map(pos, mat, pnt, vec)

  # invert size^2
  s = wp.vec3(safe_div(1.0, size[0] * size[0]), safe_div(1.0, size[1] * size[1]), safe_div(1.0, size[2] * size[2]))

  # (x * lvec + lpnt)' * diag(1 / size^2) * (x * lvec + lpnt) = 1
  slvec = wp.cw_mul(s, lvec)
  a = wp.dot(slvec, lvec)
  b = wp.dot(slvec, lpnt)
  c = wp.dot(wp.cw_mul(s, lpnt), lpnt) - 1.0

  # solve a * x^2 + 2 * b * x + c = 0
  sol, _ = _ray_quad(a, b, c)
  return sol


@wp.func
def _ray_cylinder(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3) -> float:
  """Returns the distance at which a ray intersects with a cylinder."""
  # bounding sphere test
  ssz = size[0] * size[0] + size[1] * size[1]
  if _ray_sphere(pos, ssz, pnt, vec) < 0.0:
    return wp.inf

  # map to local frame
  lpnt, lvec = _ray_map(pos, mat, pnt, vec)

  # init solution
  x = wp.inf

  # flat sides
  if wp.abs(lvec[2]) > MJ_MINVAL:
    for side in range(-1, 2, 2):
      # solution of: lpnt[2] + x * lvec[2] = side * height_size
      sol = (float(side) * size[1] - lpnt[2]) / lvec[2]

      # process if non-negative
      if sol >= 0.0:
        # intersection with horizontal face
        p = wp.vec2(lpnt[0] + sol * lvec[0], lpnt[1] + sol * lvec[1])

        # accept within radius
        if wp.dot(p, p) <= size[0] * size[0]:
          if x < 0.0 or sol < x:
            x = sol

  # (x * lvec + lpnt)' * (x * lvec + lpnt) = size[0] * size[0]
  a = lvec[0] * lvec[0] + lvec[1] * lvec[1]
  b = lvec[0] * lpnt[0] + lvec[1] * lpnt[1]
  c = lpnt[0] * lpnt[0] + lpnt[1] * lpnt[1] - size[0] * size[0]

  # solve a * x^2 + 2 * b * x + c = 0
  sol, _ = _ray_quad(a, b, c)

  # make sure round solution is between flat sides
  if sol >= 0.0 and wp.abs(lpnt[2] + sol * lvec[2]) <= size[1]:
    if x < 0.0 or sol < x:
      x = sol

  return x


_IFACE = wp.types.matrix((3, 2), dtype=int)(1, 2, 0, 2, 0, 1)


@wp.func
def _ray_box(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3) -> Tuple[float, vec6]:
  """Returns the distance at which a ray intersects with a box."""
  all = vec6(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0)

  # bounding sphere test
  ssz = wp.dot(size, size)
  if _ray_sphere(pos, ssz, pnt, vec) < 0.0:
    return wp.inf, all

  # map to local frame
  lpnt, lvec = _ray_map(pos, mat, pnt, vec)

  # init solution
  x = wp.inf

  # loop over axes with non-zero vec
  for i in range(3):
    if wp.abs(lvec[i]) > MJ_MINVAL:
      for side in range(-1, 2, 2):
        # solution of: lpnt[i] + x * lvec[i] = side * size[i]
        sol = (float(side) * size[i] - lpnt[i]) / lvec[i]

        # process if non-negative
        if sol >= 0.0:
          id0 = _IFACE[i][0]
          id1 = _IFACE[i][1]

          # intersection with face
          p0 = lpnt[id0] + sol * lvec[id0]
          p1 = lpnt[id1] + sol * lvec[id1]

          # accept within rectangle
          if (wp.abs(p0) <= size[id0]) and (wp.abs(p1) <= size[id1]):
            # update
            if (x < 0.0) or (sol < x):
              x = sol

            # save in all
            all[2 * i + (side + 1) // 2] = sol

  return x, all


@wp.func
def _ray_hfield(
  # Model:
  geom_type: wp.array(dtype=int),
  geom_dataid: wp.array(dtype=int),
  hfield_size: wp.array(dtype=wp.vec4),
  hfield_nrow: wp.array(dtype=int),
  hfield_ncol: wp.array(dtype=int),
  hfield_adr: wp.array(dtype=int),
  hfield_data: wp.array(dtype=float),
  # In:
  pos: wp.vec3,
  mat: wp.mat33,
  pnt: wp.vec3,
  vec: wp.vec3,
  id: int,
):
  # check geom type
  if geom_type[id] != GeomType.HFIELD:
    return wp.inf

  # hfield id and dimensions
  hid = geom_dataid[id]
  nrow = hfield_nrow[hid]
  ncol = hfield_ncol[hid]

  size = hfield_size[hid]
  adr = hfield_adr[hid]

  mat_col = wp.vec3(mat[0, 2], mat[1, 2], mat[2, 2])

  # compute size and pos of base box
  base_scale = size[3] * 0.5
  base_size = wp.vec3(size[0], size[1], base_scale)
  base_pos = pos + mat_col * base_scale

  # compute size and pos of top box
  top_scale = size[2] * 0.5
  top_size = wp.vec3(size[0], size[1], top_scale)
  top_pos = pos + mat_col * top_scale

  # init: intersection with base box
  x, _ = _ray_box(base_pos, mat, base_size, pnt, vec)

  # check top box: done if no intersection
  top_intersect, all = _ray_box(top_pos, mat, top_size, pnt, vec)

  if top_intersect < 0.0:
    return x

  # map to local frame
  lpnt, lvec = _ray_map(pos, mat, pnt, vec)

  # construct basis vectors of normal plane
  b0 = wp.vec3(1.0, 1.0, 1.0)

  if wp.abs(lvec[0]) >= wp.abs(lvec[1]) and wp.abs(lvec[0]) >= wp.abs(lvec[2]):
    b0[0] = 0.0
  elif wp.abs(lvec[1]) >= wp.abs(lvec[2]):
    b0[1] = 0.0
  else:
    b0[2] = 0.0
  b1 = b0 + lvec * -safe_div(wp.dot(lvec, b0), wp.dot(lvec, lvec))
  b1 = wp.normalize(b1)

  b2 = wp.cross(b1, lvec)
  b2 = wp.normalize(b2)

  # find ray segment intersecting top box
  seg = wp.vec2(0.0, top_intersect)
  for i in range(6):
    if all[i] > seg[1]:
      seg[0] = top_intersect
      seg[1] = all[i]

  # project segment endpoints in horizontal plane, discretize
  dx = safe_div(2.0 * size[0], float(ncol - 1))
  dy = safe_div(2.0 * size[1], float(nrow - 1))
  SX = wp.vec2(safe_div(lpnt[0] * seg[0] * lvec[0] + size[0], dx), safe_div(lpnt[0] * seg[1] * lvec[0] + size[0], dx))
  SY = wp.vec2(safe_div(lpnt[1] + seg[0] * lvec[1] + size[1], dy), safe_div(lpnt[1] + seg[1] * lvec[1] + size[1], dy))

  # compute ranges, with +1 padding
  cmin = wp.max(0, int(wp.floor(wp.min(SX[0], SX[1])) - 1.0))
  cmax = wp.min(ncol - 1, int(wp.ceil(wp.max(SX[0], SX[1])) + 1.0))
  rmin = wp.max(0, int(wp.floor(wp.min(SY[0], SY[1])) - 1.0))
  rmax = wp.min(nrow - 1, int(wp.ceil(wp.max(SY[0], SY[1])) + 1.0))

  # check triangles within bounds
  for r in range(rmin, rmax):
    for c in range(cmin, cmax):
      # first triangle
      v0 = wp.vec3(dx * float(c) - size[0], dy * float(r) - size[1], hfield_data[adr + r * ncol + c] * size[2])
      v1 = wp.vec3(
        dx * float(c + 1) - size[0], dy * float(r + 1) - size[1], hfield_data[adr + (r + 1) * ncol + (c + 1)] * size[2]
      )
      v2 = wp.vec3(dx * float(c + 1) - size[0], dy * float(r) - size[1], hfield_data[adr + r * ncol + (c + 1)] * size[2])
      sol = _ray_triangle(v0, v1, v2, pnt, vec, b0, b1)
      if sol >= 0.0 and (x < 0.0 or sol < x):
        x = sol

      # second triangle
      v0 = wp.vec3(dx * float(c) - size[0], dy * float(r) - size[1], hfield_data[adr + r * ncol + c] * size[2])
      v1 = wp.vec3(
        dx * float(c + 1) - size[0], dy * float(r + 1) - size[1], hfield_data[adr + (r + 1) * ncol + (c + 1)] * size[2]
      )
      v2 = wp.vec3(dx * float(c) - size[0], dy * float(r + 1) - size[1], hfield_data[adr + (r + 1) * ncol + c] * size[2])
      sol = _ray_triangle(v0, v1, v2, pnt, vec, b0, b1)
      if sol >= 0.0 and (x < 0.0 or sol < x):
        x = sol

  # check viable sides of top box
  for i in range(4):
    if all[i] >= 0.0 and (all[i] < x or x < 0.0):
      # normalized height of intersection point
      z = safe_div(lpnt[2] + all[i] * lvec[2], size[2])

      # rectangle points: y, y0, z0, z1
      # side normal to x-axis
      if i < 2:
        y = safe_div(lpnt[1] + all[i] * lvec[1] + size[1], dy)
        y0 = wp.max(0.0, wp.min(float(nrow - 2), wp.floor(y)))
        if i == 1:
          z0 = hfield_data[adr + int(wp.round(y0 + 0.0)) * ncol + ncol - 1]
          z1 = hfield_data[adr + int(wp.round(y0 + 1.0)) * ncol + ncol - 1]
        else:
          z0 = hfield_data[adr + int(wp.round(y0 + 0.0)) * ncol]
          z1 = hfield_data[adr + int(wp.round(y0 + 1.0)) * ncol]
      # side normal to y-axis
      else:
        y = safe_div(lpnt[0] + all[i] * lvec[0] + size[0], dx)
        y0 = wp.max(0.0, wp.min(float(ncol - 2), wp.floor(y)))
        if i == 3:
          z0 = hfield_data[adr + int(wp.round(y0 + 0.0)) + (nrow - 1) * ncol]
          z1 = hfield_data[adr + int(wp.round(y0 + 1.0)) + (nrow - 1) * ncol]
        else:
          z0 = hfield_data[adr + int(wp.round(y0 + 0.0))]
          z1 = hfield_data[adr + int(wp.round(y0 + 1.0))]

      # check if point is below line segments
      if z < z0 * (y0 + 1.0 - y) + z1 * (y - y0):
        x = all[i]

  return x


@wp.func
def ray_mesh(
  # Model:
  nmeshface: int,
  mesh_vertadr: wp.array(dtype=int),
  mesh_faceadr: wp.array(dtype=int),
  mesh_vert: wp.array(dtype=wp.vec3),
  mesh_face: wp.array(dtype=wp.vec3i),
  # In:
  data_id: int,
  pos: wp.vec3,
  mat: wp.mat33,
  pnt: wp.vec3,
  vec: wp.vec3,
) -> float:
  """Returns the distance and geomid for ray mesh intersections."""
  pnt, vec = _ray_map(pos, mat, pnt, vec)

  # compute orthogonal basis vectors
  if wp.abs(vec[0]) < wp.abs(vec[1]):
    if wp.abs(vec[0]) < wp.abs(vec[2]):
      b0 = wp.vec3(0.0, vec[2], -vec[1])
    else:
      b0 = wp.vec3(vec[1], -vec[0], 0.0)
  else:
    if wp.abs(vec[1]) < wp.abs(vec[2]):
      b0 = wp.vec3(-vec[2], 0.0, vec[0])
    else:
      b0 = wp.vec3(vec[1], -vec[0], 0.0)

  # normalize first vector
  b0 = wp.normalize(b0)

  # compute second vector as cross product
  b1 = wp.cross(vec, b0)
  b1 = wp.normalize(b1)

  min_dist = float(wp.inf)

  # get mesh vertex data range
  vert_start = mesh_vertadr[data_id]

  # get mesh face and vertex data
  face_start = mesh_faceadr[data_id]

  if data_id + 1 < mesh_faceadr.shape[0]:
    face_end = mesh_faceadr[data_id + 1]
  else:
    face_end = nmeshface

  # iterate through all faces
  for i in range(face_start, face_end):
    # get vertices for this face
    v_idx = mesh_face[i]

    # create triangle struct
    v0 = mesh_vert[vert_start + v_idx.x]
    v1 = mesh_vert[vert_start + v_idx.y]
    v2 = mesh_vert[vert_start + v_idx.z]

    # calculate intersection
    dist = _ray_triangle(v0, v1, v2, pnt, vec, b0, b1)
    if dist < min_dist:
      min_dist = dist

  return min_dist


@wp.func
def ray_geom(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3, geomtype: int) -> float:
  """Returns distance along ray to intersection with geom, or infinity if none."""
  # TODO(team): static loop unrolling to remove unnecessary branching
  if geomtype == GeomType.PLANE:
    return _ray_plane(pos, mat, size, pnt, vec)
  elif geomtype == GeomType.SPHERE:
    return _ray_sphere(pos, size[0] * size[0], pnt, vec)
  elif geomtype == GeomType.CAPSULE:
    return _ray_capsule(pos, mat, size, pnt, vec)
  elif geomtype == GeomType.ELLIPSOID:
    return _ray_ellipsoid(pos, mat, size, pnt, vec)
  elif geomtype == GeomType.CYLINDER:
    return _ray_cylinder(pos, mat, size, pnt, vec)
  elif geomtype == GeomType.BOX:
    dist, _ = _ray_box(pos, mat, size, pnt, vec)
    return dist
  else:
    return wp.inf


@wp.func
def _ray_geom_mesh(
  # Model:
  nmeshface: int,
  body_weldid: wp.array(dtype=int),
  geom_type: wp.array(dtype=int),
  geom_bodyid: wp.array(dtype=int),
  geom_dataid: wp.array(dtype=int),
  geom_matid: wp.array2d(dtype=int),
  geom_group: wp.array(dtype=int),
  geom_size: wp.array2d(dtype=wp.vec3),
  geom_rgba: wp.array2d(dtype=wp.vec4),
  mesh_vertadr: wp.array(dtype=int),
  mesh_faceadr: wp.array(dtype=int),
  mesh_vert: wp.array(dtype=wp.vec3),
  mesh_face: wp.array(dtype=wp.vec3i),
  hfield_size: wp.array(dtype=wp.vec4),
  hfield_nrow: wp.array(dtype=int),
  hfield_ncol: wp.array(dtype=int),
  hfield_adr: wp.array(dtype=int),
  hfield_data: wp.array(dtype=float),
  mat_rgba: wp.array2d(dtype=wp.vec4),
  # Data in:
  geom_xpos_in: wp.array2d(dtype=wp.vec3),
  geom_xmat_in: wp.array2d(dtype=wp.mat33),
  # In:
  worldid: int,
  pnt: wp.vec3,
  vec: wp.vec3,
  geomgroup: vec6,
  flg_static: bool,
  bodyexclude: int,
  geomid: int,
) -> float:
  if not _ray_eliminate(
    body_weldid,
    geom_bodyid,
    geom_matid[worldid % geom_matid.shape[0]],
    geom_group,
    geom_rgba[worldid % geom_rgba.shape[0]],
    mat_rgba[worldid % mat_rgba.shape[0]],
    geomid,
    geomgroup,
    flg_static,
    bodyexclude,
  ):
    pos = geom_xpos_in[worldid, geomid]
    mat = geom_xmat_in[worldid, geomid]
    type = geom_type[geomid]

    if type == GeomType.MESH:
      return ray_mesh(
        nmeshface,
        mesh_vertadr,
        mesh_faceadr,
        mesh_vert,
        mesh_face,
        geom_dataid[geomid],
        pos,
        mat,
        pnt,
        vec,
      )
    elif type == GeomType.HFIELD:
      return _ray_hfield(
        geom_type,
        geom_dataid,
        hfield_size,
        hfield_nrow,
        hfield_ncol,
        hfield_adr,
        hfield_data,
        pos,
        mat,
        pnt,
        vec,
        geomid,
      )
    else:
      return ray_geom(pos, mat, geom_size[worldid % geom_size.shape[0], geomid], pnt, vec, type)
  else:
    return wp.inf


@wp.kernel
def _ray(
  # Model:
  ngeom: int,
  nmeshface: int,
  body_weldid: wp.array(dtype=int),
  geom_type: wp.array(dtype=int),
  geom_bodyid: wp.array(dtype=int),
  geom_dataid: wp.array(dtype=int),
  geom_matid: wp.array2d(dtype=int),
  geom_group: wp.array(dtype=int),
  geom_size: wp.array2d(dtype=wp.vec3),
  geom_rgba: wp.array2d(dtype=wp.vec4),
  mesh_vertadr: wp.array(dtype=int),
  mesh_faceadr: wp.array(dtype=int),
  mesh_vert: wp.array(dtype=wp.vec3),
  mesh_face: wp.array(dtype=wp.vec3i),
  hfield_size: wp.array(dtype=wp.vec4),
  hfield_nrow: wp.array(dtype=int),
  hfield_ncol: wp.array(dtype=int),
  hfield_adr: wp.array(dtype=int),
  hfield_data: wp.array(dtype=float),
  mat_rgba: wp.array2d(dtype=wp.vec4),
  # Data in:
  geom_xpos_in: wp.array2d(dtype=wp.vec3),
  geom_xmat_in: wp.array2d(dtype=wp.mat33),
  # In:
  pnt: wp.array2d(dtype=wp.vec3),
  vec: wp.array2d(dtype=wp.vec3),
  geomgroup: vec6,
  flg_static: bool,
  bodyexclude: wp.array(dtype=int),
  # Out:
  dist_out: wp.array(dtype=float, ndim=2),
  geomid_out: wp.array(dtype=int, ndim=2),
):
  worldid, rayid, tid = wp.tid()

  num_threads = wp.block_dim()

  min_dist = float(wp.inf)
  min_geomid = int(-1)

  upper = ((ngeom + num_threads - 1) // num_threads) * num_threads
  for geomid in range(tid, upper, num_threads):
    if geomid < ngeom:
      dist = _ray_geom_mesh(
        nmeshface,
        body_weldid,
        geom_type,
        geom_bodyid,
        geom_dataid,
        geom_matid,
        geom_group,
        geom_size,
        geom_rgba,
        mesh_vertadr,
        mesh_faceadr,
        mesh_vert,
        mesh_face,
        hfield_size,
        hfield_nrow,
        hfield_ncol,
        hfield_adr,
        hfield_data,
        mat_rgba,
        geom_xpos_in,
        geom_xmat_in,
        worldid,
        pnt[worldid, rayid],
        vec[worldid, rayid],
        geomgroup,
        flg_static,
        bodyexclude[rayid],
        geomid,
      )
    else:
      dist = wp.inf

    tile_dist = wp.tile(dist)
    local_min_geomid = wp.tile_argmin(tile_dist)
    local_min_dist = tile_dist[local_min_geomid[0]]

    tile_geomid = wp.tile(geomid)

    if local_min_dist < min_dist:
      min_dist = local_min_dist
      min_geomid = tile_geomid[local_min_geomid[0]]

  if wp.isinf(min_dist):
    dist_out[worldid, rayid] = -1.0
  else:
    dist_out[worldid, rayid] = min_dist
  geomid_out[worldid, rayid] = min_geomid


[docs] def ray( m: Model, d: Data, pnt: wp.array2d(dtype=wp.vec3), vec: wp.array2d(dtype=wp.vec3), geomgroup: Optional[vec6] = None, flg_static: bool = True, bodyexclude: int = -1, ) -> Tuple[wp.array, wp.array]: """Returns the distance at which rays intersect with primitive geoms. Args: m: The model containing kinematic and dynamic information (device). d: The data object containing the current state and output arrays (device). pnt: Ray origin points. vec: Ray directions. geomgroup: Group inclusion/exclusion mask. If all are wp.inf, ignore. flg_static: If True, allows rays to intersect with static geoms. bodyexclude: Ignore geoms on specified body id (-1 to disable). Returns: Distances from ray origins to geom surfaces and IDs of intersected geoms (-1 if none). """ assert pnt.shape[0] == 1 assert pnt.shape[0] == vec.shape[0] if geomgroup is None: geomgroup = vec6(-1, -1, -1, -1, -1, -1) ray_bodyexclude = wp.empty(1, dtype=int) ray_bodyexclude.fill_(bodyexclude) ray_dist = wp.empty((d.nworld, 1), dtype=float) ray_geomid = wp.empty((d.nworld, 1), dtype=int) rays(m, d, pnt, vec, geomgroup, flg_static, ray_bodyexclude, ray_dist, ray_geomid) return ray_dist, ray_geomid
def rays( m: Model, d: Data, pnt: wp.array2d(dtype=wp.vec3), vec: wp.array2d(dtype=wp.vec3), geomgroup: vec6, flg_static: bool, bodyexclude: wp.array(dtype=int), dist: wp.array2d(dtype=wp.vec3), geomid: wp.array2d(dtype=int), ): wp.launch_tiled( _ray, dim=(d.nworld, pnt.shape[1]), inputs=[ m.ngeom, m.nmeshface, m.body_weldid, m.geom_type, m.geom_bodyid, m.geom_dataid, m.geom_matid, m.geom_group, m.geom_size, m.geom_rgba, m.mesh_vertadr, m.mesh_faceadr, m.mesh_vert, m.mesh_face, m.hfield_size, m.hfield_nrow, m.hfield_ncol, m.hfield_adr, m.hfield_data, m.mat_rgba, d.geom_xpos, d.geom_xmat, pnt, vec, geomgroup, flg_static, bodyexclude, dist, geomid, ], block_dim=m.block_dim.ray, )