QLearning on CartPole-v0 (Python)

In this example we use Q-learning to train an agent on the CartPole-v0 environment. The state for this environment exposed to the agent, consists of four floating point values. Thus, for this state we cannot use a tabular based method. Therefore, we employ the wrapper class StateAggregationCartPoleEnv that uses bins to cast each variable in the state vector as an integer.

def plot_running_avg(avg_rewards):

    running_avg = np.empty(avg_rewards.shape[0])
    for t in range(avg_rewards.shape[0]):
        running_avg[t] = np.mean(avg_rewards[max(0, t-100) : (t+1)])
    plt.plot(running_avg)
    plt.xlabel("Number of episodes")
    plt.ylabel("Reward")
    plt.title("Running average")
    plt.show()
if __name__ == '__main__':
    GAMMA = 1.0
    ALPHA = 0.1
    EPS = 1.0

    env = StateAggregationCartPoleEnv(n_states=10)

    agent_config = TDAlgoConfig(gamma=GAMMA, alpha=ALPHA,
                                n_itrs_per_episode=50000,
                                n_episodes=10000,
                                policy=EpsilonGreedyPolicy(n_actions=n_actions(env),
                                                           eps=EPS, decay_op=EpsilonDecayOption.INVERSE_STEP))

    agent = QLearning(agent_config)

    trainer_config = RLSerialTrainerConfig(n_episodes=50000, output_msg_frequency=5000)
    trainer = RLSerialAgentTrainer(trainer_config, agent=agent)
    trainer.train(env)

    plot_running_avg(agent.total_rewards)