Skip to content

Commit

Permalink
Add MJX predictive control, a bimanual hand-over task, and a visualiz…
Browse files Browse the repository at this point in the history
…ation colab.

PiperOrigin-RevId: 609711577
Change-Id: I11b348d9d94e5b7ba953d855969e751de936749e
  • Loading branch information
erez-tom authored and copybara-github committed Feb 23, 2024
1 parent 063ff93 commit 5311e3e
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 3 deletions.
179 changes: 179 additions & 0 deletions python/mujoco_mpc/mjx/predictive_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Predictive sampling for MPC."""

import time
from typing import Callable, Tuple

import jax
from jax import numpy as jnp
import mujoco
from mujoco import mjx
from mujoco.mjx._src import dataclasses

CostFn = Callable[[mjx.Model, mjx.Data], jax.Array]


class Planner(dataclasses.PyTreeNode):
"""Predictive sampling planner.
Attributes:
model: MuJoCo model
cost: function returning per-timestep cost
noise_scale: standard deviation of zero-mean Gaussian
horizon: planning duration (steps)
nspline: number of spline points to explore
nsample: number of action sequence candidates sampled
interp: type of action interpolation
"""
model: mjx.Model
cost: CostFn
noise_scale: jax.Array
horizon: int
nspline: int
nsample: int
interp: str


def _rollout(p: Planner, d: mjx.Data, policy: jax.Array) -> jax.Array:
"""Expand the policy into actions and roll out dynamics and cost."""
actions = get_actions(p, policy)

def step(d, action):
d = d.replace(ctrl=action)
cost = p.cost(p.model, d)
d = mjx.step(p.model, d)
return d, cost

_, costs = jax.lax.scan(step, d, actions, length=p.horizon)

return jnp.sum(costs)


def get_actions(p: Planner, policy: jax.Array) -> jax.Array:
"""Gets actions over a planning duration from a policy."""
if p.interp == 'zero':
indices = [i * p.nspline // p.horizon for i in range(p.horizon)]
actions = policy[jnp.array(indices)]
elif p.interp == 'linear':
locs = jnp.array([i * p.nspline / p.horizon for i in range(p.horizon)])
idx = locs.astype(int)
actions = jax.vmap(jnp.multiply)(policy[idx], 1 - locs + idx)
actions += jax.vmap(jnp.multiply)(policy[idx + 1], locs - idx)
else:
raise ValueError(f'unimplemented interp: {p.interp}')

return actions


def improve_policy(
p: Planner, d: mjx.Data, policy: jax.Array, rng: jax.Array
) -> Tuple[jax.Array, jax.Array]:
"""Improves policy."""
limit = p.model.actuator_ctrlrange

# create noisy policies, with nominal policy at index 0
noise = jax.random.normal(rng, (p.nsample, p.nspline, p.model.nu))
noise = noise * p.noise_scale * (limit[:, 1] - limit[:, 0])
policies = jnp.concatenate((policy[None], policy + noise))
# clamp actions to ctrlrange
policies = jnp.clip(policies, limit[:, 0], limit[:, 1])

# perform nsample + 1 parallel rollouts
costs = jax.vmap(_rollout, in_axes=(None, None, 0))(p, d, policies)
costs = jnp.nan_to_num(costs, nan=jnp.inf)
best_id = jnp.argmin(costs)

return policies[best_id], costs[best_id]


def resample(p: Planner, policy: jax.Array, steps_per_plan: int) -> jax.Array:
"""Resample policy to new advanced time."""
if p.horizon % p.nspline != 0:
raise ValueError("horizon must be divisible by nspline")
splinesteps = p.horizon // p.nspline
if splinesteps % steps_per_plan != 0:
raise ValueError(
f'splinesteps ({splinesteps}) must be divisible by steps_per_plan'
f' ({steps_per_plan})'
)
roll = splinesteps // steps_per_plan
policy = jnp.roll(policy, -roll, axis=0)
policy = policy.at[-roll:].set(policy[-roll - 1])

return policy


def set_state(d_out, d_in):
return d_out.replace(
time=d_in.time, qpos=d_in.qpos, qvel=d_in.qvel, act=d_in.act,
ctrl=d_in.ctrl)


def receding_horizon_optimization(
p: Planner,
plan_model_cpu,
sim_model_cpu,
nsteps,
steps_per_plan,
frame_skip,
):
d = mujoco.MjData(plan_model_cpu)
d = mjx.put_data(plan_model_cpu, d)
m = mjx.put_model(plan_model_cpu)
p = p.replace(model=m)
jitted_cost = jax.jit(p.cost)

policy = jnp.zeros((p.nspline, m.nu))
rng = jax.random.key(0)
improve_fn = (
jax.jit(improve_policy)
.lower(p, d, policy, rng)
.compile()
)
step_fn = jax.jit(mjx.step).lower(m, d).compile()

trajectory, costs = [], []
plan_time = 0
sim_data = mujoco.MjData(sim_model_cpu)
mujoco.mj_resetDataKeyframe(sim_model_cpu, sim_data, 0)
# without kinematics, the first cost is off:
mujoco.mj_forward(sim_model_cpu, sim_data)
sim_data = mjx.put_data(sim_model_cpu, sim_data)
sim_model = mjx.put_model(sim_model_cpu)
actions = get_actions(p, policy)

for step in range(nsteps):
if step % steps_per_plan == 0:
# resample policy to new advanced time
print('re-planning')
policy = resample(p, policy, steps_per_plan)
beg = time.perf_counter()
d = set_state(d, sim_data)
policy, _ = improve_fn(p, d, policy, jax.random.key(step))
plan_time += time.perf_counter() - beg
actions = get_actions(p, policy)

sim_data = sim_data.replace(ctrl=actions[0])
cost = jitted_cost(sim_model, sim_data)
sim_data = step_fn(sim_model, sim_data)
costs.append(cost)
print(f'step: {step}')
print(f'cost: {cost}')
if step % frame_skip == 0:
trajectory.append(jax.device_get(sim_data.qpos))

return trajectory, costs, plan_time
69 changes: 69 additions & 0 deletions python/mujoco_mpc/mjx/tasks/bimanual/handover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from etils import epath
import jax
from jax import numpy as jp
import mujoco
from mujoco import mjx
from mujoco_mpc.mjx import predictive_sampling


def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array:
"""Returns cost for bimanual bring to target task."""
# reach
left_gripper = d.site_xpos[3]
right_gripper = d.site_xpos[6]
box = d.xpos[m.nbody - 1]

reach_l = left_gripper - box
reach_r = right_gripper - box

# bring
target = jp.array([-0.4, -0.2, 0.3])
bring = box - target

residuals = [reach_l, reach_r, bring]
weights = [0.1, 0.1, 1]
norm_p = [0.005, 0.005, 0.003]

# NormType::kL2: y = sqrt(x*x' + p^2) - p
terms = []
for r, w, p in zip(residuals, weights, norm_p):
terms.append(w * (jp.sqrt(jp.dot(r, r) + p**2) - p))

return jp.sum(jp.array(terms))


def get_models_and_cost_fn() -> (
tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn]
):
"""Returns a tuple of the model and the cost function."""
path = epath.Path(
'build/mujoco_menagerie/aloha/'
)
model_file_name = 'mjx_scene.xml'
xml = (path / model_file_name).read_text()
assets = {}
for f in path.glob('*.xml'):
if f.name == model_file_name:
continue
assets[f.name] = f.read_bytes()
for f in (path / 'assets').glob('*'):
assets[f.name] = f.read_bytes()
sim_model = mujoco.MjModel.from_xml_string(xml, assets)
plan_model = mujoco.MjModel.from_xml_string(xml, assets)
plan_model.opt.timestep = 0.01 # incidentally, already the case
return sim_model, plan_model, bring_to_target
72 changes: 72 additions & 0 deletions python/mujoco_mpc/mjx/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import matplotlib.pyplot as plt
import mediapy
import mujoco
from mujoco_mpc.mjx import predictive_sampling
from mujoco_mpc.mjx.tasks.bimanual import handover
# %%
nsteps = 500
steps_per_plan = 4
frame_skip = 5 # how many steps between each rendered frame


sim_model, plan_model, cost_fn = handover.get_models_and_cost_fn()
p = predictive_sampling.Planner(
cost_fn,
noise_scale=0.3,
horizon=128,
nspline=4,
nsample=128 - 1,
interp='zero',
)

trajectory, costs, plan_time = (
predictive_sampling.receding_horizon_optimization(
p,
plan_model,
sim_model,
nsteps,
steps_per_plan,
frame_skip,
)
)
# %%
plt.xlim([0, nsteps * sim_model.opt.timestep])
plt.ylim([0, max(costs)])
plt.xlabel('time')
plt.ylabel('cost')
plt.plot([i * sim_model.opt.timestep for i in range(nsteps)], costs)
plt.show()

sim_time = nsteps * sim_model.opt.timestep
plan_steps = nsteps // steps_per_plan
real_factor = sim_time / plan_time
print(f'Total wall time ({plan_steps} planning steps): {plan_time} s'
f' ({real_factor:.2f}x realtime)')
# %%
frames = []
renderer = mujoco.Renderer(sim_model)
d = mujoco.MjData(sim_model)

for qpos in trajectory:
d.qpos = qpos
mujoco.mj_forward(sim_model, d)
renderer.update_scene(d)
frames.append(renderer.render())
# %%
mediapy.show_video(frames, fps=1/sim_model.opt.timestep/frame_skip)
# %%
3 changes: 2 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def _configure_and_build_agent_server(self):
],
install_requires=[
"grpcio",
"mujoco >= 2.3.3",
"mujoco >= 3.1.1",
"mujoco-mjx",
"protobuf",
],
extras_require={
Expand Down
3 changes: 2 additions & 1 deletion python/setup_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def _configure_and_build_direct_server(self):
extras_require={
"test": [
"absl-py",
"mujoco >= 2.3.3",
"mujoco >= 3.1.1",
"mujoco-mjx",
],
},
ext_modules=[CMakeExtension("direct_server")],
Expand Down
3 changes: 2 additions & 1 deletion python/setup_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def _configure_and_build_filter_server(self):
extras_require={
"test": [
"absl-py",
"mujoco >= 2.3.3",
"mujoco >= 3.1.1",
"mujoco-mjx",
],
},
ext_modules=[CMakeExtension("filter_server")],
Expand Down

0 comments on commit 5311e3e

Please sign in to comment.