learn what not to learn
Introduction
Learning how to act when the action space are large is challenging for reinforcement learning. For a specific case that many actions are irrelevant, it is sometimes easier for the algorithm to learn which action not to take. The paper propose a new reinforcement learning approach for dealing with large action spaces by restricting the available actions in each state to a subset of the most likely ones. More specifically, it propose a system that learns the approximation of Q-function and concurrently leans to eliminate actions. The method need to utilize an additional elimination signal which incorporates domain-specific prior knowledge. For example, in parser-based text games, the parser gives feedback regarding irrelevant actions after the action is played. (e.g., Player: "Climb the tree." Parser: "There are no trees to climb") Then a machine learning model can be trained to generalize to unseen states.
The paper focus mainly on tasks where both states and the actions are natural language. It introduce a novel deep reinforcement learning approach which has a DQN network and an Action Elimination Network(AEN), both using the CNN which is suitable to NLP tasks. The AEN is trained to predict invalid actions, supervised by the elimination signal from the environment. Note that the core assumption is that it is easy to predict which actions are invalid or inferior in each state and leverage that information for control.
Related Work
Text-Based Games(TBG): The state of environment in TBG is described by simple language. The player interact with the environment with text command as action which respect a pre-defined grammar. An popular example is Zork which has been tested in the paper.
Representations for TBG: Good word representation is necessary in order to learn control policies from text. Previous work on TBG used pre-trained embeddings directly for control. other works combined pre-trained embeddings with neural networks.
DRL with linear function approximation: DRL methods such as the DQN have achieved state-of-the-art results in a variety of challenging, high-dimensional domains. This is mainly because neural networks can learn rich domain representations for value function and policy.
RL in Large Action Spaces: Prior work concentrated on factorizing the action space into binary subspace(Pazis and Parr, 2011; Dulac-Arnold et al., 2012; Lagoudakis and Parr, 2003), other works proposed to embed the discrete actions into a continuous space, then choose the nearest discrete action according to the optimal actions in the continuous space(Dulac-Arnold et al., 2015; Van Hasselt and Wiering, 2009). He et. al. (2015)extended DQN to unbounded(natural language) action spaces. Learning to eliminate actions was first mentioned by (Even-Dar, Mannor, and Mansour, 2003). They proposed to learn confidence intervals around the value function in each state. Lipton et al.(2016a) proposed to learn a classifier that dtects hazardous state and then use it to shape the reward. Fulda et al.(2017) presented a method for affordance extraction via inner products of pre-trained word embeddings.
Action Elimination
The approach in the paper builds on the standard RL formulation: At each time step t, the agent observes state [math]\displaystyle{ s_t }[/math] and chooses a discrete action [math]\displaystyle{ a_t\in\{1,...,|A|\} }[/math]. Then the agent obtains a reward [math]\displaystyle{ r_t(s_t,a_t) }[/math] and next state [math]\displaystyle{ s_{t+1} }[/math]. The goal of the algorithm is to learn a policy [math]\displaystyle{ \pi(a|s) }[/math] which maximizes the expected future discount return [math]\displaystyle{ V^\pi(s)=E^\pi[\sum_{t=0}^{\infty}\gamma^tr(s_t,a_t)|s_0=s]. }[/math]
After executing an action, the agent observes a binary elimination signal e(s,a), which equals to 1 if action a can be eliminated for state s, 0 otherwise.
Definition 1. Valid state-action pairs with respect to an elimination signal are state action pairs which the elimination process should not eliminate.
Definition 2. Admissible state-action pairs with respect to an elimination algorithm are state action pairs which the elimination algorithm does not eliminate.
Action elimination with contextual bandits: Let [math]\displaystyle{ x(s_t)\in R^d }[/math] be the feature representation of [math]\displaystyle{ s_t }[/math]. We assume that under this representation there exists a set of parameters [math]\displaystyle{ \theta_a^*\in R_d }[/math] such that the elimination signal in state [math]\displaystyle{ s_t }[/math] is [math]\displaystyle{ e_t(s_t,a) = \theta_a^Tx(s_t)+\eta_t }[/math], where [math]\displaystyle{ \Vert\theta_a^*\Vert_2\leq S }[/math]. [math]\displaystyle{ \eta_t }[/math] is an R-subgaussian random variable with zero mean that models additive noise to the elimination signal. When there is no noise in the elimination signal, then R=0. Otherwise, [math]\displaystyle{ R\leq 1 }[/math] since the elimination signal is bounded in [0,1]. We will relax the assumption and let elimination signal to be: [math]\displaystyle{ 0\leq E[e_t(s_t,a)]\leq 1 }[/math] for any valid action and [math]\displaystyle{ u\leq E[e_t(s_t, a)]\leq 1 }[/math] for any invalid action. And [math]\displaystyle{ l\leq u }[/math]. Next we denote by [math]\displaystyle{ X_{t,a} }[/math] as the matrix whose rows are the observed state representation vectors in which action a was chosen, up to time t. [math]\displaystyle{ E_{t,a} }[/math] as the vector whose elements are the observed state representation elimination signals in which action a was chosen, up to time t. Denote the solution to the regularized linear regression [math]\displaystyle{ \Vert X_{t,a}\theta_{t,a}-E_{t,a}\Vert_2^2+\lambda\Vert \theta_{t,a}\Vert_2^2 }[/math] (for some [math]\displaystyle{ \lambda\gt 0 }[/math]) by [math]\displaystyle{ \hat{\theta}_{t,a}=\bar{V}_{t,a}^{-1}X_{t,a}^TE_{t,a} }[/math] where [math]\displaystyle{ \bar{V}_{t,a}=\lambda I + X_{t,a}^TX_{t,a} }[/math].
Concurrent Learning: This part will show that Q-learning and contextual bandit algorithms can learn simultaneously, resulting in the convergence of both algorithms, i.e., finding an optimal policy and a minimal valid action space.
Definition 3. Action Elimination Q-learning is a Q-learning algorithm which updates only admissible state-action pairs and chooses the best action in the next state from its admissible actions. We allow the base Q-learning algorithm to be any algorithm that converges to [math]\displaystyle{ Q^* }[/math] with probability 1 after observing each state-action infinitely often.
If the elimination is done based on the concentration bounds of the linear contextual bandits, we can ensure that Action Elimination Q-learning converges, as shown in Proposition 1.
Proposition 1: Assume that all state action pairs (s,a) are visited infinitely often, unless eliminated according to [math]\displaystyle{ \hat{\theta}_{t-1,a}^Tx(s)-\sqrt{\beta_{t-1}(\tilde{\delta})x(s)^T\bar{V}_{t-1,a}^{-1}x(s))}\gt l }[/math]. Then, with a probability of at least [math]\displaystyle{ 1-\delta }[/math], action elimination Q-learning converges to the optimal Q-function for any valid state-action pairs. In addition, actions which should be eliminated are visited at most [math]\displaystyle{ T_{s,a}(t)\leq 4\beta_t/(u-l)^2 +1 }[/math] times.
Notice that when there is no noise in the elimination signal(R=0), we correctly eliminate actions with probability 1. so invalid actions will be sampled a finite number of times.
Method
The assumption that [math]\displaystyle{ e_t(s_t,a)=\theta_a^{*T}x(s_t)+\eta_t }[/math] might not hold when using raw features like word2vec. So the paper proposes to use the neural network's last layer as features. A practical challenge here is that the features must be fixed over time when used by the contextual bandit. So batch-updates framework(Levine et al., 2017;Riquelme, Tucker, and Snoek, 2018) is used, where a new contextual bandit model is learned for every few steps that uses the last layer activations of the AEN as features.
Below shows the architecture of action elimination framework: After taking action [math]\displaystyle{ a_t }[/math], the agent observes [math]\displaystyle{ (r_t,s_{t+1},e_t) }[/math]. The agent use it to learn two function approximation deep neural networks: A DQN and an AEN. AEN provides an admissible actions set [math]\displaystyle{ A' }[/math] to the DQN.
Psuedocode of the Algorithm: AE-DQN trains two networks: a DQN denoted by Q and an AEN denoted by E. The algorithm creates a linear contextual bandit model from it every L iterations with procedure AENUpdate(). This procedure uses the activations of the last hidden layer of E as features, which are then used to create a contextual linear bandit model, i.e., <math display="inline">V_a=\sum_