learn what not to learn

From statwiki
Revision as of 19:41, 19 November 2018 by J385chen (talk | contribs) (→‎Method)
Jump to navigation Jump to search

Introduction

In reinforcement learning, agent Learning how to act when the action space are large is challenging. For a specific case that many actions are irrelevant, it is sometimes easier for the algorithm to learn which action not to take. The paper propose a new reinforcement learning approach for dealing with large action spaces by restricting the available actions in each state to a subset of the most likely ones. More specifically, it propose a system that learns the approximation of Q-function and concurrently leans to eliminate actions. The method need to utilize an external elimination signal which incorporates domain-specific prior knowledge. For example, in parser-based text games, the parser gives feedback regarding irrelevant actions after the action is played. (e.g., Player: "Climb the tree." Parser: "There are no trees to climb") Then a machine learning model can be trained to generalize to unseen states.

The paper focus mainly on tasks where both states and the actions are natural language. It introduces a novel deep reinforcement learning approach which has a DQN network and an Action Elimination Network(AEN), both using the CNN which is suitable to NLP tasks. The AEN is trained to predict invalid actions, supervised by the elimination signal from the environment. Note that the core assumption is that it is easy to predict which actions are invalid or inferior in each state and leverage that information for control.

The text-based game called "Zork", which let player to interact with a virtual world through a text based interface, is tested by using this method. The input for the game contains more than a thousand possible actions in each state. The algorithm has achieved faster learning rate than the baseline agents through eliminating irrelevant actions.

Below shows an example for the Zork interface:

Related Work

Text-Based Games(TBG): The state of environment in TBG is described by simple language. The player interact with the environment with text command as action which respect a pre-defined grammar. An popular example is Zork which has been tested in the paper.

Representations for TBG: Good word representation is necessary in order to learn control policies from text. Previous work on TBG used pre-trained embeddings directly for control. other works combined pre-trained embeddings with neural networks.

DRL with linear function approximation: DRL methods such as the DQN have achieved state-of-the-art results in a variety of challenging, high-dimensional domains. This is mainly because neural networks can learn rich domain representations for value function and policy.

RL in Large Action Spaces: Prior work concentrated on factorizing the action space into binary subspace(Pazis and Parr, 2011; Dulac-Arnold et al., 2012; Lagoudakis and Parr, 2003), other works proposed to embed the discrete actions into a continuous space, then choose the nearest discrete action according to the optimal actions in the continuous space(Dulac-Arnold et al., 2015; Van Hasselt and Wiering, 2009). He et. al. (2015)extended DQN to unbounded(natural language) action spaces. Learning to eliminate actions was first mentioned by (Even-Dar, Mannor, and Mansour, 2003). They proposed to learn confidence intervals around the value function in each state. Lipton et al.(2016a) proposed to learn a classifier that dtects hazardous state and then use it to shape the reward. Fulda et al.(2017) presented a method for affordance extraction via inner products of pre-trained word embeddings.

Action Elimination

Definition 1:

Valid state-action pairs with respect to an elimination signal are state action pairs which the elimination process should not eliminate.

Definition 2:

Admissible state-action pairs with respect to an elimination algorithm are state action pairs which the elimination algorithm does not eliminate.

Definition 3:

Action Elimination Q-learning is a Q-learning algorithm which updates only admissible state-action pairs and chooses the best action in the next state from its admissible actions. We allow the base Q-learning algorithm to be any algorithm that converges to [math]\displaystyle{ Q^* }[/math] with probability 1 after observing each state-action infinitely often.

The approach in the paper builds on the standard RL formulation. At each time step t, the agent observes state [math]\displaystyle{ s_t }[/math] and chooses a discrete action [math]\displaystyle{ a_t\in\{1,...,|A|\} }[/math]. Then the agent obtains a reward [math]\displaystyle{ r_t(s_t,a_t) }[/math] and next state [math]\displaystyle{ s_{t+1} }[/math]. The goal of the algorithm is to learn a policy [math]\displaystyle{ \pi(a|s) }[/math] which maximizes the expected future discount return [math]\displaystyle{ V^\pi(s)=E^\pi[\sum_{t=0}^{\infty}\gamma^tr(s_t,a_t)|s_0=s]. }[/math]After executing an action, the agent observes a binary elimination signal e(s,a), which equals to 1 if action a can be eliminated for state s, 0 otherwise.


Action elimination with contextual bandits

Let [math]\displaystyle{ x(s_t)\in R^d }[/math] be the feature representation of [math]\displaystyle{ s_t }[/math]. We assume that under this representation there exists a set of parameters [math]\displaystyle{ \theta_a^*\in R_d }[/math] such that the elimination signal in state [math]\displaystyle{ s_t }[/math] is [math]\displaystyle{ e_t(s_t,a) = \theta_a^Tx(s_t)+\eta_t }[/math], where [math]\displaystyle{ \Vert\theta_a^*\Vert_2\leq S }[/math]. [math]\displaystyle{ \eta_t }[/math] is an R-subgaussian random variable with zero mean that models additive noise to the elimination signal. When there is no noise in the elimination signal, then R=0. Otherwise, [math]\displaystyle{ R\leq 1 }[/math] since the elimination signal is bounded in [0,1]. Assume the elimination signal satisfies: [math]\displaystyle{ 0\leq E[e_t(s_t,a)]\leq 1 }[/math] for any valid action and [math]\displaystyle{ u\leq E[e_t(s_t, a)]\leq 1 }[/math] for any invalid action. And [math]\displaystyle{ l\leq u }[/math]. Denote by [math]\displaystyle{ X_{t,a} }[/math] as the matrix whose rows are the observed state representation vectors in which action a was chosen, up to time t. [math]\displaystyle{ E_{t,a} }[/math] as the vector whose elements are the observed state representation elimination signals in which action a was chosen, up to time t. Denote the solution to the regularized linear regression [math]\displaystyle{ \Vert X_{t,a}\theta_{t,a}-E_{t,a}\Vert_2^2+\lambda\Vert \theta_{t,a}\Vert_2^2 }[/math] (for some [math]\displaystyle{ \lambda\gt 0 }[/math]) by [math]\displaystyle{ \hat{\theta}_{t,a}=\bar{V}_{t,a}^{-1}X_{t,a}^TE_{t,a} }[/math], where [math]\displaystyle{ \bar{V}_{t,a}=\lambda I + X_{t,a}^TX_{t,a} }[/math].

Concurrent Learning

This part will show that Q-learning and contextual bandit algorithms can learn simultaneously, resulting in the convergence of both algorithms, i.e., finding an optimal policy and a minimal valid action space.

If the elimination is done based on the concentration bounds of the linear contextual bandits, we can ensure that Action Elimination Q-learning converges, as shown in Proposition 1.

Proposition 1: Assume that all state action pairs (s,a) are visited infinitely often, unless eliminated according to [math]\displaystyle{ \hat{\theta}_{t-1,a}^Tx(s)-\sqrt{\beta_{t-1}(\tilde{\delta})x(s)^T\bar{V}_{t-1,a}^{-1}x(s))}\gt l }[/math]. Then, with a probability of at least [math]\displaystyle{ 1-\delta }[/math], action elimination Q-learning converges to the optimal Q-function for any valid state-action pairs. In addition, actions which should be eliminated are visited at most [math]\displaystyle{ T_{s,a}(t)\leq 4\beta_t/(u-l)^2 +1 }[/math] times.

Notice that when there is no noise in the elimination signal(R=0), we correctly eliminate actions with probability 1. so invalid actions will be sampled a finite number of times.

Method

The assumption that [math]\displaystyle{ e_t(s_t,a)=\theta_a^{*T}x(s_t)+\eta_t }[/math] might not hold when using raw features like word2vec. So the paper proposes to use the neural network's last layer as features. A practical challenge here is that the features must be fixed over time when used by the contextual bandit. So batch-updates framework(Levine et al., 2017;Riquelme, Tucker, and Snoek, 2018) is used, where a new contextual bandit model is learned for every few steps that uses the last layer activations of the AEN as features.

Below shows the architecture of action elimination framework:

After taking action [math]\displaystyle{ a_t }[/math], the agent observes [math]\displaystyle{ (r_t,s_{t+1},e_t) }[/math]. The agent use it to learn two function approximation deep neural networks: A DQN and an AEN. AEN provides an admissible actions set [math]\displaystyle{ A' }[/math] to the DQN. The architecture for both the AEN and DQN is an NLP CNN(100 convolutional filters for AEN and 500 for DQN, with three different 1D kernels of length (1,2,3)), based on(Kim, 2014). the state is represented as a sequence of words, composed of the game descriptor and the player's inventory. these are truncated or zero padded to a length of 50 descriptor + 15 inventory words and each word is embedded into continuous vectors using word2vec in [math]\displaystyle{ R^300 }[/math]. The features of the last four states are then concatenated together such that the final state representations s are in [math]\displaystyle{ R^78000 }[/math]. The AEN is trained to minimize the MSE loss, using the elimination signal as a label.

Psuedocode of the Algorithm:

AE-DQN trains two networks: a DQN denoted by Q and an AEN denoted by E. The algorithm creates a linear contextual bandit model from it every L iterations with procedure AENUpdate(). This procedure uses the activations of the last hidden layer of E as features, which are then used to create a contextual linear bandit model.AENUpdate() then solved this model and plugin it into the target AEN. The contextual linear bandit model [math]\displaystyle{ (E^-,V) }[/math] is then used to eliminate actions via the ACT() and Target() functions. ACT() follows an [math]\displaystyle{ \epsilon }[/math]-greedy mechanism on the admissible actions set. If it decides to exploit, then it selects the action with highest Q-value by taking an argmax on Q-values among [math]\displaystyle{ A' }[/math]. And if it chooses to explore, then it selects an action uniformly from [math]\displaystyle{ A' }[/math]. The targets() procedure is estimating the value function by taking max over Q-values only among admissible actions, hence, reducing function approximation errors.