Source code for train.state


:class:`~train.State` objects can be used to represent the agent's state. They can be used to save the recent observations seen by agent and process them before passing to the :func:`~train.Agent.act` method. The following example saves last 2 observations (images) after transforming them (crop, scale etc.) and computes the difference between them which can be useful for tracking motion:

.. code:: python

    from train import State

    class MyState(State):

        def __init__(self, **kwargs):
            super(MyState, self).__init__(length=2, **kwargs)

        def process_observation(self, observation):
            x = observation
            x = x[35:-15, :, :] # crop
            x =, [.299, .587, .114]) # grayscale
            x = x / 255 # scale
            return x

        def process_state(self, state):
            prev, current = state
            diff = current - prev
            return diff.reshape(diff.shape + (1, ))

Custom state objects can be passed to agent during initialization:

.. code:: python

    state = MyState()
    agent = MyAgent(state=state, env=env)

import random
from collections import namedtuple, deque

import numpy as np

from .utils import zeros_like

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

[docs]class State(): """ Core class to represent agent's state. Saves recent observations seen by agent. Args: length (int): Number of recent observations to save. zeros (array_like): Array of zeros with same shape as each observation that will be used to pad initial states when number of recent observations is smaller than length of state. """ def __init__(self, length=0, zeros=None): self.length = length self.zeros = zeros self.reset()
[docs] def update(self, observation): """Update the current state based on new observation. Args: observation (array_like): Observation returned by environment. """ assert observation is not None observation = self.process_observation(observation) if self.zeros is None: self.zeros = zeros_like(observation) self.pad() if self.length == 0: = observation else:
[docs] def process_observation(self, observation): """Process observation before saving it. Args: observation (array_like): Observation returned by environment. Returns: array_like: Processed observation. """ return observation
[docs] def get(self, asarray=True, dtype='float32'): """Get the current state. Args: asarray (bool): If ``True`` returns an :class:`~numpy.ndarray`. dtype (~numpy.dtype): Data type of the returned value. Returns: (array_like, list): Processed state. """ if self.length == 0: state = else: state = list( state = self.process_state(state) if asarray: state = np.array(state, dtype=dtype) return state
[docs] def process_state(self, state): """Process state before passing it to :func:`~train.Agent.act`. Args: state (array_like, list): List of recent observations. Returns: (array_like, list): Processed state. """ return state
[docs] def reset(self): """Reset current state. """ if self.length == 0: = None else: = deque(maxlen=self.length) if self.zeros is not None: self.pad()
def pad(self): assert self.zeros is not None if self.length == 0 and is None: = self.zeros else: while len( < self.length:
# See class RingBuffer(): def __init__(self, maxlen): self.maxlen = maxlen self.reset() def append(self, item): maxlen = self.maxlen if len( < maxlen or maxlen <= 0:[self.pos] = item self.pos += 1 if maxlen > 0: self.pos %= maxlen def get(self): return[self.pos:] +[:self.pos] def last(self): """Return last transition. Returns: Transition: Last transition. Raises: IndexError: When it is empty. """ return[self.pos - 1] def sample(self, batch_size): return random.sample(, batch_size) def reset(self): """Reset transitions. """ = [] self.pos = 0 def __len__(self): return len(
[docs]class Transitions(RingBuffer): """ Queue like data structure to save recent transitions observed by agent. Can be used as a replay buffer for algorithms like DQN. Args: maxlen (int): Number of recent transitions to save. When negative, there is no limit on the number of transitions saved. """
[docs] def get(self, **kwargs): """Get all transitions. Returns: (list, Transition): List of transitions or a Transition object containing lists of values. """ data = super(Transitions, self).get() return self.get_transitions(data, **kwargs)
[docs] def sample(self, batch_size, **kwargs): """Randomly sample transitions. Args: batch_size (int): Number of transitions to sample. Returns: (list, Transition): List of transitions or a Transition object containing lists of values. """ data = super(Transitions, self).sample(batch_size) return self.get_transitions(data, **kwargs)
# See def get_transitions(self, data, **kwargs): transpose = kwargs.get('transpose', True) asarray = kwargs.get('asarray', True) dtype = kwargs.get('dtype', 'float32') if not transpose: return data data = Transition(*zip(*data)) if not asarray: return data states = np.array(data.state, dtype=dtype) actions = np.array(data.action, dtype='int32') next_states = np.array(data.next_state, dtype=dtype) rewards = np.array(data.reward, dtype=dtype) dones = np.array(data.done, dtype='uint8') return Transition(state=states, action=actions, next_state=next_states, reward=rewards, done=dones)