REINFORCE on CartPole-v0
(C++)
Overview
Code
#include "cubeai/base/cubeai_config.h"
#if defined(USE_PYTORCH) && defined(USE_GYMFCPP)
#include "cubeai/base/cubeai_types.h"
#include "cubeai/rl/algorithms/pg/simple_reinforce.h"
#include "cubeai/rl/trainers/pytorch_rl_agent_trainer.h"
#include "cubeai/ml/distributions/torch_categorical.h"
#include "cubeai/optimization/optimizer_type.h"
#include "cubeai/optimization/pytorch_optimizer_factory.h"
#include "gymfcpp/gymfcpp_types.h"
#include "gymfcpp/cart_pole_env.h"
#include "gymfcpp/time_step.h"
#include <torch/torch.h>
#include <boost/python.hpp>
#include <iostream>
#include <string>
#include <any>
namespace rl_example_13{
namespace F = torch::nn::functional;
using cubeai::real_t;
using cubeai::uint_t;
using cubeai::torch_tensor_t;
using cubeai::rl::algos::pg::SimpleReinforce;
using cubeai::rl::algos::pg::ReinforceConfig;
using cubeai::rl::PyTorchRLTrainer;
using cubeai::rl::PyTorchRLTrainerConfig;
using cubeai::ml::stats::TorchCategorical;
using gymfcpp::CartPole;
class PolicyImpl: public torch::nn::Module
{
public:
PolicyImpl();
torch_tensor_t forward(torch_tensor_t);
template<typename StateTp>
std::tuple<uint_t, real_t> act(const StateTp& state);
template<typename LossValuesTp>
void update_policy_loss(const LossValuesTp& vals);
void step_backward_policy_loss();
torch_tensor_t compute_loss(){return loss_;}
private:
torch::nn::Linear fc1_;
torch::nn::Linear fc2_;
// placeholder for the loss
torch_tensor_t loss_;
};
PolicyImpl::PolicyImpl()
:
fc1_(torch::nn::Linear(4, 16)),
fc2_(torch::nn::Linear(16, 2))
{
register_module("fc1", fc1_);
register_module("fc2", fc2_);
}
template<typename LossValuesTp>
void
PolicyImpl::update_policy_loss(const LossValuesTp& vals){
torch_tensor_t torch_vals = torch::tensor(vals);
// specify that we require the gradient
loss_ = torch::cat(torch::tensor(vals, torch::requires_grad())).sum();
}
void
PolicyImpl::step_backward_policy_loss(){
loss_.backward();
}
torch_tensor_t
PolicyImpl::forward(torch_tensor_t x){
x = F::relu(fc1_->forward(x));
x = fc2_->forward(x);
return F::softmax(x, F::SoftmaxFuncOptions(0));
}
template<typename StateTp>
std::tuple<uint_t, real_t>
PolicyImpl::act(const StateTp& state){
auto torch_state = torch::tensor(state);
auto probs = forward(torch_state);
auto m = TorchCategorical(&probs, nullptr);
auto action = m.sample();
return std::make_tuple(action.item().toLong(), m.log_prob(action).item().to<real_t>());
}
TORCH_MODULE(Policy);
}
int main(){
using namespace example;
try{
Py_Initialize();
auto gym_module = boost::python::import("__main__");
auto gym_namespace = gym_module.attr("__dict__");
auto env = CartPole("v0", gym_namespace, false);
env.make();
Policy policy;
auto optimizer_ptr = std::make_unique<torch::optim::Adam>(policy->parameters(), torch::optim::AdamOptions(1e-2));
// reinforce options
ReinforceConfig opts = {1000, 100, 100, 100, 1.0e-2, 0.1, 195.0, true};
SimpleReinforce<CartPole, Policy> algorithm(opts, policy);
PyTorchRLTrainerConfig trainer_config{1.0e-8, 1001, 50};
PyTorchRLTrainer<CartPole, SimpleReinforce<CartPole, Policy>> trainer(trainer_config, algorithm, std::move(optimizer_ptr));
trainer.train(env);
}
catch(const boost::python::error_already_set&){
PyErr_Print();
}
catch(std::exception& e){
std::cout<<e.what()<<std::endl;
}
catch(...){
std::cout<<"Unknown exception occured"<<std::endl;
}
return 0;
}
#else
#include <iostream>
int main(){
std::cout<<"This example requires PyTorch and gymfcpp. Reconfigure cubeai with PyTorch and gymfcpp support."<<std::endl;
return 0;
}
#endif