Bandit

2017/11/09

とりあえずε-greedy

import random
import numpy as np

class Arm:
    def __init__(self, prob, no):
        self.probability = prob
        self.no = no
        self.hit_count = 0
        self.total_count = 0

    # retrun boolean value
    def lot(self):
        random_float = random.random()
        self.increment_total_count()

        if random_float < self.probability:
            self.increment_hit_count()
            return True

        return False

    def increment_total_count(self):
        self.total_count += 1
        return

    def increment_hit_count(self):
        self.hit_count += 1
        return

    def observable_expected_hit(self):
        if self.total_count == 0:
            return 0

        return self.hit_count / self.total_count
        return

class Recorder:
    def __init__(self):
        self.trial_count = 0
        self.probability_list = []
        self.arm_no_list      = []
        self.hit_count   = 0
        return

    def increment_hit_count(self):
        self.hit_count += 1
        return

    def increment_trial_count(self):
        self.trial_count += 1
        return

    def push_probability(self, prob):
        self.probability_list.append(prob)
        return

    def push_arm_no(self, arm_no):
        self.arm_no_list.append(arm_no)
        return

    def output(self):
        print "trial_count : %d" % self.trial_count
        print "hit_count   : %d" % self.hit_count

        return

class Bandit:
    def __init__(self, recorder):
        self.counter = {}
        return
    # have to override
    def explore(self, arms):
        return
    def exploit(self, arms):
        return
    def execute(self, arms):
        return
    def lot_arm(self,arm):
        return

class EpsilonGreedyBandit(Bandit):
    def __init__(self, recorder, init_epsilon):
        self.recorder = recorder
        self.epsilon  = init_epsilon
        return

    def explore(self, arms):
        number_of_arms = len(arms)
        selected_index = int(number_of_arms * random.random()) - 1
        selected_arm   = arms[selected_index]

        self.lot_arm(selected_arm)

        return

    def exploit(self, arms):
        expected_hit_list = [ arm.observable_expected_hit() for arm in arms ]
        max_index = np.argmax(expected_hit_list)

        selected_arm = arms[max_index]
        self.lot_arm(selected_arm)

        return

    def lot_arm(self, arm):
        is_hit = arm.lot()
        self.recorder.increment_trial_count()
        self.recorder.push_arm_no(arm.no)
        self.recorder.push_probability(arm.probability)

        if is_hit:
            self.recorder.increment_hit_count()

    def execute(self, arms):
        epsilon = self.epsilon
        if epsilon < random.random():
            self.exploit(arms)
        else:
            self.explore(arms)

arm_attributes = [
    (0.1, 1),
    (0.2, 2),
    (0.4, 3),
]

arms = [ Arm(attribute[0], attribute[1]) for attribute in arm_attributes ]

recorder = Recorder()
EPSILON   = 0.1
TRIAL_NUM = 100
epsilon_greedy_bandit = EpsilonGreedyBandit(recorder, EPSILON)

counter = 0
for _ in range(0, TRIAL_NUM):
    counter += 1
    epsilon_greedy_bandit.execute(arms)

recorder.output()

print recorder.arm_no_list
print counter