State¶
State
objects can be used to represent the agent’s state. They can be used to save the recent observations seen by agent and process them before passing to the act()
method. The following example saves last 2 observations (images) after transforming them (crop, scale etc.) and computes the difference between them which can be useful for tracking motion:
from train import State
class MyState(State):
def __init__(self, **kwargs):
super(MyState, self).__init__(length=2, **kwargs)
def process_observation(self, observation):
x = observation
x = x[35:-15, :, :] # crop
x = np.dot(x, [.299, .587, .114]) # grayscale
x = x / 255 # scale
return x
def process_state(self, state):
prev, current = state
diff = current - prev
return diff.reshape(diff.shape + (1, ))
Custom state objects can be passed to agent during initialization:
state = MyState()
agent = MyAgent(state=state, env=env)
State¶
-
class
train.
State
(length=0, zeros=None)[source]¶ Core class to represent agent’s state. Saves recent observations seen by agent.
Parameters: - length (int) – Number of recent observations to save.
- zeros (array_like) – Array of zeros with same shape as each observation that will be used to pad initial states when number of recent observations is smaller than length of state.
-
get
(asarray=True, dtype='float32')[source]¶ Get the current state.
Parameters: Returns: Processed state.
Return type: (array_like, list)
-
process_observation
(observation)[source]¶ Process observation before saving it.
Parameters: observation (array_like) – Observation returned by environment. Returns: Processed observation. Return type: array_like
Transitions¶
-
class
train.
Transitions
(maxlen)[source]¶ Queue like data structure to save recent transitions observed by agent. Can be used as a replay buffer for algorithms like DQN.
Parameters: maxlen (int) – Number of recent transitions to save. When negative, there is no limit on the number of transitions saved. -
get
(**kwargs)[source]¶ Get all transitions.
Returns: List of transitions or a Transition object containing lists of values. Return type: (list, Transition)
-
last
()¶ Return last transition.
Returns: Last transition. Return type: Transition Raises: IndexError
– When it is empty.
-
reset
()¶ Reset transitions.
-
sample
(batch_size, **kwargs)[source]¶ Randomly sample transitions.
Parameters: batch_size (int) – Number of transitions to sample. Returns: List of transitions or a Transition object containing lists of values. Return type: (list, Transition)
-