from __future__ import print_function

import argparse
from pyglet.window import key
import gym
import numpy as np
import pickle
import os
from datetime import datetime
import gzip
import json

import copy


def key_press(k, mod):
    global restart
    if k == key.ESCAPE: restart = True
    if k == key.UP:    
        a[3] = +1.0
        if a[0] == 0.0:
            a[1] = +1.0
    if k == key.LEFT:  
        a[0] = -1.0
        a[1] =  0.0  # Cut gas while turning
    if k == key.RIGHT: 
        a[0] = +1.0
        a[1] =  0.0  # Cut gas while turning
    # if k == key.DOWN:  
    #     a[2] = +0.4  # stronger brakes

def key_release(k, mod):
    if k == key.LEFT and a[0] == -1.0: 
        a[0] = 0.0
        if a[3] == 1.0:
            a[1] = 1.0
    if k == key.RIGHT and a[0] == +1.0: 
        a[0] = 0.0
        if a[3] == 1.0:
            a[1] = 1.0
    if k == key.UP:    
        a[1] = 0.0
        a[3] = 0.0
    # if k == key.DOWN:  
    #     a[2] = 0.0


def store_data(data, datasets_dir="./data"):
    # save data
    if not os.path.exists(datasets_dir):
        os.mkdir(datasets_dir)
    data_file = os.path.join(datasets_dir, 'data.pkl.gzip')
    f = gzip.open(data_file,'wb')
    pickle.dump(data, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--collect_data", action="store_true", default=True, help="Collect the data in a pickle file.")
    args = parser.parse_args()
    
    good_samples = {
        "state": [],
        "next_state": [],
        "reward": [],
        "action": [],
        "terminal" : [],
    }
    episode_samples = copy.deepcopy(good_samples)

    env = gym.make('CarRacing-v0').unwrapped
    env.reset()
    env.viewer.window.on_key_press = key_press
    env.viewer.window.on_key_release = key_release

    a = np.zeros(4, dtype=np.float32)
    
    episode_rewards = []
    good_steps = episode_steps = 0
    # Episode loop
    while True:
        episode_samples["state"] = []
        episode_samples["action"] = []
        episode_samples["next_state"] = []
        episode_samples["reward"] = []
        episode_samples["terminal"] = []
        episode_reward = 0
        state = env.reset()
        restart = False
        episode_steps = good_steps
        # State loop
        while True:
            next_state, r, done, info = env.step(a[:3])
            episode_reward += r

            episode_samples["state"].append(state)            # state has shape (96, 96, 3)
            episode_samples["action"].append(np.array(a[:3]))     # action has shape (1, 3)
            episode_samples["next_state"].append(next_state)
            episode_samples["reward"].append(r)
            episode_samples["terminal"].append(done)
            
            state = next_state
            episode_steps += 1

            if episode_steps % 1000 == 0 or done:
                print("\nstep {}".format(episode_steps))

            env.render()
            if done or restart: 
                break
        
        if not restart:
            good_steps = episode_steps

            episode_rewards.append(episode_reward)
            
            good_samples["state"].append(episode_samples["state"])
            good_samples["action"].append(episode_samples["action"])
            good_samples["next_state"].append(episode_samples["next_state"])
            good_samples["reward"].append(episode_samples["reward"])
            good_samples["terminal"].append(episode_samples["terminal"])

            print('... saving data')
            store_data(good_samples, "./data")

    env.close()