From 6d55a09f810bc0d7d38ad04ade92f2b720308b58 Mon Sep 17 00:00:00 2001 From: Alex Pasquali Date: Mon, 12 Dec 2022 16:19:51 +0100 Subject: [PATCH] Updated custom policy docs to better explain the ``mlp_extractor``'s dimensions (#1196) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Updated custom policy docs Better explained how the dimensions of the mlp_extractor work, including the action net and the value net after the layers specified in net_arch. * Improved custom policy doc Section: Custom Network Architecture. Explained with greater detail that an action net and a value net will be added on top of the net_arch. * Improved custom policy doc Section: Custom Network Architecture. Merged a comment into a note * Alignment Co-authored-by: Quentin GALLOUÉDEC --- docs/guide/custom_policy.rst | 20 ++++++++++++++++++++ docs/misc/changelog.rst | 1 + 2 files changed, 21 insertions(+) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 4ba3203..fb25c18 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -60,6 +60,25 @@ Custom Network Architecture One way of customising the policy network architecture is to pass arguments when creating the model, using ``policy_kwargs`` parameter: +.. note:: + An extra linear layer will be added on top of the layers specified in ``net_arch``, in order to have the right output dimensions and activation functions (e.g. Softmax for discrete actions). + + In the following example, as CartPole's action space has a dimension of 2, the final dimensions of the ``net_arch``'s layers will be: + + + .. code-block:: none + + obs + <4> + / \ + <32> <32> + | | + <32> <32> + | | + <2> <1> + action value + + .. code-block:: python import gym @@ -69,6 +88,7 @@ using ``policy_kwargs`` parameter: # Custom actor (pi) and value function (vf) networks # of two layers of size 32 each with Relu activation function + # Note: an extra linear layer will be added on top of the pi and the vf nets, respectively policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[dict(pi=[32, 32], vf=[32, 32])]) # Create the agent diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9c65909..65c1487 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -52,6 +52,7 @@ Documentation: ^^^^^^^^^^^^^^ - Updated Hugging Face Integration page (@simoninithomas) - Changed ``env`` to ``vec_env`` when environment is vectorized +- Updated custom policy docs to better explain the ``mlp_extractor``'s dimensions (@AlexPasqua) - Update custom policy documentation (@athatheo) Release 1.6.2 (2022-10-10)