Source code for mujoco_warp._src.render

# Copyright 2026 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Tuple

import warp as wp

from mujoco_warp._src import math
from mujoco_warp._src.ray import ray_box
from mujoco_warp._src.ray import ray_capsule
from mujoco_warp._src.ray import ray_cylinder
from mujoco_warp._src.ray import ray_ellipsoid
from mujoco_warp._src.ray import ray_flex_with_bvh
from mujoco_warp._src.ray import ray_flex_with_bvh_anyhit
from mujoco_warp._src.ray import ray_mesh_with_bvh
from mujoco_warp._src.ray import ray_mesh_with_bvh_anyhit
from mujoco_warp._src.ray import ray_plane
from mujoco_warp._src.ray import ray_sphere
from mujoco_warp._src.render_util import compute_ray
from mujoco_warp._src.render_util import pack_rgba_to_uint32
from mujoco_warp._src.types import MJ_MAXVAL
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import GeomType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import ObjType
from mujoco_warp._src.types import RenderContext
from mujoco_warp._src.warp_util import event_scope

wp.set_module_options({"enable_backward": False})


@wp.func
def sample_texture(
  # Model:
  geom_type: wp.array[int],
  mesh_faceadr: wp.array[int],
  # In:
  geom_id: int,
  tex_repeat: wp.vec2,
  tex: wp.Texture2D,
  pos: wp.vec3,
  rot: wp.mat33,
  mesh_facetexcoord: wp.array[wp.vec3i],
  mesh_texcoord: wp.array[wp.vec2],
  mesh_texcoord_offsets: wp.array[int],
  hit_point: wp.vec3,
  bary_u: float,
  bary_v: float,
  f: int,
  mesh_id: int,
) -> wp.vec3:
  uv = wp.vec2(0.0, 0.0)

  if geom_type[geom_id] == GeomType.PLANE:
    local = wp.transpose(rot) @ (hit_point - pos)
    uv = wp.vec2(local[0], local[1])

  if geom_type[geom_id] == GeomType.MESH:
    if f < 0 or mesh_id < 0:
      return wp.vec3(0.0, 0.0, 0.0)

    face_adr = mesh_faceadr[mesh_id] + f
    uv0 = mesh_texcoord[mesh_texcoord_offsets[mesh_id] + mesh_facetexcoord[face_adr][0]]
    uv1 = mesh_texcoord[mesh_texcoord_offsets[mesh_id] + mesh_facetexcoord[face_adr][1]]
    uv2 = mesh_texcoord[mesh_texcoord_offsets[mesh_id] + mesh_facetexcoord[face_adr][2]]
    uv = uv0 * bary_u + uv1 * bary_v + uv2 * (1.0 - bary_u - bary_v)

  u = uv[0] * tex_repeat[0]
  v = uv[1] * tex_repeat[1]
  u = u - wp.floor(u)
  v = v - wp.floor(v)
  tex_color = wp.texture_sample(tex, wp.vec2(u, v), dtype=wp.vec4)
  return wp.vec3(tex_color[0], tex_color[1], tex_color[2])


@wp.func
def sample_skybox(
  # In:
  skybox_tex: wp.Texture2D,
  face_width_inv: float,
  ray_dir_world: wp.vec3,
) -> wp.vec3:
  # MuJoCo maps a world-space direction to cube-map space by rotating 90° about X
  # (see render_gl3.c: S=x, T=z, R=-y). Faces in tex_data are stacked vertically
  # in OpenGL cube-face order: +X, -X, +Y, -Y, +Z, -Z.
  rx = ray_dir_world[0]
  ry = ray_dir_world[2]
  rz = -ray_dir_world[1]

  arx = wp.abs(rx)
  ary = wp.abs(ry)
  arz = wp.abs(rz)

  face = int(0)
  sc = float(0.0)
  tc = float(0.0)
  ma = float(1.0)

  if arx >= ary and arx >= arz:
    ma = arx
    if rx > 0.0:
      face = 0
      sc = -rz
      tc = -ry
    else:
      face = 1
      sc = rz
      tc = -ry
  elif ary >= arz:
    ma = ary
    if ry > 0.0:
      face = 2
      sc = rx
      tc = rz
    else:
      face = 3
      sc = rx
      tc = -rz
  else:
    ma = arz
    if rz > 0.0:
      face = 4
      sc = rx
      tc = -ry
    else:
      face = 5
      sc = -rx
      tc = -ry

  s = (math.safe_div(sc, ma) + 1.0) * 0.5
  t = (math.safe_div(tc, ma) + 1.0) * 0.5

  # Keep the linear filter from bleeding between adjacent faces in the vertical strip.
  t_min = 0.5 * face_width_inv
  t = wp.clamp(t, t_min, 1.0 - t_min)

  v = (float(face) + t) * wp.static(1.0 / 6.0)
  color = wp.texture_sample(skybox_tex, wp.vec2(s, v), dtype=wp.vec4)
  return wp.vec3(color[0], color[1], color[2])


def _make_cast_ray(geom_ray_types: Tuple[int], first_hit: bool = False) -> wp.Function:
  """Build a ray-cast func specialized to the geom types present in the scene.

  geom_ray_types is the set of GeomType int values that actually occur, so the
  per-type intersection branches for absent types are eliminated at compile time
  via wp.static, avoiding the register pressure of unreachable code paths.

  first_hit selects the variant (also resolved at compile time via wp.static):
    - False: full closest-hit cast. Returns the closest hit's full surface data.
    - True: any-hit cast (shadow rays). Uses the cheaper any-hit mesh/flex
      intersections and returns on the first hit within max_dist. The result is
      still the full tuple; callers test geom_id != -1 to detect a hit.
  """

  @wp.func
  def cast_ray(
    # Model:
    geom_type: wp.array[int],
    geom_dataid: wp.array2d[int],
    geom_size: wp.array2d[wp.vec3],
    flex_vertadr: wp.array[int],
    flex_edge: wp.array[wp.vec2i],
    flex_radius: wp.array[float],
    # Data in:
    geom_xpos_in: wp.array2d[wp.vec3],
    geom_xmat_in: wp.array2d[wp.mat33],
    flexvert_xpos_in: wp.array2d[wp.vec3],
    # In:
    bvh_id: wp.uint64,
    group_root: int,
    worldid: int,
    bvh_ngeom: int,
    flex_bvh_ngeom: int,
    enabled_geom_ids: wp.array[int],
    mesh_bvh_id: wp.array[wp.uint64],
    hfield_bvh_id: wp.array[wp.uint64],
    flex_geom_flexid: wp.array[int],
    flex_geom_edgeid: wp.array[int],
    flex_bvh_id: wp.array[wp.uint64],
    flex_group_root: wp.array2d[int],
    ray_origin_world: wp.vec3,
    ray_dir_world: wp.vec3,
    max_dist: float,
    cull_backfaces: bool,
  ) -> Tuple[int, float, wp.vec3, float, float, int, int]:
    dist = max_dist
    normal = wp.vec3(0.0, 0.0, 0.0)
    geom_id = int(-1)
    bary_u = float(0.0)
    bary_v = float(0.0)
    face_idx = int(-1)
    geom_mesh_id = int(-1)

    query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root)
    bounds_nr = int(0)
    ngeom = bvh_ngeom + flex_bvh_ngeom

    while wp.bvh_query_next(query, bounds_nr, dist):
      gi_global = bounds_nr
      local_id = gi_global - (worldid * ngeom)

      d = float(-1.0)
      hit_mesh_id = int(-1)
      u = float(0.0)
      v = float(0.0)
      f = int(-1)
      n = wp.vec3(0.0, 0.0, 0.0)
      hit_geom_id = int(-1)

      if local_id < bvh_ngeom:
        gi = enabled_geom_ids[local_id]
        gtype = geom_type[gi]
      else:
        gi = local_id - bvh_ngeom
        gtype = GeomType.FLEX

      hit_geom_id = gi

      if wp.static(int(GeomType.PLANE) in geom_ray_types):
        if gtype == GeomType.PLANE:
          d, n = ray_plane(
            geom_xpos_in[worldid, gi],
            geom_xmat_in[worldid, gi],
            geom_size[worldid % geom_size.shape[0], gi],
            ray_origin_world,
            ray_dir_world,
          )
      if wp.static(int(GeomType.HFIELD) in geom_ray_types):
        if gtype == GeomType.HFIELD:
          d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh(
            hfield_bvh_id,
            geom_dataid[worldid % geom_dataid.shape[0], gi],
            geom_xpos_in[worldid, gi],
            geom_xmat_in[worldid, gi],
            ray_origin_world,
            ray_dir_world,
            dist,
            cull_backfaces,
          )
      if wp.static(int(GeomType.SPHERE) in geom_ray_types):
        if gtype == GeomType.SPHERE:
          d, n = ray_sphere(
            geom_xpos_in[worldid, gi],
            geom_size[worldid % geom_size.shape[0], gi][0] * geom_size[worldid % geom_size.shape[0], gi][0],
            ray_origin_world,
            ray_dir_world,
          )
      if wp.static(int(GeomType.ELLIPSOID) in geom_ray_types):
        if gtype == GeomType.ELLIPSOID:
          d, n = ray_ellipsoid(
            geom_xpos_in[worldid, gi],
            geom_xmat_in[worldid, gi],
            geom_size[worldid % geom_size.shape[0], gi],
            ray_origin_world,
            ray_dir_world,
          )
      if wp.static(int(GeomType.CAPSULE) in geom_ray_types):
        if gtype == GeomType.CAPSULE:
          d, n = ray_capsule(
            geom_xpos_in[worldid, gi],
            geom_xmat_in[worldid, gi],
            geom_size[worldid % geom_size.shape[0], gi],
            ray_origin_world,
            ray_dir_world,
          )
      if wp.static(int(GeomType.CYLINDER) in geom_ray_types):
        if gtype == GeomType.CYLINDER:
          d, n = ray_cylinder(
            geom_xpos_in[worldid, gi],
            geom_xmat_in[worldid, gi],
            geom_size[worldid % geom_size.shape[0], gi],
            ray_origin_world,
            ray_dir_world,
          )
      if wp.static(int(GeomType.BOX) in geom_ray_types):
        if gtype == GeomType.BOX:
          d, all, n = ray_box(
            geom_xpos_in[worldid, gi],
            geom_xmat_in[worldid, gi],
            geom_size[worldid % geom_size.shape[0], gi],
            ray_origin_world,
            ray_dir_world,
          )
      if wp.static(int(GeomType.MESH) in geom_ray_types):
        if gtype == GeomType.MESH:
          if wp.static(first_hit):
            hit = ray_mesh_with_bvh_anyhit(
              mesh_bvh_id,
              geom_dataid[worldid % geom_dataid.shape[0], gi],
              geom_xpos_in[worldid, gi],
              geom_xmat_in[worldid, gi],
              ray_origin_world,
              ray_dir_world,
              dist,
            )
            d = 0.0 if hit else -1.0
          else:
            d, n, u, v, f, hit_mesh_id = ray_mesh_with_bvh(
              mesh_bvh_id,
              geom_dataid[worldid % geom_dataid.shape[0], gi],
              geom_xpos_in[worldid, gi],
              geom_xmat_in[worldid, gi],
              ray_origin_world,
              ray_dir_world,
              dist,
              cull_backfaces,
            )
      if wp.static(int(GeomType.FLEX) in geom_ray_types):
        if gtype == GeomType.FLEX:
          hit_geom_id = -2
          flexid = flex_geom_flexid[gi]
          edge_id = flex_geom_edgeid[gi]

          if edge_id >= 0:
            edge = flex_edge[edge_id]
            vert_adr = flex_vertadr[flexid]
            v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]]
            v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]]
            pos = 0.5 * (v0 + v1)
            vec = v1 - v0

            length = wp.length(vec)
            edgeq = math.quat_z2vec(vec)
            mat = math.quat_to_mat(edgeq)
            size = wp.vec3(flex_radius[flexid], 0.5 * length, 0.0)

            d, n = ray_capsule(pos, mat, size, ray_origin_world, ray_dir_world)
            hit_mesh_id = flexid
          else:
            if wp.static(first_hit):
              hit = ray_flex_with_bvh_anyhit(
                flex_bvh_id,
                flexid,
                flex_group_root[worldid, flexid],
                ray_origin_world,
                ray_dir_world,
                dist,
              )
              d = 0.0 if hit else -1.0
            else:
              flex_gr = flex_group_root[worldid, flexid]
              d, n, u, v, f = ray_flex_with_bvh(flex_bvh_id, flexid, flex_gr, ray_origin_world, ray_dir_world, dist)
              if d >= 0.0:
                hit_mesh_id = flexid

      # Backface cull: drop exit-face hits when the ray origin is inside the geom,
      # matching ray_mesh_with_bvh's `dot(lvec, n) < 0` rule. Strict `> 0` keeps
      # tangent hits and skips branches with a zero-vector normal (any-hit).
      if cull_backfaces and d >= 0.0 and wp.dot(ray_dir_world, n) > 0.0:
        d = -1.0

      if wp.static(first_hit):
        # Any-hit: return as soon as anything is in range; surface data is unused.
        if d >= 0.0 and d < dist:
          return hit_geom_id, d, n, u, v, f, hit_mesh_id
      else:
        if d >= 0.0 and d < dist:
          dist = d
          normal = n
          geom_id = hit_geom_id
          bary_u = u
          bary_v = v
          face_idx = f
          geom_mesh_id = hit_mesh_id

    return geom_id, dist, normal, bary_u, bary_v, face_idx, geom_mesh_id

  return cast_ray


def _make_compute_lighting(cast_ray_first_hit: wp.Function) -> wp.Function:
  """Build specialized compute_lighting."""

  @wp.func
  def compute_lighting(
    # Model:
    geom_type: wp.array[int],
    geom_dataid: wp.array2d[int],
    geom_size: wp.array2d[wp.vec3],
    flex_vertadr: wp.array[int],
    flex_edge: wp.array[wp.vec2i],
    flex_radius: wp.array[float],
    # Data in:
    geom_xpos_in: wp.array2d[wp.vec3],
    geom_xmat_in: wp.array2d[wp.mat33],
    flexvert_xpos_in: wp.array2d[wp.vec3],
    # In:
    use_shadows: bool,
    bvh_id: wp.uint64,
    group_root: int,
    bvh_ngeom: int,
    bvh_nflexgeom: int,
    enabled_geom_ids: wp.array[int],
    worldid: int,
    mesh_bvh_id: wp.array[wp.uint64],
    hfield_bvh_id: wp.array[wp.uint64],
    flex_geom_flexid: wp.array[int],
    flex_geom_edgeid: wp.array[int],
    flex_bvh_id: wp.array[wp.uint64],
    flex_group_root: wp.array2d[int],
    lightactive: bool,
    lighttype: int,
    lightcastshadow: bool,
    lightpos: wp.vec3,
    lightdir: wp.vec3,
    normal: wp.vec3,
    hitpoint: wp.vec3,
    cull_backfaces: bool,
  ) -> float:
    light_contribution = float(0.0)

    # TODO: We should probably only be looping over active lights
    # in the first place with a static loop of enabled light idx?
    if not lightactive:
      return light_contribution

    L = wp.vec3(0.0, 0.0, 0.0)
    dist_to_light = float(MJ_MAXVAL)
    attenuation = float(1.0)

    if lighttype == 1:  # directional light
      L = wp.normalize(-lightdir)
    else:
      L, dist_to_light = math.normalize_with_norm(lightpos - hitpoint)
      attenuation = 1.0 / (1.0 + 0.02 * dist_to_light * dist_to_light)
      if lighttype == 0:  # spot light
        spot_dir = wp.normalize(lightdir)
        cos_theta = wp.dot(-L, spot_dir)
        spot_factor = wp.min(1.0, wp.max(0.0, (cos_theta - 0.85) * 10.0))
        attenuation = attenuation * spot_factor

    ndotl = wp.max(0.0, wp.dot(normal, L))
    if ndotl == 0.0:
      return light_contribution

    visible = float(1.0)

    if use_shadows and lightcastshadow:
      # Nudge the origin slightly along the surface normal to avoid
      # self-intersection when casting shadow rays
      shadow_origin = hitpoint + normal * 1.0e-4
      # Distance-limited shadows: cap by dist_to_light (for non-directional)
      max_t = dist_to_light - 1.0e-3
      if lighttype == 1:  # directional light
        max_t = 1.0e8

      shadow_geom_id, shadow_d, shadow_n, shadow_u, shadow_v, shadow_f, shadow_mesh_id = cast_ray_first_hit(
        geom_type,
        geom_dataid,
        geom_size,
        flex_vertadr,
        flex_edge,
        flex_radius,
        geom_xpos_in,
        geom_xmat_in,
        flexvert_xpos_in,
        bvh_id,
        group_root,
        worldid,
        bvh_ngeom,
        bvh_nflexgeom,
        enabled_geom_ids,
        mesh_bvh_id,
        hfield_bvh_id,
        flex_geom_flexid,
        flex_geom_edgeid,
        flex_bvh_id,
        flex_group_root,
        shadow_origin,
        L,
        max_t,
        cull_backfaces,
      )

      if shadow_geom_id != -1:
        visible = 0.3

    return ndotl * attenuation * visible

  return compute_lighting


[docs] @event_scope def render(m: Model, d: Data, rc: RenderContext): """Render the current frame. Outputs are stored in buffers within the render context. Args: m: The model on device. d: The data on device. rc: The render context on device. """ rc.rgb_data.fill_(rc.background_color) rc.depth_data.fill_(0.0) rc.seg_data.fill_(wp.vec2i(-1, -1)) # Specialize the ray-cast helpers to the geom types present in the scene so the # compiler eliminates intersection branches for absent types. geom_ray_types = rc.geom_ray_types cast_ray = _make_cast_ray(geom_ray_types, first_hit=False) cast_ray_first_hit = _make_cast_ray(geom_ray_types, first_hit=True) compute_lighting = _make_compute_lighting(cast_ray_first_hit) @wp.kernel(module="unique", enable_backward=False) def _render_megakernel( # Model: geom_type: wp.array[int], geom_dataid: wp.array2d[int], geom_matid: wp.array2d[int], geom_size: wp.array2d[wp.vec3], geom_rgba: wp.array2d[wp.vec4], cam_projection: wp.array[int], cam_fovy: wp.array2d[float], cam_sensorsize: wp.array[wp.vec2], cam_intrinsic: wp.array2d[wp.vec4], light_type: wp.array2d[int], light_castshadow: wp.array2d[bool], light_active: wp.array2d[bool], flex_vertadr: wp.array[int], flex_edge: wp.array[wp.vec2i], flex_radius: wp.array[float], mesh_faceadr: wp.array[int], mat_texid: wp.array3d[int], mat_texrepeat: wp.array2d[wp.vec2], mat_rgba: wp.array2d[wp.vec4], # Data in: geom_xpos_in: wp.array2d[wp.vec3], geom_xmat_in: wp.array2d[wp.mat33], cam_xpos_in: wp.array2d[wp.vec3], cam_xmat_in: wp.array2d[wp.mat33], light_xpos_in: wp.array2d[wp.vec3], light_xdir_in: wp.array2d[wp.vec3], flexvert_xpos_in: wp.array2d[wp.vec3], # In: nrender: int, use_shadows: bool, bvh_ngeom: int, bvh_nflexgeom: int, cam_res: wp.array[wp.vec2i], cam_id_map: wp.array[int], ray: wp.array[wp.vec3], rgb_adr: wp.array[int], depth_adr: wp.array[int], seg_adr: wp.array[int], render_rgb: wp.array[bool], render_depth: wp.array[bool], render_seg: wp.array[bool], bvh_id: wp.uint64, group_root: wp.array[int], flex_bvh_id: wp.array[wp.uint64], flex_group_root: wp.array2d[int], enabled_geom_ids: wp.array[int], mesh_bvh_id: wp.array[wp.uint64], mesh_facetexcoord: wp.array[wp.vec3i], mesh_texcoord: wp.array[wp.vec2], mesh_texcoord_offsets: wp.array[int], hfield_bvh_id: wp.array[wp.uint64], flex_rgba: wp.array[wp.vec4], flex_geom_flexid: wp.array[int], flex_geom_edgeid: wp.array[int], textures: wp.array[wp.Texture2D], # Out: rgb_out: wp.array2d[wp.uint32], depth_out: wp.array2d[float], seg_out: wp.array2d[wp.vec2i], ): worldid, rayid = wp.tid() # Map global rayid -> (camid, rayid_local) using cumulative sizes camid = int(-1) rayid_local = int(-1) accum = int(0) for i in range(nrender): num_i = cam_res[i][0] * cam_res[i][1] if rayid < accum + num_i: camid = i rayid_local = rayid - accum break accum += num_i if camid == -1 or rayid_local < 0: return if not render_rgb[camid] and not render_depth[camid] and not render_seg[camid]: return # Map active camera index to MuJoCo camera ID mujoco_cam_id = cam_id_map[camid] if wp.static(rc.use_precomputed_rays): ray_dir_local_cam = ray[rayid] else: img_w = cam_res[camid][0] img_h = cam_res[camid][1] px = rayid_local % img_w py = rayid_local // img_w ray_dir_local_cam = compute_ray( cam_projection[mujoco_cam_id], cam_fovy[worldid % cam_fovy.shape[0], mujoco_cam_id], cam_sensorsize[mujoco_cam_id], cam_intrinsic[worldid % cam_intrinsic.shape[0], mujoco_cam_id], img_w, img_h, px, py, wp.static(rc.znear), ) ray_dir_world = cam_xmat_in[worldid, mujoco_cam_id] @ ray_dir_local_cam ray_origin_world = cam_xpos_in[worldid, mujoco_cam_id] geom_id, dist, normal, u, v, f, mesh_id = cast_ray( geom_type, geom_dataid, geom_size, flex_vertadr, flex_edge, flex_radius, geom_xpos_in, geom_xmat_in, flexvert_xpos_in, bvh_id, group_root[worldid], worldid, bvh_ngeom, bvh_nflexgeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, flex_geom_flexid, flex_geom_edgeid, flex_bvh_id, flex_group_root, ray_origin_world, ray_dir_world, float(MJ_MAXVAL), wp.static(rc.enable_backface_culling), ) if render_seg[camid] and geom_id != -1: if geom_id == -2: seg_out[worldid, seg_adr[camid] + rayid_local] = wp.vec2i(mesh_id, int(ObjType.FLEX)) else: seg_out[worldid, seg_adr[camid] + rayid_local] = wp.vec2i(geom_id, int(ObjType.GEOM)) # Early Out if geom_id == -1: if wp.static(rc.render_skybox) and render_rgb[camid]: skybox_color = sample_skybox( textures[wp.static(rc.skybox_tex_id)], wp.static(1.0 / float(rc.skybox_face_width)), ray_dir_world, ) rgb_out[worldid, rgb_adr[camid] + rayid_local] = pack_rgba_to_uint32( skybox_color[0] * 255.0, skybox_color[1] * 255.0, skybox_color[2] * 255.0, 255.0, ) return if render_depth[camid]: # Planar depth: project Euclidean distance onto the camera's optical axis. # In camera-local coordinates, the optical axis is -Z. The Z-component of the # normalized ray direction is negative, so -ray_dir_local_cam[2] gives cos(θ) # between the ray and the optical axis. depth_out[worldid, depth_adr[camid] + rayid_local] = dist * (-ray_dir_local_cam[2]) if not render_rgb[camid]: return # Shade the pixel hit_point = ray_origin_world + ray_dir_world * dist if geom_id == -2: # We encode flex_id in mesh_id for flex ray hits during cast_ray color = flex_rgba[mesh_id] elif geom_matid[worldid % geom_matid.shape[0], geom_id] == -1: color = geom_rgba[worldid % geom_rgba.shape[0], geom_id] else: color = mat_rgba[worldid % mat_rgba.shape[0], geom_matid[worldid % geom_matid.shape[0], geom_id]] base_color = wp.vec3(color[0], color[1], color[2]) hit_color = base_color if wp.static(rc.use_textures): if geom_id != -2: mat_id = geom_matid[worldid % geom_matid.shape[0], geom_id] if mat_id >= 0: tex_id = mat_texid[worldid % mat_texid.shape[0], mat_id, 1] if tex_id >= 0: tex_color = sample_texture( geom_type, mesh_faceadr, geom_id, mat_texrepeat[worldid % mat_texrepeat.shape[0], mat_id], textures[tex_id], geom_xpos_in[worldid, geom_id], geom_xmat_in[worldid, geom_id], mesh_facetexcoord, mesh_texcoord, mesh_texcoord_offsets, hit_point, u, v, f, mesh_id, ) base_color = wp.cw_mul(base_color, tex_color) result = wp.vec3(0.0, 0.0, 0.0) if wp.static(rc.use_ambient_lighting): len_n = wp.length(normal) n = normal if len_n > 0.0 else wp.vec3(0.0, 0.0, 1.0) n = wp.normalize(n) hemispheric = 0.5 * (n[2] + 1.0) ambient_color = wp.vec3(0.4, 0.4, 0.45) * hemispheric + wp.vec3(0.1, 0.1, 0.12) * (1.0 - hemispheric) result = 0.5 * wp.cw_mul(base_color, ambient_color) # Apply lighting and shadows for l in range(wp.static(m.nlight)): light_contribution = compute_lighting( geom_type, geom_dataid, geom_size, flex_vertadr, flex_edge, flex_radius, geom_xpos_in, geom_xmat_in, flexvert_xpos_in, use_shadows, bvh_id, group_root[worldid], bvh_ngeom, bvh_nflexgeom, enabled_geom_ids, worldid, mesh_bvh_id, hfield_bvh_id, flex_geom_flexid, flex_geom_edgeid, flex_bvh_id, flex_group_root, light_active[worldid % light_active.shape[0], l], light_type[worldid % light_type.shape[0], l], light_castshadow[worldid % light_castshadow.shape[0], l], light_xpos_in[worldid, l], light_xdir_in[worldid, l], normal, hit_point, wp.static(rc.enable_backface_culling), ) result = result + base_color * light_contribution hit_color = wp.min(result, wp.vec3(1.0, 1.0, 1.0)) hit_color = wp.max(hit_color, wp.vec3(0.0, 0.0, 0.0)) rgb_out[worldid, rgb_adr[camid] + rayid_local] = pack_rgba_to_uint32( hit_color[0] * 255.0, hit_color[1] * 255.0, hit_color[2] * 255.0, 255.0, ) wp.launch( kernel=_render_megakernel, dim=(d.nworld, rc.total_rays), inputs=[ m.geom_type, m.geom_dataid, m.geom_matid, m.geom_size, m.geom_rgba, m.cam_projection, m.cam_fovy, m.cam_sensorsize, m.cam_intrinsic, m.light_type, m.light_castshadow, m.light_active, m.flex_vertadr, m.flex_edge, m.flex_radius, m.mesh_faceadr, m.mat_texid, m.mat_texrepeat, m.mat_rgba, d.geom_xpos, d.geom_xmat, d.cam_xpos, d.cam_xmat, d.light_xpos, d.light_xdir, d.flexvert_xpos, rc.nrender, rc.use_shadows, rc.bvh_ngeom, rc.bvh_nflexgeom, rc.cam_res, rc.cam_id_map, rc.ray, rc.rgb_adr, rc.depth_adr, rc.seg_adr, rc.render_rgb, rc.render_depth, rc.render_seg, rc.bvh_id, rc.group_root, rc.flex_bvh_id, rc.flex_group_root, rc.enabled_geom_ids, rc.mesh_bvh_id, rc.mesh_facetexcoord, rc.mesh_texcoord, rc.mesh_texcoord_offsets, rc.hfield_bvh_id, rc.flex_rgba, rc.flex_geom_flexid, rc.flex_geom_edgeid, rc.textures, ], outputs=[ rc.rgb_data, rc.depth_data, rc.seg_data, ], block_dim=m.block_dim.render, )