diff --git a/docs/optax-101.ipynb b/docs/optax-101.ipynb index 6fc9dea5b..aa9d32d6f 100644 --- a/docs/optax-101.ipynb +++ b/docs/optax-101.ipynb @@ -1,20 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Optax 101", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", @@ -40,27 +24,23 @@ }, { "cell_type": "code", + "execution_count": 1, "metadata": { "id": "Gg6zyMBqydty" }, + "outputs": [], "source": [ - "import random\n", - "from typing import Tuple\n", - "\n", "import optax\n", "import jax.numpy as jnp\n", "import jax\n", - "import numpy as np\n", "\n", "BATCH_SIZE = 5\n", "NUM_TRAIN_STEPS = 1_000\n", - "RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))\n", + "RAW_TRAINING_DATA = jax.random.randint(jax.random.PRNGKey(42), (NUM_TRAIN_STEPS, BATCH_SIZE, 1), 0, 255)\n", "\n", - "TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)\n", + "TRAINING_DATA = jnp.unpackbits(RAW_TRAINING_DATA.astype(jnp.uint8), axis=-1)\n", "LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -77,9 +57,11 @@ }, { "cell_type": "code", + "execution_count": 2, "metadata": { "id": "Syp9LJ338h9-" }, + "outputs": [], "source": [ "initial_params = {\n", " 'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),\n", @@ -101,9 +83,7 @@ " loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)\n", "\n", " return loss_value.mean()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -118,36 +98,55 @@ }, { "cell_type": "code", + "execution_count": 3, "metadata": { - "id": "JsbPBTF09FGY", "executionInfo": { + "elapsed": 6046, "status": "ok", "timestamp": 1636155226542, - "user_tz": 0, - "elapsed": 6046, "user": { "displayName": "Ross Hemsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjSZqBnQizDvVofyb2N_r9W3cP4duk9mv1mxCb9=s64", "userId": "11415908946302743815" - } + }, + "user_tz": 0 }, + "id": "JsbPBTF09FGY", "outputId": "c427f94f-a605-44fc-b519-707bc5d47b7d" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step: 0, Loss: 5.624\n", + "Step: 100, Loss: 0.188\n", + "Step: 200, Loss: 0.053\n", + "Step: 300, Loss: 0.025\n", + "Step: 400, Loss: 0.004\n", + "Step: 500, Loss: 0.028\n", + "Step: 600, Loss: 0.002\n", + "Step: 700, Loss: 0.025\n", + "Step: 800, Loss: 0.017\n", + "Step: 900, Loss: 0.003\n" + ] + } + ], "source": [ + "@jax.jit\n", + "def step(params, opt_state, batch, labels):\n", + " loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)\n", + " updates, opt_state = optimizer.update(grads, opt_state, params)\n", + " params = optax.apply_updates(params, updates)\n", + " return params, opt_state, loss_value\n", + "\n", "def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:\n", " opt_state = optimizer.init(params)\n", "\n", - " @jax.jit\n", - " def step(params, opt_state, batch, labels):\n", - " loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)\n", - " updates, opt_state = optimizer.update(grads, opt_state, params)\n", - " params = optax.apply_updates(params, updates)\n", - " return params, opt_state, loss_value\n", - "\n", " for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):\n", " params, opt_state, loss_value = step(params, opt_state, batch, labels)\n", " if i % 100 == 0:\n", - " print(f'step {i}, loss: {loss_value}')\n", + " print(f'Step: {i:3}, Loss: {loss_value:.3f}')\n", "\n", " return params\n", "\n", @@ -155,25 +154,6 @@ "# provided by optax.\n", "optimizer = optax.adam(learning_rate=1e-2)\n", "params = fit(initial_params, optimizer)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "step 0, loss: 5.60183048248291\n", - "step 100, loss: 0.14773361384868622\n", - "step 200, loss: 0.28999248147010803\n", - "step 300, loss: 0.05951451137661934\n", - "step 400, loss: 0.08592046797275543\n", - "step 500, loss: 0.005035111214965582\n", - "step 600, loss: 0.0028563595842570066\n", - "step 700, loss: 0.013286210596561432\n", - "step 800, loss: 0.01311601884663105\n", - "step 900, loss: 0.003692328929901123\n" - ] - } ] }, { @@ -200,21 +180,40 @@ }, { "cell_type": "code", + "execution_count": 4, "metadata": { - "id": "SZegYQajDtLi", "executionInfo": { + "elapsed": 734, "status": "ok", "timestamp": 1636155227388, - "user_tz": 0, - "elapsed": 734, "user": { "displayName": "Ross Hemsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjSZqBnQizDvVofyb2N_r9W3cP4duk9mv1mxCb9=s64", "userId": "11415908946302743815" - } + }, + "user_tz": 0 }, + "id": "SZegYQajDtLi", "outputId": "f65f9fd8-8e9c-4ae6-e759-62362ff94f53" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step: 0, Loss: 5.624\n", + "Step: 100, Loss: 0.000\n", + "Step: 200, Loss: 0.000\n", + "Step: 300, Loss: 0.000\n", + "Step: 400, Loss: 0.000\n", + "Step: 500, Loss: 0.000\n", + "Step: 600, Loss: 0.000\n", + "Step: 700, Loss: 0.000\n", + "Step: 800, Loss: 0.000\n", + "Step: 900, Loss: 0.000\n" + ] + } + ], "source": [ "schedule = optax.warmup_cosine_decay_schedule(\n", " init_value=0.0,\n", @@ -230,26 +229,112 @@ ")\n", "\n", "params = fit(initial_params, optimizer)" - ], - "execution_count": null, + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N7-efvtM16pO" + }, + "source": [ + "## Reading the Learning Rate inside the Train Loop" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GzPJMRYV16pP" + }, + "source": [ + "Sometimes we want to access certain hyperparameters in the optimizer. For example, we may want to log the learning rate at a service.\n", + "\n", + "To extract the learning rate inside the train loop, we can use the [inject_hyperparams](https://optax.readthedocs.io/en/latest/api.html#optax.inject_hyperparams) wrapper to make any hyperparameter a modifiable part of the optimizer state. This means that you can promote the learning rate to be part of the optimizer state so that you can access it in the optimizer state directly.\n", + "\n", + "The following example demonstrates how to extend the previous code to extract the learning rate." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FIT1aO9_16pP", + "outputId": "f90205ee-9359-42b3-f745-15aa67d33b62" + }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ - "step 0, loss: 5.60183048248291\n", - "step 100, loss: 1.0181801179953709e-08\n", - "step 200, loss: 0.27725887298583984\n", - "step 300, loss: 0.0\n", - "step 400, loss: 0.0\n", - "step 500, loss: 0.0\n", - "step 600, loss: 0.0\n", - "step 700, loss: 0.0\n", - "step 800, loss: 0.0\n", - "step 900, loss: 0.0\n" + "Available hyperparams: b1 b2 eps eps_root weight_decay learning_rate\n", + "\n", + "Step 0, Loss: 5.624, Learning rate: 0.020\n", + "Step 100, Loss: 0.000, Learning rate: 0.993\n", + "Step 200, Loss: 0.000, Learning rate: 0.939\n", + "Step 300, Loss: 0.000, Learning rate: 0.837\n", + "Step 400, Loss: 0.000, Learning rate: 0.699\n", + "Step 500, Loss: 0.000, Learning rate: 0.540\n", + "Step 600, Loss: 0.000, Learning rate: 0.376\n", + "Step 700, Loss: 0.000, Learning rate: 0.225\n", + "Step 800, Loss: 0.000, Learning rate: 0.104\n", + "Step 900, Loss: 0.000, Learning rate: 0.027\n" ] } + ], + "source": [ + "# Wrap the optimizer to inject the hyperparameters\n", + "optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)\n", + "\n", + "def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:\n", + " opt_state = optimizer.init(params)\n", + "\n", + " # Since we injected hyperparams, we can access them directly here\n", + " print(f'Available hyperparams: {\" \".join(opt_state.hyperparams.keys())}\\n')\n", + "\n", + " for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):\n", + " params, opt_state, loss_value = step(params, opt_state, batch, labels)\n", + " if i % 100 == 0:\n", + " # Get the updated learning rate\n", + " lr = opt_state.hyperparams['learning_rate']\n", + " print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.3f}')\n", + "\n", + " return params\n", + "\n", + "params = fit(initial_params, optimizer)" ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "optax-101.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3.9.13 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "vscode": { + "interpreter": { + "hash": "626d743d6476408aa1b36c3ff0d1f9d9d03e37c6879626ddfcdd13d658004bbf" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 }