The Algorithms logo
算法
关于我们捐赠

Q学习

H
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "GAMMA = 0.9\n",
    "LEARNING_RATE = 0.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    with open(\"Qtable.txt\", \"rb\") as f:\n",
    "        \"\"\"\n",
    "        This code is responsible for loading Qtable.txt if already present\n",
    "        \"\"\"\n",
    "        Q_table = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def next_best_action(state: int, Q_table: np.ndarray) -> int:\n",
    "    \"\"\"\n",
    "    Return the most suitable action value from Q_table or a random action\n",
    "    if there is no np.argmax(Q_table[state]).\n",
    "    >>> next_best_action(1, []) in [0, 1, 2, 3]\n",
    "    True\n",
    "    \"\"\"\n",
    "    action = np.argmax(Q_table[state])\n",
    "    return action if action is not None else np.random.choice([0, 1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def state_action_reward(player: object, x_food: int, y_food: int) -> tuple:\n",
    "    \"\"\"\n",
    "    This function returns state, action and reward to update the Qtable.\n",
    "    \"\"\"\n",
    "    x_agent, y_agent = player.body[0].pos\n",
    "    current_state = state(player, x_food, y_food)\n",
    "    current_action = next_best_action(state, Q_table)\n",
    "    current_reward = reward(player, x_food, y_food)\n",
    "    return (current_state, current_action, current_reward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# States to consider:\n",
    "#   * Food's relative positioning wrt the head\n",
    "# Checking for food in nine directions\n",
    "###\n",
    "# #\n",
    "###\n",
    "#   * obstruction Ahead, Right, Left\n",
    "\n",
    "def state(player: object, x_food: int, y_food: int) -> int:\n",
    "    \"\"\"\n",
    "    This function Checks for the food in nine directions and returns state.\n",
    "    \"\"\"\n",
    "    x_agent, y_agent = player.body[0].pos\n",
    "    states = []\n",
    "\n",
    "    # Code to check for obstacles in front of the agent\n",
    "    DangerAhead = (\n",
    "        (player.dirnx == -1 and x_agent + player.delx < 0)\n",
    "        or (player.dirnx == -1 and ((x_agent + player.delx, y_agent) in player.body))\n",
    "        or (player.dirnx == 1 and x_agent + player.delx > 14)\n",
    "        or (player.dirnx == 1 and ((x_agent + player.delx, y_agent) in player.body))\n",
    "        or (player.dirny == 1 and y_agent + player.dely > 14)\n",
    "        or (player.dirny == 1 and ((x_agent, y_agent + player.dely) in player.body))\n",
    "        or (player.dirny == -1 and y_agent + player.dely < 0)\n",
    "        or (player.dirny == -1 and ((x_agent, y_agent + player.dely) in player.body))\n",
    "    )\n",
    "    # Code to check for obstacles to the left of the agent\n",
    "    DangerLeft = (\n",
    "        (player.dirnx == -1 and y_agent + 1 > 14)\n",
    "        or (player.dirnx == -1 and ((x_agent, y_agent + 1) in player.body))\n",
    "        or (player.dirnx == 1 and y_agent - 1 < 0)\n",
    "        or (player.dirnx == 1 and ((x_agent, y_agent - 1) in player.body))\n",
    "        or (player.dirny == 1 and x_agent + 1 > 14)\n",
    "        or (player.dirny == 1 and ((x_agent + 1, y_agent) in player.body))\n",
    "        or (player.dirny == -1 and x_agent - 1 < 0)\n",
    "        or (player.dirny == -1 and ((x_agent - 1, y_agent) in player.body))\n",
    "    )\n",
    "    # Code to check for obstacles to the right of the agent\n",
    "    DangerRight = (\n",
    "        (player.dirnx == -1 and y_agent - 1 < 0)\n",
    "        or (player.dirnx == -1 and ((x_agent, y_agent - 1) in player.body))\n",
    "        or (player.dirnx == 1 and y_agent + 1 > 14)\n",
    "        or (player.dirnx == 1 and ((x_agent, y_agent + 1) in player.body))\n",
    "        or (player.dirny == 1 and x_agent - 1 < 0)\n",
    "        or (player.dirny == 1 and ((x_agent - 1, y_agent) in player.body))\n",
    "        or (player.dirny == -1 and x_agent + 1 > 14)\n",
    "        or (player.dirny == -1 and ((x_agent + 1, y_agent) in player.body))\n",
    "    )\n",
    "    # Code for food straight wrt head\n",
    "    FoodStraightAhead = (\n",
    "        (player.dirnx == 1 and y_agent == y_food and x_food > x_agent)\n",
    "        or (player.dirnx == -1 and y_agent == y_food and x_food < x_agent)\n",
    "        or (player.dirny == 1 and y_agent < y_food and x_food == x_agent)\n",
    "        or (player.dirny == -1 and y_agent > y_food and x_food == x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodAheadRight = (\n",
    "        (player.dirnx == 1 and y_agent < y_food and x_food > x_agent)\n",
    "        or (player.dirnx == -1 and y_agent > y_food and x_food < x_agent)\n",
    "        or (player.dirny == 1 and y_agent < y_food and x_food < x_agent)\n",
    "        or (player.dirny == -1 and y_agent > y_food and x_food > x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodAheadLeft = (\n",
    "        (player.dirnx == 1 and y_agent > y_food and x_food > x_agent)\n",
    "        or (player.dirnx == -1 and y_agent < y_food and x_food < x_agent)\n",
    "        or (player.dirny == 1 and y_agent < y_food and x_food > x_agent)\n",
    "        or (player.dirny == -1 and y_agent > y_food and x_food < x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodBehindRight = (\n",
    "        (player.dirnx == 1 and y_agent < y_food and x_food < x_agent)\n",
    "        or (player.dirnx == -1 and y_agent > y_food and x_food > x_agent)\n",
    "        or (player.dirny == 1 and y_agent > y_food and x_food < x_agent)\n",
    "        or (player.dirny == -1 and y_agent < y_food and x_food > x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodBehindLeft = (\n",
    "        (player.dirnx == 1 and y_agent > y_food and x_food < x_agent)\n",
    "        or (player.dirnx == -1 and y_agent < y_food and x_food > x_agent)\n",
    "        or (player.dirny == 1 and y_agent > y_food and x_food > x_agent)\n",
    "        or (player.dirny == -1 and y_agent < y_food and x_food < x_agent)\n",
    "    )\n",
    "    # Code for food exactly behind\n",
    "    FoodBehind = (\n",
    "        (player.dirnx == 1 and y_agent == y_food and x_food < x_agent)\n",
    "        or (player.dirnx == -1 and y_agent == y_food and x_food > x_agent)\n",
    "        or (player.dirny == 1 and y_agent > y_food and x_food == x_agent)\n",
    "        or (player.dirny == -1 and y_agent < y_food and x_food == x_agent)\n",
    "    )\n",
    "    # Code for food left\n",
    "    FoodLeft = (\n",
    "        (player.dirnx == 1 and y_agent > y_food and x_food == x_agent)\n",
    "        or (player.dirnx == -1 and y_agent < y_food and x_food == x_agent)\n",
    "        or (player.dirny == 1 and y_agent == y_food and x_food > x_agent)\n",
    "        or (player.dirny == -1 and y_agent == y_food and x_food < x_agent)\n",
    "    )\n",
    "    # Code for food right\n",
    "    FoodRight = (\n",
    "        (player.dirnx == 1 and y_agent < y_food and x_food == x_agent)\n",
    "        or (player.dirnx == -1 and y_agent > y_food and x_food == x_agent)\n",
    "        or (player.dirny == 1 and y_agent == y_food and x_food < x_agent)\n",
    "        or (player.dirny == -1 and y_agent == y_food and x_food > x_agent)\n",
    "    )\n",
    "    # Adding to states list while priortizing danger over eating food\n",
    "    states = [int(x) for x in (DangerAhead, DangerLeft, DangerRight)]\n",
    "    if sum(states) == 0:\n",
    "        states += [int(x) for x in (\n",
    "            FoodStraightAhead, FoodAheadRight, FoodAheadLeft, FoodBehindRight,\n",
    "            FoodBehindLeft, FoodBehind, FoodLeft, FoodRight\n",
    "        )]\n",
    "    else:\n",
    "        states += [0] * 8\n",
    "    state = 0\n",
    "    for i in range(11):\n",
    "        if states[i] == 1:\n",
    "            state = i\n",
    "    return state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def euler_dist(x1: int, y1: int, x2: int, y2: int) -> int:\n",
    "    \"\"\"\n",
    "    For calculation of Euler Distance.\n",
    "    \"\"\"\n",
    "    dist = (x1 - x2) ** 2 + (y1 - y2) ** 2\n",
    "    return dist ** 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reward conditions:\n",
    "#   * +10 for eating food\n",
    "#   * -12 for dying\n",
    "#   * -2 for getting closer\n",
    "#   * -7 for going away from the fruit\n",
    "\n",
    "\n",
    "def reward(player: object, x_food: int, y_food: int) -> int:\n",
    "    \"\"\"\n",
    "    This function assigns the reward to the agent according to the action taken\n",
    "    \"\"\"\n",
    "    x_agent, y_agent = player.body[0].pos\n",
    "    if x_agent == x_food and y_agent == y_food:\n",
    "        return +10\n",
    "    elif (\n",
    "        (player.dirnx == -1 and x_agent + player.delx <= 0)\n",
    "        or (\n",
    "            player.dirnx == -1\n",
    "            and ((x_agent + player.delx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.dirnx == 1 and x_agent + player.delx >= 14)\n",
    "        or (\n",
    "            player.dirnx == 1\n",
    "            and ((x_agent + player.dirnx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.dirny == 1 and y_agent + player.dely >= 14)\n",
    "        or (\n",
    "            player.dirny == 1\n",
    "            and ((x_agent + player.delx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.dirny == -1 and y_agent + player.dely <= 0)\n",
    "        or (\n",
    "            player.dirny == -1\n",
    "            and ((x_agent + player.delx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.resetDone is True)\n",
    "    ):\n",
    "        return -12\n",
    "    elif (\n",
    "        euler_dist(x_agent + player.delx, y_agent + player.dely, x_food, y_food)\n",
    "        - euler_dist(x_agent, y_agent, x_food, y_food)\n",
    "        > 0\n",
    "    ):\n",
    "        return -2\n",
    "    elif (\n",
    "        euler_dist(x_agent + player.delx, y_agent + player.dely, x_food, y_food)\n",
    "        - euler_dist(x_agent, y_agent, x_food, y_food)\n",
    "        < 0\n",
    "    ):\n",
    "        return -7\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def learn(\n",
    "        state: int, action: int, reward: int, next_state: int,\n",
    "        next_action: int, i: int, trialNumber: int, epsilon: float) -> type(None):\n",
    "    \"\"\"\n",
    "    This function is for iteratively updating the Qtable.\n",
    "    \"\"\"\n",
    "    currentQ = Q_table[state][action]\n",
    "    nextQ = Q_table[next_state][next_action]\n",
    "    # Qlearning Algorithm to get new q value for the q table.\n",
    "    newQ = (1 - LEARNING_RATE) * currentQ + LEARNING_RATE * (reward + GAMMA * nextQ)\n",
    "    Q_table[state][action] = newQ\n",
    "    state = next_state\n",
    "    currentQ = nextQ\n",
    "    if trialNumber % 100 == 0:\n",
    "        print(\"Printing Q_table: \")\n",
    "        print(Q_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
关于此算法
import numpy as np
import pickle
GAMMA = 0.9
LEARNING_RATE = 0.6
if __name__ == "__main__":
    with open("Qtable.txt", "rb") as f:
        """
        This code is responsible for loading Qtable.txt if already present
        """
        Q_table = pickle.load(f)
def next_best_action(state: int, Q_table: np.ndarray) -&gt; int:
    """
    Return the most suitable action value from Q_table or a random action
    if there is no np.argmax(Q_table[state]).
    &gt;&gt;&gt; next_best_action(1, []) in [0, 1, 2, 3]
    True
    """
    action = np.argmax(Q_table[state])
    return action if action is not None else np.random.choice([0, 1, 2, 3])
def state_action_reward(player: object, x_food: int, y_food: int) -&gt; tuple:
    """
    This function returns state, action and reward to update the Qtable.
    """
    x_agent, y_agent = player.body[0].pos
    current_state = state(player, x_food, y_food)
    current_action = next_best_action(state, Q_table)
    current_reward = reward(player, x_food, y_food)
    return (current_state, current_action, current_reward)
# States to consider:
#   * Food's relative positioning wrt the head
# Checking for food in nine directions
###
# #
###
#   * obstruction Ahead, Right, Left

def state(player: object, x_food: int, y_food: int) -&gt; int:
    """
    This function Checks for the food in nine directions and returns state.
    """
    x_agent, y_agent = player.body[0].pos
    states = []

    # Code to check for obstacles in front of the agent
    DangerAhead = (
        (player.dirnx == -1 and x_agent + player.delx &lt; 0)
        or (player.dirnx == -1 and ((x_agent + player.delx, y_agent) in player.body))
        or (player.dirnx == 1 and x_agent + player.delx &gt; 14)
        or (player.dirnx == 1 and ((x_agent + player.delx, y_agent) in player.body))
        or (player.dirny == 1 and y_agent + player.dely &gt; 14)
        or (player.dirny == 1 and ((x_agent, y_agent + player.dely) in player.body))
        or (player.dirny == -1 and y_agent + player.dely &lt; 0)
        or (player.dirny == -1 and ((x_agent, y_agent + player.dely) in player.body))
    )
    # Code to check for obstacles to the left of the agent
    DangerLeft = (
        (player.dirnx == -1 and y_agent + 1 &gt; 14)
        or (player.dirnx == -1 and ((x_agent, y_agent + 1) in player.body))
        or (player.dirnx == 1 and y_agent - 1 &lt; 0)
        or (player.dirnx == 1 and ((x_agent, y_agent - 1) in player.body))
        or (player.dirny == 1 and x_agent + 1 &gt; 14)
        or (player.dirny == 1 and ((x_agent + 1, y_agent) in player.body))
        or (player.dirny == -1 and x_agent - 1 &lt; 0)
        or (player.dirny == -1 and ((x_agent - 1, y_agent) in player.body))
    )
    # Code to check for obstacles to the right of the agent
    DangerRight = (
        (player.dirnx == -1 and y_agent - 1 &lt; 0)
        or (player.dirnx == -1 and ((x_agent, y_agent - 1) in player.body))
        or (player.dirnx == 1 and y_agent + 1 &gt; 14)
        or (player.dirnx == 1 and ((x_agent, y_agent + 1) in player.body))
        or (player.dirny == 1 and x_agent - 1 &lt; 0)
        or (player.dirny == 1 and ((x_agent - 1, y_agent) in player.body))
        or (player.dirny == -1 and x_agent + 1 &gt; 14)
        or (player.dirny == -1 and ((x_agent + 1, y_agent) in player.body))
    )
    # Code for food straight wrt head
    FoodStraightAhead = (
        (player.dirnx == 1 and y_agent == y_food and x_food &gt; x_agent)
        or (player.dirnx == -1 and y_agent == y_food and x_food &lt; x_agent)
        or (player.dirny == 1 and y_agent &lt; y_food and x_food == x_agent)
        or (player.dirny == -1 and y_agent &gt; y_food and x_food == x_agent)
    )
    # Code for food which is ahead and right wrt head
    FoodAheadRight = (
        (player.dirnx == 1 and y_agent &lt; y_food and x_food &gt; x_agent)
        or (player.dirnx == -1 and y_agent &gt; y_food and x_food &lt; x_agent)
        or (player.dirny == 1 and y_agent &lt; y_food and x_food &lt; x_agent)
        or (player.dirny == -1 and y_agent &gt; y_food and x_food &gt; x_agent)
    )
    # Code for food which is ahead and right wrt head
    FoodAheadLeft = (
        (player.dirnx == 1 and y_agent &gt; y_food and x_food &gt; x_agent)
        or (player.dirnx == -1 and y_agent &lt; y_food and x_food &lt; x_agent)
        or (player.dirny == 1 and y_agent &lt; y_food and x_food &gt; x_agent)
        or (player.dirny == -1 and y_agent &gt; y_food and x_food &lt; x_agent)
    )
    # Code for food which is ahead and right wrt head
    FoodBehindRight = (
        (player.dirnx == 1 and y_agent &lt; y_food and x_food &lt; x_agent)
        or (player.dirnx == -1 and y_agent &gt; y_food and x_food &gt; x_agent)
        or (player.dirny == 1 and y_agent &gt; y_food and x_food &lt; x_agent)
        or (player.dirny == -1 and y_agent &lt; y_food and x_food &gt; x_agent)
    )
    # Code for food which is ahead and right wrt head
    FoodBehindLeft = (
        (player.dirnx == 1 and y_agent &gt; y_food and x_food &lt; x_agent)
        or (player.dirnx == -1 and y_agent &lt; y_food and x_food &gt; x_agent)
        or (player.dirny == 1 and y_agent &gt; y_food and x_food &gt; x_agent)
        or (player.dirny == -1 and y_agent &lt; y_food and x_food &lt; x_agent)
    )
    # Code for food exactly behind
    FoodBehind = (
        (player.dirnx == 1 and y_agent == y_food and x_food &lt; x_agent)
        or (player.dirnx == -1 and y_agent == y_food and x_food &gt; x_agent)
        or (player.dirny == 1 and y_agent &gt; y_food and x_food == x_agent)
        or (player.dirny == -1 and y_agent &lt; y_food and x_food == x_agent)
    )
    # Code for food left
    FoodLeft = (
        (player.dirnx == 1 and y_agent &gt; y_food and x_food == x_agent)
        or (player.dirnx == -1 and y_agent &lt; y_food and x_food == x_agent)
        or (player.dirny == 1 and y_agent == y_food and x_food &gt; x_agent)
        or (player.dirny == -1 and y_agent == y_food and x_food &lt; x_agent)
    )
    # Code for food right
    FoodRight = (
        (player.dirnx == 1 and y_agent &lt; y_food and x_food == x_agent)
        or (player.dirnx == -1 and y_agent &gt; y_food and x_food == x_agent)
        or (player.dirny == 1 and y_agent == y_food and x_food &lt; x_agent)
        or (player.dirny == -1 and y_agent == y_food and x_food &gt; x_agent)
    )
    # Adding to states list while priortizing danger over eating food
    states = [int(x) for x in (DangerAhead, DangerLeft, DangerRight)]
    if sum(states) == 0:
        states += [int(x) for x in (
            FoodStraightAhead, FoodAheadRight, FoodAheadLeft, FoodBehindRight,
            FoodBehindLeft, FoodBehind, FoodLeft, FoodRight
        )]
    else:
        states += [0] * 8
    state = 0
    for i in range(11):
        if states[i] == 1:
            state = i
    return state
def euler_dist(x1: int, y1: int, x2: int, y2: int) -&gt; int:
    """
    For calculation of Euler Distance.
    """
    dist = (x1 - x2) ** 2 + (y1 - y2) ** 2
    return dist ** 0.5
# Reward conditions:
#   * +10 for eating food
#   * -12 for dying
#   * -2 for getting closer
#   * -7 for going away from the fruit


def reward(player: object, x_food: int, y_food: int) -&gt; int:
    """
    This function assigns the reward to the agent according to the action taken
    """
    x_agent, y_agent = player.body[0].pos
    if x_agent == x_food and y_agent == y_food:
        return +10
    elif (
        (player.dirnx == -1 and x_agent + player.delx &lt;= 0)
        or (
            player.dirnx == -1
            and ((x_agent + player.delx, y_agent + player.dely) in player.body)
        )
        or (player.dirnx == 1 and x_agent + player.delx &gt;= 14)
        or (
            player.dirnx == 1
            and ((x_agent + player.dirnx, y_agent + player.dely) in player.body)
        )
        or (player.dirny == 1 and y_agent + player.dely &gt;= 14)
        or (
            player.dirny == 1
            and ((x_agent + player.delx, y_agent + player.dely) in player.body)
        )
        or (player.dirny == -1 and y_agent + player.dely &lt;= 0)
        or (
            player.dirny == -1
            and ((x_agent + player.delx, y_agent + player.dely) in player.body)
        )
        or (player.resetDone is True)
    ):
        return -12
    elif (
        euler_dist(x_agent + player.delx, y_agent + player.dely, x_food, y_food)
        - euler_dist(x_agent, y_agent, x_food, y_food)
        &gt; 0
    ):
        return -2
    elif (
        euler_dist(x_agent + player.delx, y_agent + player.dely, x_food, y_food)
        - euler_dist(x_agent, y_agent, x_food, y_food)
        &lt; 0
    ):
        return -7
def learn(
        state: int, action: int, reward: int, next_state: int,
        next_action: int, i: int, trialNumber: int, epsilon: float) -&gt; type(None):
    """
    This function is for iteratively updating the Qtable.
    """
    currentQ = Q_table[state][action]
    nextQ = Q_table[next_state][next_action]
    # Qlearning Algorithm to get new q value for the q table.
    newQ = (1 - LEARNING_RATE) * currentQ + LEARNING_RATE * (reward + GAMMA * nextQ)
    Q_table[state][action] = newQ
    state = next_state
    currentQ = nextQ
    if trialNumber % 100 == 0:
        print("Printing Q_table: ")
        print(Q_table)