-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MJX predictive control, a bimanual hand-over task, and a visualiz…
…ation colab. PiperOrigin-RevId: 609711577 Change-Id: I11b348d9d94e5b7ba953d855969e751de936749e
- Loading branch information
1 parent
063ff93
commit 5311e3e
Showing
6 changed files
with
326 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Copyright 2023 DeepMind Technologies Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Predictive sampling for MPC.""" | ||
|
||
import time | ||
from typing import Callable, Tuple | ||
|
||
import jax | ||
from jax import numpy as jnp | ||
import mujoco | ||
from mujoco import mjx | ||
from mujoco.mjx._src import dataclasses | ||
|
||
CostFn = Callable[[mjx.Model, mjx.Data], jax.Array] | ||
|
||
|
||
class Planner(dataclasses.PyTreeNode): | ||
"""Predictive sampling planner. | ||
Attributes: | ||
model: MuJoCo model | ||
cost: function returning per-timestep cost | ||
noise_scale: standard deviation of zero-mean Gaussian | ||
horizon: planning duration (steps) | ||
nspline: number of spline points to explore | ||
nsample: number of action sequence candidates sampled | ||
interp: type of action interpolation | ||
""" | ||
model: mjx.Model | ||
cost: CostFn | ||
noise_scale: jax.Array | ||
horizon: int | ||
nspline: int | ||
nsample: int | ||
interp: str | ||
|
||
|
||
def _rollout(p: Planner, d: mjx.Data, policy: jax.Array) -> jax.Array: | ||
"""Expand the policy into actions and roll out dynamics and cost.""" | ||
actions = get_actions(p, policy) | ||
|
||
def step(d, action): | ||
d = d.replace(ctrl=action) | ||
cost = p.cost(p.model, d) | ||
d = mjx.step(p.model, d) | ||
return d, cost | ||
|
||
_, costs = jax.lax.scan(step, d, actions, length=p.horizon) | ||
|
||
return jnp.sum(costs) | ||
|
||
|
||
def get_actions(p: Planner, policy: jax.Array) -> jax.Array: | ||
"""Gets actions over a planning duration from a policy.""" | ||
if p.interp == 'zero': | ||
indices = [i * p.nspline // p.horizon for i in range(p.horizon)] | ||
actions = policy[jnp.array(indices)] | ||
elif p.interp == 'linear': | ||
locs = jnp.array([i * p.nspline / p.horizon for i in range(p.horizon)]) | ||
idx = locs.astype(int) | ||
actions = jax.vmap(jnp.multiply)(policy[idx], 1 - locs + idx) | ||
actions += jax.vmap(jnp.multiply)(policy[idx + 1], locs - idx) | ||
else: | ||
raise ValueError(f'unimplemented interp: {p.interp}') | ||
|
||
return actions | ||
|
||
|
||
def improve_policy( | ||
p: Planner, d: mjx.Data, policy: jax.Array, rng: jax.Array | ||
) -> Tuple[jax.Array, jax.Array]: | ||
"""Improves policy.""" | ||
limit = p.model.actuator_ctrlrange | ||
|
||
# create noisy policies, with nominal policy at index 0 | ||
noise = jax.random.normal(rng, (p.nsample, p.nspline, p.model.nu)) | ||
noise = noise * p.noise_scale * (limit[:, 1] - limit[:, 0]) | ||
policies = jnp.concatenate((policy[None], policy + noise)) | ||
# clamp actions to ctrlrange | ||
policies = jnp.clip(policies, limit[:, 0], limit[:, 1]) | ||
|
||
# perform nsample + 1 parallel rollouts | ||
costs = jax.vmap(_rollout, in_axes=(None, None, 0))(p, d, policies) | ||
costs = jnp.nan_to_num(costs, nan=jnp.inf) | ||
best_id = jnp.argmin(costs) | ||
|
||
return policies[best_id], costs[best_id] | ||
|
||
|
||
def resample(p: Planner, policy: jax.Array, steps_per_plan: int) -> jax.Array: | ||
"""Resample policy to new advanced time.""" | ||
if p.horizon % p.nspline != 0: | ||
raise ValueError("horizon must be divisible by nspline") | ||
splinesteps = p.horizon // p.nspline | ||
if splinesteps % steps_per_plan != 0: | ||
raise ValueError( | ||
f'splinesteps ({splinesteps}) must be divisible by steps_per_plan' | ||
f' ({steps_per_plan})' | ||
) | ||
roll = splinesteps // steps_per_plan | ||
policy = jnp.roll(policy, -roll, axis=0) | ||
policy = policy.at[-roll:].set(policy[-roll - 1]) | ||
|
||
return policy | ||
|
||
|
||
def set_state(d_out, d_in): | ||
return d_out.replace( | ||
time=d_in.time, qpos=d_in.qpos, qvel=d_in.qvel, act=d_in.act, | ||
ctrl=d_in.ctrl) | ||
|
||
|
||
def receding_horizon_optimization( | ||
p: Planner, | ||
plan_model_cpu, | ||
sim_model_cpu, | ||
nsteps, | ||
steps_per_plan, | ||
frame_skip, | ||
): | ||
d = mujoco.MjData(plan_model_cpu) | ||
d = mjx.put_data(plan_model_cpu, d) | ||
m = mjx.put_model(plan_model_cpu) | ||
p = p.replace(model=m) | ||
jitted_cost = jax.jit(p.cost) | ||
|
||
policy = jnp.zeros((p.nspline, m.nu)) | ||
rng = jax.random.key(0) | ||
improve_fn = ( | ||
jax.jit(improve_policy) | ||
.lower(p, d, policy, rng) | ||
.compile() | ||
) | ||
step_fn = jax.jit(mjx.step).lower(m, d).compile() | ||
|
||
trajectory, costs = [], [] | ||
plan_time = 0 | ||
sim_data = mujoco.MjData(sim_model_cpu) | ||
mujoco.mj_resetDataKeyframe(sim_model_cpu, sim_data, 0) | ||
# without kinematics, the first cost is off: | ||
mujoco.mj_forward(sim_model_cpu, sim_data) | ||
sim_data = mjx.put_data(sim_model_cpu, sim_data) | ||
sim_model = mjx.put_model(sim_model_cpu) | ||
actions = get_actions(p, policy) | ||
|
||
for step in range(nsteps): | ||
if step % steps_per_plan == 0: | ||
# resample policy to new advanced time | ||
print('re-planning') | ||
policy = resample(p, policy, steps_per_plan) | ||
beg = time.perf_counter() | ||
d = set_state(d, sim_data) | ||
policy, _ = improve_fn(p, d, policy, jax.random.key(step)) | ||
plan_time += time.perf_counter() - beg | ||
actions = get_actions(p, policy) | ||
|
||
sim_data = sim_data.replace(ctrl=actions[0]) | ||
cost = jitted_cost(sim_model, sim_data) | ||
sim_data = step_fn(sim_model, sim_data) | ||
costs.append(cost) | ||
print(f'step: {step}') | ||
print(f'cost: {cost}') | ||
if step % frame_skip == 0: | ||
trajectory.append(jax.device_get(sim_data.qpos)) | ||
|
||
return trajectory, costs, plan_time |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2023 DeepMind Technologies Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from etils import epath | ||
import jax | ||
from jax import numpy as jp | ||
import mujoco | ||
from mujoco import mjx | ||
from mujoco_mpc.mjx import predictive_sampling | ||
|
||
|
||
def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array: | ||
"""Returns cost for bimanual bring to target task.""" | ||
# reach | ||
left_gripper = d.site_xpos[3] | ||
right_gripper = d.site_xpos[6] | ||
box = d.xpos[m.nbody - 1] | ||
|
||
reach_l = left_gripper - box | ||
reach_r = right_gripper - box | ||
|
||
# bring | ||
target = jp.array([-0.4, -0.2, 0.3]) | ||
bring = box - target | ||
|
||
residuals = [reach_l, reach_r, bring] | ||
weights = [0.1, 0.1, 1] | ||
norm_p = [0.005, 0.005, 0.003] | ||
|
||
# NormType::kL2: y = sqrt(x*x' + p^2) - p | ||
terms = [] | ||
for r, w, p in zip(residuals, weights, norm_p): | ||
terms.append(w * (jp.sqrt(jp.dot(r, r) + p**2) - p)) | ||
|
||
return jp.sum(jp.array(terms)) | ||
|
||
|
||
def get_models_and_cost_fn() -> ( | ||
tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn] | ||
): | ||
"""Returns a tuple of the model and the cost function.""" | ||
path = epath.Path( | ||
'build/mujoco_menagerie/aloha/' | ||
) | ||
model_file_name = 'mjx_scene.xml' | ||
xml = (path / model_file_name).read_text() | ||
assets = {} | ||
for f in path.glob('*.xml'): | ||
if f.name == model_file_name: | ||
continue | ||
assets[f.name] = f.read_bytes() | ||
for f in (path / 'assets').glob('*'): | ||
assets[f.name] = f.read_bytes() | ||
sim_model = mujoco.MjModel.from_xml_string(xml, assets) | ||
plan_model = mujoco.MjModel.from_xml_string(xml, assets) | ||
plan_model.opt.timestep = 0.01 # incidentally, already the case | ||
return sim_model, plan_model, bring_to_target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright 2023 DeepMind Technologies Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import matplotlib.pyplot as plt | ||
import mediapy | ||
import mujoco | ||
from mujoco_mpc.mjx import predictive_sampling | ||
from mujoco_mpc.mjx.tasks.bimanual import handover | ||
# %% | ||
nsteps = 500 | ||
steps_per_plan = 4 | ||
frame_skip = 5 # how many steps between each rendered frame | ||
|
||
|
||
sim_model, plan_model, cost_fn = handover.get_models_and_cost_fn() | ||
p = predictive_sampling.Planner( | ||
cost_fn, | ||
noise_scale=0.3, | ||
horizon=128, | ||
nspline=4, | ||
nsample=128 - 1, | ||
interp='zero', | ||
) | ||
|
||
trajectory, costs, plan_time = ( | ||
predictive_sampling.receding_horizon_optimization( | ||
p, | ||
plan_model, | ||
sim_model, | ||
nsteps, | ||
steps_per_plan, | ||
frame_skip, | ||
) | ||
) | ||
# %% | ||
plt.xlim([0, nsteps * sim_model.opt.timestep]) | ||
plt.ylim([0, max(costs)]) | ||
plt.xlabel('time') | ||
plt.ylabel('cost') | ||
plt.plot([i * sim_model.opt.timestep for i in range(nsteps)], costs) | ||
plt.show() | ||
|
||
sim_time = nsteps * sim_model.opt.timestep | ||
plan_steps = nsteps // steps_per_plan | ||
real_factor = sim_time / plan_time | ||
print(f'Total wall time ({plan_steps} planning steps): {plan_time} s' | ||
f' ({real_factor:.2f}x realtime)') | ||
# %% | ||
frames = [] | ||
renderer = mujoco.Renderer(sim_model) | ||
d = mujoco.MjData(sim_model) | ||
|
||
for qpos in trajectory: | ||
d.qpos = qpos | ||
mujoco.mj_forward(sim_model, d) | ||
renderer.update_scene(d) | ||
frames.append(renderer.render()) | ||
# %% | ||
mediapy.show_video(frames, fps=1/sim_model.opt.timestep/frame_skip) | ||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters