キン〇マハムスター佐藤  2021/02/23更新

Pythonでの優先度付きキュー まとめ


はじめに

Pythonの優先度付きキューについてまとめます。

バージョンはPython3です。


優先度付きキューとは:

優先度付きキューはリストに似た「最小値を素早く取り出す」ことのできるデータ型です。 具体的には、

・最小値を O(logN)で取り出す

・要素を O(logN)で追加する

ことが出来ます(通常だとO(N))

「要素の追加」と「最小値の抽出」を繰り返すような時に使います。


Pythonでの実装方法:

Pythonにはheapqという優先度付きキューを実装したモジュールが用意されています。

基本例:

import heapq
a = [5,7,9,2,4,1]
print(a)

#実行結果: [5, 7, 9, 2, 4, 1]

heapq.heapify(a)
for _ in range(len(a)):
  print(heapq.heappop(a),end=" ")

#実行結果: 1 2 4 5 7 9 
#小さい順に取り出されている

#最大値を求める場合はあらかじめ要素に-1を掛け、最後にもう一度-1を掛けます。
a = list(map(lambda x: int(x) * (-1), [5,7,9,2,4,1]))
print(a)
#実行結果: [-5, -7, -9, -2, -4, -1]

heapq.heapify(a)
for _ in range(len(a)):
  print(heapq.heappop(a)*-1,end=" ,")

#実行結果: 9 7 5 4 2 1 


元の配列が2次元の場合:


import heapq
a = [[10,5],[25,7],[3,9],[4,2],[1,4],[6,1]]
print(a)

#実行結果: [[10, 5], [25, 7], [3, 9], [4, 2], [1, 4], [6, 1]]

heapq.heapify(a)
for _ in range(len(a)):
  print(heapq.heappop(a),end=" ")

#実行結果: [1, 4] [3, 9] [4, 2] [6, 1] [10, 5] [25, 7] 
#2次元配列の場合は要素内の先頭の数字順になることがわかると思います。


要素を追加する実装例はこちら

import heapq

a = []
for value in [5,7,9,2,4,1]:
  heapq.heappush(a,value)
print(a)
#実行結果: [1, 4, 2, 7, 5, 9] *昇順ソートされていない

for _ in range(len(a)):
  print(heapq.heappop(a),end=" ")

#実行結果:  1 2 4 5 7 9 


#新しく追加する場合はheapifyする必要がないこと
#heappushした時点では昇順ソートされていないことに注意

使うメソッドは3つです。

・heapq.heapify(リスト)  リストを優先度付きキューに変換

・heapq.heappop( (リスト) )   優先度付きキューから最小値を取り出す

・heapq.heappush( (リスト) , 要素)   優先度付きキューに要素を挿入



Classに応用する:

 リストから最小値順に取り出す方法は簡単ですが要素が複雑な場合もありますね。そのときはclassを使います。 例えば座標系と移動costがあって、最小コストのルートを求める場合とか


#5 x 5の座標と移動コストの表があり、左上から右下へ移動するときの最小コストとルートを求めたい
#表
# [1, 7, 1, 2, 1]
# [3, 1, 3, 4, 6]
# [2, 1, 8, 1, 3]
# [4, 2, 2, 2, 5]
# [1, 7, 3, 2, 1]


import heapq

class State:
    def __init__(self, x, y, cost, ref):
        self.x = x
        self.y = y
        self.cost = cost
        self.ref = ref

    def __lt__(self, state):
        return self.cost < state.cost

    def disp(self):
        print("--")
        for i in range(h):
            for j in range(w):
                if i == self.y and j == self.x:
                    print("*", end="")
                else:
                    print(" ", end="")
                print(t[i][j], end="")
            print("")

def test(sx,sy,gx,gy):
    dy = [0, 1, 0, -1]
    dx = [1, 0, -1, 0]
    openque = []
    closedque = set()
    state = State(sx, sy, t[sy][sx], None)
    heapq.heappush(openque, state)
    while len(openque) > 0:
        state = heapq.heappop(openque)
        if state.x == gx and state.y == gy:
            return state
        if state in closedque:
            continue
        closedque.add(state)

        for i in range(4):
            nx = state.x + dx[i]
            ny = state.y + dy[i]
            if nx < 0 or w <= nx or ny < 0 or h <= ny:
                continue
            ncost = state.cost + int(t[ny][nx])
            heapq.heappush(openque, State(nx, ny, ncost, state))

t = [
 [1, 7, 1, 2, 1],
 [3, 1, 3, 4, 6],
 [2, 1, 8, 1, 3],
 [4, 2, 2, 2, 5],
 [1, 7, 3, 2, 1]
]
h=5
w=5

st = test(0,0,w-1,h-1)
print(st.cost)
while not st == None:
    st.disp()
    st = st.ref



実行結果:



15
--
 1 7 1 2 1
 3 1 3 4 6
 2 1 8 1 3
 4 2 2 2 5
 1 7 3 2*1
--
 1 7 1 2 1
 3 1 3 4 6
 2 1 8 1 3
 4 2 2 2 5
 1 7 3*2 1
--
 1 7 1 2 1
 3 1 3 4 6
 2 1 8 1 3
 4 2 2*2 5
 1 7 3 2 1
--
 1 7 1 2 1
 3 1 3 4 6
 2 1 8 1 3
 4 2*2 2 5
 1 7 3 2 1
--
 1 7 1 2 1
 3 1 3 4 6
 2 1 8 1 3
 4*2 2 2 5
 1 7 3 2 1
--
 1 7 1 2 1
 3 1 3 4 6
 2*1 8 1 3
 4 2 2 2 5
 1 7 3 2 1
--
 1 7 1 2 1
 3*1 3 4 6
 2 1 8 1 3
 4 2 2 2 5
 1 7 3 2 1
--
 1 7 1 2 1
*3 1 3 4 6
 2 1 8 1 3
 4 2 2 2 5
 1 7 3 2 1
--
*1 7 1 2 1
 3 1 3 4 6
 2 1 8 1 3
 4 2 2 2 5
 1 7 3 2 1




タイトルとURLをコピーしました