State

State objects can be used to represent the agent’s state. They can be used to save the recent observations seen by agent and process them before passing to the act() method. The following example saves last 2 observations (images) after transforming them (crop, scale etc.) and computes the difference between them which can be useful for tracking motion:

from train import State

class MyState(State):

    def __init__(self, **kwargs):
        super(MyState, self).__init__(length=2, **kwargs)

    def process_observation(self, observation):
        x = observation
        x = x[35:-15, :, :] # crop
        x = np.dot(x, [.299, .587, .114]) # grayscale
        x = x / 255 # scale
        return x

    def process_state(self, state):
        prev, current = state
        diff = current - prev
        return diff.reshape(diff.shape + (1, ))

Custom state objects can be passed to agent during initialization:

state = MyState()
agent = MyAgent(state=state, env=env)

State

class train.State(length=0, zeros=None)[source]

Core class to represent agent’s state. Saves recent observations seen by agent.

Parameters:
  • length (int) – Number of recent observations to save.
  • zeros (array_like) – Array of zeros with same shape as each observation that will be used to pad initial states when number of recent observations is smaller than length of state.
get(asarray=True, dtype='float32')[source]

Get the current state.

Parameters:
  • asarray (bool) – If True returns an ndarray.
  • dtype (dtype) – Data type of the returned value.
Returns:

Processed state.

Return type:

(array_like, list)

process_observation(observation)[source]

Process observation before saving it.

Parameters:observation (array_like) – Observation returned by environment.
Returns:Processed observation.
Return type:array_like
process_state(state)[source]

Process state before passing it to act().

Parameters:state (array_like, list) – List of recent observations.
Returns:Processed state.
Return type:(array_like, list)
reset()[source]

Reset current state.

update(observation)[source]

Update the current state based on new observation.

Parameters:observation (array_like) – Observation returned by environment.

Transitions

class train.Transitions(maxlen)[source]

Queue like data structure to save recent transitions observed by agent. Can be used as a replay buffer for algorithms like DQN.

Parameters:maxlen (int) – Number of recent transitions to save. When negative, there is no limit on the number of transitions saved.
get(**kwargs)[source]

Get all transitions.

Returns:List of transitions or a Transition object containing lists of values.
Return type:(list, Transition)
last()

Return last transition.

Returns:Last transition.
Return type:Transition
Raises:IndexError – When it is empty.
reset()

Reset transitions.

sample(batch_size, **kwargs)[source]

Randomly sample transitions.

Parameters:batch_size (int) – Number of transitions to sample.
Returns:List of transitions or a Transition object containing lists of values.
Return type:(list, Transition)

Transition

class train.Transition(state, action, reward, next_state, done)