from dataclasses import dataclass
import logging
import numpy as np


class RNNType:
    LTSM = 0  # default
    GRU = 1


class MCTSIndex:
    UCB = 0  # default
    AlphaZero = 1


class MCTSMode:
    Standard = 0
    Lipschitz = 1  # default


class MCTSPolicy:
    Greedy = 0  # default
    Softmax = 1


@dataclass
class Config:
    # task-setting
    CORRIDOR_LENGTH: int = 100
    INITIAL_POSITION: int = 50
    NUM_EPISODE: int = 30_000
    GAMMA: float = 0.98
    MAX_TIMESTEP: int = 1000
    STATE_ACTION: bool = False  # TODO: set prob state action?
    # rpg
    ALPHA: float = 0.05  # 0.02
    CLIP_GRAD_NORM: float = 1.0
    RNN_TYPE: RNNType = RNNType.LTSM
    RNN_SIZE: int = 8
    DROPOUT: float = 0
    # ppo
    NUM_EPOCHS: int = 3,
    CLIP_PPO: float = 0.2
    IS_LOSS_DISCOUNTED: bool = True
    # mcts
    C: float = np.sqrt(2)
    MCTS_INDEX: int = MCTSIndex.UCB
    MCTS_MODE: int = MCTSMode.Lipschitz
    MCTS_POLICY: int = MCTSPolicy.Greedy
    # lazy-mcts
    EXPANSION_ALL_AT_ONCE: bool = True
    BETA: float = 10.0  # used if POLICY_TYPE is not 'greedy'
    M: float = 10_000_000  # <- length_of_corridor * 100 if algo is pgmcts
    # pg-mcts (or alphaZero)
    MIXING_COEFFICIENT_ADAPTATION: bool = False
    MIN_RHO: float = 0.1
    LAMBDA: float = 0.2
    """"" [pg-mcts case] if MIX_PROBABILITY_ADAPTATION==False, LAMBDA is mixing prob, otherwise coefficient.
          [alphaZero case] This defines rollout-policy as (1 - LAMBDA) * rpg_policy + LAMBDA * mcts_policy. """
    # log-setting
    LOG_INTERVAL: int = 100
    LOG_LEVEL: int = logging.DEBUG
    LOG_DEFAULT_FORMAT: str = "[%(levelname)1.1s %(asctime)s %(stack_module)s:%(stack_lineno)d] %(message)s"
    LOG_SIMPLE_FORMAT: str = "\t%(message)s"
    LOG_OUTPUT_FILES: bool = True
    LOG_RESULT_FILENAME: str = 'results/hoge.csv'
    LOG_CONFIG_FILENAME: str = 'results/config.csv'


#############################
# 30 - 0
#############################
ppo_30_0 = Config(
    # task-setting
    CORRIDOR_LENGTH = 30,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 10_000,
    GAMMA=0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.06,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # ppo
    NUM_EPOCHS = 3,
    CLIP_PPO=0.2,
    IS_LOSS_DISCOUNTED=False,
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
ppo_30_0.LOG_RESULT_FILENAME = "results/ppo"  \
                               + "_lossDiscounted" + "{}".format(ppo_30_0.IS_LOSS_DISCOUNTED) \
                               + "_epoch" + "{}".format(ppo_30_0.NUM_EPOCHS) \
                               + "_alp" + "{:.5g}".format(ppo_30_0.ALPHA) \
                               + "_clip" + "{:.5g}".format(ppo_30_0.CLIP_PPO) \
                               + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

rpg_30_0 = Config(
    # task-setting
    CORRIDOR_LENGTH = 30,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 10_000,
    GAMMA=0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.2,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
rpg_30_0.LOG_RESULT_FILENAME = "results/rpg"  \
                               + "_alp" + "{:.5g}".format(rpg_30_0.ALPHA) \
                               + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

mcts_30_0 = Config(
    # task-setting
    CORRIDOR_LENGTH = 30,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # mcts
    C = 0.3,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Greedy,
    EXPANSION_ALL_AT_ONCE = True,
    M = 10_000_000,  # <- length_of_corridor * 100 if algo is pgmcts
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
mcts_30_0.LOG_RESULT_FILENAME = "results/mcts"  \
                                + "_c" + "{:.5g}".format(mcts_30_0.C)\
                                + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

pgmcts_30_0 = Config(
    # task-setting
    CORRIDOR_LENGTH = 30,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.2,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 0.1,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Softmax,
    EXPANSION_ALL_AT_ONCE = True,
    BETA = 100.0,  # used if POLICY_TYPE is not 'greedy'
    M = 3_000,  # <- length_of_corridor * 100
    # pgmcts
    MIXING_COEFFICIENT_ADAPTATION = False,
    MIN_RHO = 0.1,
    LAMBDA = 0.2,  # if MIX_PROBABILITY_ADAPTATION==False, it is mixing prob; if True, it is coefficient
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
pgmcts_30_0.LOG_RESULT_FILENAME = "results/pgmcts"  \
                                  + "_alp" + "{:.5g}".format(pgmcts_30_0.ALPHA) \
                                  + "_c" + "{:.5g}".format(pgmcts_30_0.C)\
                                  + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

pgmcts_adpt_30_0 = Config(
    # task-setting
    CORRIDOR_LENGTH = 30,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.2,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 0.1,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Softmax,
    EXPANSION_ALL_AT_ONCE = True,
    BETA = 100.0,  # used if POLICY_TYPE is not 'greedy'
    M = 3_000,  # <- length_of_corridor * 100
    # pgmcts (or alphazero)
    MIXING_COEFFICIENT_ADAPTATION = True,
    MIN_RHO = 0.1,
    LAMBDA = 1.0,  # if MIX_PROBABILITY_ADAPTATION==False, it is mixing prob; if True, it is coefficient
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
pgmcts_adpt_30_0.LOG_RESULT_FILENAME = "results/pgmcts_adpt"  \
                                       + "_alp" + "{:.5g}".format(pgmcts_adpt_30_0.ALPHA) \
                                       + "_c" + "{:.5g}".format(pgmcts_adpt_30_0.C)\
                                       + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

alphazero_30_0 = Config(
    # task-setting
    CORRIDOR_LENGTH = 30,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.02,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 1,  # np.sqrt(2),
    MCTS_INDEX = MCTSIndex.AlphaZero,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Greedy,
    EXPANSION_ALL_AT_ONCE = True,
    M = 10_000_000,  # <- length_of_corridor * 100 if algo is pg-mcts
    # pgmcts (or alphazero)
    LAMBDA = 1.0,  # rollout-policy = (1 - LAMBDA) * rpg_policy + LAMBDA * mcts_policy
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
alphazero_30_0.LOG_RESULT_FILENAME = "results/alphazero"  \
                                     + "_alp" + "{:.5g}".format(alphazero_30_0.ALPHA) \
                                     + "_c" + "{:.5g}".format(alphazero_30_0.C) \
                                     + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

#############################
# 100 - 50
#############################
ppo_100_50 = Config(
    # task-setting
    CORRIDOR_LENGTH = 100,
    INITIAL_POSITION = 50,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.03,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # ppo
    NUM_EPOCHS=3,
    CLIP_PPO=0.2,
    IS_LOSS_DISCOUNTED=False,
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.DEBUG,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
ppo_100_50.LOG_RESULT_FILENAME = "results/ppo" \
                                 + "_lossDiscounted" + "{}".format(ppo_100_50.IS_LOSS_DISCOUNTED) \
                                 + "_epoch" + "{}".format(ppo_100_50.NUM_EPOCHS) \
                                 + "_alp" + "{:.5g}".format(ppo_100_50.ALPHA) \
                                 + "_clip" + "{:.5g}".format(ppo_100_50.CLIP_PPO) \
                                 + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"


rpg_100_50 = Config(
    # task-setting
    CORRIDOR_LENGTH = 100,
    INITIAL_POSITION = 50,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.1,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.DEBUG,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
rpg_100_50.LOG_RESULT_FILENAME = "results/rpg"  \
                                 + "_alp" + "{:.5g}".format(rpg_100_50.ALPHA) \
                                 + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

mcts_100_50 = Config(
    # task-setting
    CORRIDOR_LENGTH = 100,
    INITIAL_POSITION = 50,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # mcts
    C = 0.3,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Greedy,
    EXPANSION_ALL_AT_ONCE = True,
    M = 10_000_000,  # <- length_of_corridor * 100 if algo is pgmcts
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
mcts_100_50.LOG_RESULT_FILENAME = "results/mcts"  \
                                  + "_c" + "{:.5g}".format(mcts_100_50.C)\
                                  + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

pgmcts_100_50 = Config(
    # task-setting
    CORRIDOR_LENGTH = 100,
    INITIAL_POSITION = 50,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.1,  # 0,3
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 0.1,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Softmax,
    EXPANSION_ALL_AT_ONCE = True,
    BETA = 100.0,  # used if POLICY_TYPE is not 'greedy'
    M = 5_000,  # <- length_of_corridor * 100
    # pgmcts
    MIXING_COEFFICIENT_ADAPTATION = False,
    MIN_RHO = 0.1,
    LAMBDA = 0.2,  # if MIX_PROBABILITY_ADAPTATION==False, it is mixing prob; if True, it is coefficient
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
pgmcts_100_50.LOG_RESULT_FILENAME = "results/pgmcts"  \
                                    + "_alp" + "{:.5g}".format(pgmcts_100_50.ALPHA) \
                                    + "_c" + "{:.5g}".format(pgmcts_100_50.C)\
                                    + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

pgmcts_adpt_100_50 = Config(
    # task-setting
    CORRIDOR_LENGTH = 100,
    INITIAL_POSITION = 50,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.1,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 0.1,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Softmax,
    EXPANSION_ALL_AT_ONCE = True,
    BETA = 100.0,  # used if POLICY_TYPE is not 'greedy'
    M = 5_000,  # <- length_of_corridor * 100
    # pgmcts (or alphazero)
    MIXING_COEFFICIENT_ADAPTATION = True,
    MIN_RHO = 0.1,
    LAMBDA = 1.0,  # if MIX_PROBABILITY_ADAPTATION==False, it is mixing prob; if True, it is coefficient
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
pgmcts_adpt_100_50.LOG_RESULT_FILENAME = "results/pgmcts_adpt"  \
                                         + "_alp" + "{:.5g}".format(pgmcts_adpt_100_50.ALPHA) \
                                         + "_c" + "{:.5g}".format(pgmcts_adpt_100_50.C)\
                                         + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

alphazero_100_50 = Config(
    # task-setting
    CORRIDOR_LENGTH = 100,
    INITIAL_POSITION = 50,
    NUM_EPISODE = 10_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.01,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 1,  # np.sqrt(2),
    MCTS_INDEX = MCTSIndex.AlphaZero,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Greedy,
    EXPANSION_ALL_AT_ONCE = True,
    M = 10_000_000,  # <- length_of_corridor * 100 if algo is pg-mcts
    # pgmcts (or alphazero)
    LAMBDA = 1.0,  # rollout-policy = (1 - LAMBDA) * rpg_policy + LAMBDA * mcts_policy
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
alphazero_100_50.LOG_RESULT_FILENAME = "results/alphazero"  \
                                       + "_alp" + "{:.5g}".format(alphazero_100_50.ALPHA) \
                                       + "_c" + "{:.5g}".format(alphazero_100_50.C) \
                                       + "_len{1:03d}_ini{2:03d}_{0:%Y%m%d_%H%M%S}.csv"

###################
# for demonstration
###################
# rpg_conf_demo = rpg_30_0
# rpg_conf_demo.CORRIDOR_LENGTH = 20
# rpg_conf_demo.LOG_OUTPUT_FILES = False
# rpg_conf_demo.LOG_RESULT_FILENAME = "results/rpg_demo_{0:%Y%m%d_%H%M%S}.csv"
# rpg_conf_demo.LOG_LEVEL = logging.DEBUG
#
# lazymcts_conf_demo = mcts_30_0
# lazymcts_conf_demo.CORRIDOR_LENGTH = 20
# lazymcts_conf_demo
#
# pgmcts_conf_demo = pgmcts_30_0
# pgmcts_conf_demo.CORRIDOR_LENGTH = 20
# pgmcts_conf_demo    STATE_ACTION = False,
#     LOG_RESULT_FILENAME="results/mcts_demo_{0:%Y%m%d_%H%M%S}.csv",
#
# lazymcts_conf_demo = alphazero_30_0
# lazymcts_conf_demo.CORRIDOR_LENGTH = 20

ppo_conf_demo = Config(
    # task-setting
    CORRIDOR_LENGTH = 20,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 3_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.05,  # 0.1,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # ppo
    NUM_EPOCHS=3,
    CLIP_PPO=0.2,
    IS_LOSS_DISCOUNTED=False,
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.DEBUG,
    LOG_OUTPUT_FILES = True,
    LOG_RESULT_FILENAME="results/ppo_demo_{0:%Y%m%d_%H%M%S}.csv",
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)

rpg_conf_demo = Config(
    # task-setting
    CORRIDOR_LENGTH = 20,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 3_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.1,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.DEBUG,
    LOG_OUTPUT_FILES = True,
    LOG_RESULT_FILENAME="results/rpg_demo_{0:%Y%m%d_%H%M%S}.csv",
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)

lazymcts_conf_demo = Config(
    # task-setting
    CORRIDOR_LENGTH = 20,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 3_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # mcts
    C = 0.3,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Greedy,
    EXPANSION_ALL_AT_ONCE = True,
    M = 10_000_000,  # <- length_of_corridor * 100 if algo is pgmcts
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_RESULT_FILENAME="results/mcts_demo_{0:%Y%m%d_%H%M%S}.csv",
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)

pgmcts_conf_demo = Config(
    # task-setting
    CORRIDOR_LENGTH = 20,
    INITIAL_POSITION = 0,
    NUM_EPISODE= 3_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.2,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 0.1,
    MCTS_INDEX = MCTSIndex.UCB,
    MCTS_MODE = MCTSMode.Lipschitz,
    MCTS_POLICY = MCTSPolicy.Softmax,
    EXPANSION_ALL_AT_ONCE = True,
    BETA = 100.0,  # used if POLICY_TYPE is not 'greedy'
    M = 2_000,  # <- length_of_corridor * 100
    # pgmcts
    MIXING_COEFFICIENT_ADAPTATION = False,
    MIN_RHO = 0.1,
    LAMBDA = 0.2,  # if MIX_PROBABILITY_ADAPTATION==False, it is mixing prob; if True, it is coefficient
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_RESULT_FILENAME="results/pgmcts_demo_{0:%Y%m%d_%H%M%S}.csv",
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)


lazymcts_conf_demo = Config(
    # task-setting
    CORRIDOR_LENGTH = 20,
    INITIAL_POSITION = 0,
    NUM_EPISODE = 3_000,
    GAMMA = 0.98,
    MAX_TIMESTEP = 1000,
    STATE_ACTION = False,
    # rpg
    ALPHA = 0.02,
    DROPOUT = 0,
    CLIP_GRAD_NORM = 1.0,
    # mcts
    C = 1,  # np.sqrt(2),
    MCTS_INDEX = MCTSIndex.AlphaZero,
    MCTS_MODE=MCTSMode.Lipschitz,
    MCTS_POLICY=MCTSPolicy.Greedy,
    EXPANSION_ALL_AT_ONCE = True,
    M = 10_000_000,  # <- length_of_corridor * 100 if algo is pg-mcts
    # pgmcts (or alphazero)
    LAMBDA = 1.0,  # rollout-policy = (1 - LAMBDA) * rpg_policy + LAMBDA * mcts_policy
    # log
    LOG_INTERVAL = 100,
    LOG_LEVEL = logging.INFO,
    LOG_OUTPUT_FILES = True,
    LOG_RESULT_FILENAME="results/alphazero_demo_{0:%Y%m%d_%H%M%S}.csv",
    LOG_CONFIG_FILENAME = "results/config_{0:%Y%m%d_%H%M%S}.csv"
)
