AVL树的Python实现

avatar 2021年3月27日18:05:26 评论 757 次浏览

AVL是平衡树,平衡因子概念什么的就不阐述了,主要是在不平衡时候如何旋转。(1)右子树右节点插入:左旋转。(2)左子树左节点插入:右旋转。(3)右子树左节点插入:右旋转后左旋转。(4)左子树右节点插入:左旋转后右旋转。深入了解可以参考:https://www.wulaoer.org/?p=1598

  • 所谓的左旋和右旋都是以子树为原点的:如b是a的子树,那么旋转就围绕b来进行。
  • 如果b是a的左子树,那么就围绕b将a向右旋转,看着就像是a直接掉下来了,掉成了b的右子树。
  • 如果b是a的右子树,那么就围绕b将a向左旋转,看着就像是a直接掉下来了,掉成了b的左子树。

AVL树有左右孩子的概念,所以,在实现AVL树之前,有必要先引入Python中类的概念,先来个MWE。

Python的类

#!/usr/bin/python3
#coding:utf-8
#~~~~~~~~~~~~www.wulaoer.org 吴老二个人博客~~~~~~~
class Car:
    # 初始化
    def __init__(self, brand, gas):
        self.brand = brand
        self.gas = gas
        print('一辆新的', self.brand, '被生产出来了!')
    # 定义方法
    def add_gas(self, amount):
        self.gas += amount
    # 定义方法
    def show_gas(self):
        print('剩余汽油:', self.gas)

# 实例化
benz = Car('Benz', 100)
# 调用方法
benz.add_gas(200)
benz.show_gas()

AVL树的实现

#!/usr/bin/python3
#coding:utf-8
#~~~~~~~~~~~~www.wulaoer.org 吴老二个人博客~~~~~~~
import numpy as np
import time


class TreeNode(object):
    # 定义每个节点的数据,左孩子右孩子,平衡因子
    def __init__(self):
        self.data = 0
        self.left = None
        self.right = None
        self.height = 0


class BTree(object):

    def __init__(self):
        self.root = None

    def __Max(self, h1, h2):
        if h1 > h2:
            return h1
        elif h1 <= h2:
            return h2

    # 左左情况,向右旋转
    def __LL(self, r):
        node = r.left
        r.left = node.right
        node.right = r
        r.height = self.__Max(self.getHeight(r.right), self.getHeight(r.left)) + 1
        node.height = self.__Max(self.getHeight(node.right), self.getHeight(node.left)) + 1
        return node

    # 右右,左旋
    def __RR(self, r):
        node = r.right
        r.right = node.left
        node.left = r
        r.height = self.__Max(self.getHeight(r.right), self.getHeight(r.left)) + 1
        node.height = self.__Max(self.getHeight(node.right), self.getHeight(node.left)) + 1
        return node

    # 左右,先左旋再右旋
    def __LR(self, r):
        r.left = self.__RR(r.left)
        return self.__LL(r)

    # 右左,先右旋再左旋
    def __RL(self, r):
        r.right = self.__LL(r.right)
        return self.__RR(r)

    # r是self.root
    def __insert(self, data, r):
        if r == None:
            node = TreeNode()
            node.data = data
            return node
        elif data == r.data:
            return r
        elif data < r.data:
            r.left = self.__insert(data, r.left)
            # 左高右低
            if self.getHeight(r.left) - self.getHeight(r.right) >= 2:
                if data < r.left.data:
                    r = self.__LL(r)
                else:
                    r = self.__LR(r)
        else:
            r.right = self.__insert(data, r.right)
            if self.getHeight(r.right) - self.getHeight(r.left) >= 2:
                if data > r.right.data:
                    r = self.__RR(r)
                else:
                    r = self.__RL(r)

        r.height = self.__Max(self.getHeight(r.left), self.getHeight(r.right)) + 1
        return r

    # 删除data节点
    def __delete(self, data, r):
        if r == None:
            return r

        elif r.data == data:
            # 如果只有右子树,直接将右子树赋值到此节点
            if r.left == None:
                return r.right
            # 如果只有左子树,直接将左子树赋值到此节点
            elif r.right == None:
                return r.left
            # 如果同时有左右子树
            else:
                # 左子树高度大于右子树
                if self.getHeight(r.left) > self.getHeight(r.right):
                    # 找到最右节点 返回节点值 并删除该节点
                    node = r.left
                    while(node.right != None):
                        node = node.right
                    r = self.__delete(node.data, r)
                    r.data = node.data
                    return r
                # 右子树高度大于左子树
                else:
                    node = r.right
                    while node.left != None:
                        node = node.left
                    r = self.__delete(node.data, r)
                    r.data = node.data
                    return r

        elif data < r.data:
            # 在左子树中删除
            r.left = self.__delete(data, r.left)
            # 右子树高度与左子树高度相差超过1
            if self.getHeight(r.right) - self.getHeight(r.left) >= 2:
                if self.getHeight(r.right.left) > self.getHeight(r.right.right):
                    r = self.__RL(r)
                else:
                    r = self.__RR(r)

        elif data > r.data:
            # 右子树中删除
            r.right = self.__delete(data, r.right)
            # 左子树与右子树高度相差超过1
            if self.getHeight(r.left)-self.getHeight(r.right) >= 2:
                if self.getHeight(r.left.right)>self.getHeight(r.left.left):
                    r = self.__LR(r)
                else:
                    r = self.__LL(r)
        # 更新高度
        r.height = self.__Max(self.getHeight(r.left), self.getHeight(r.right))+1
        return r

    # 先序遍历
    def __show(self, root):
        if root != None:
            # print (root.data)
            self.__show(root.left)
            self.__show(root.right)
        else:
            return 0

    def Insert(self, data):
        self.root = self.__insert(data, self.root)
        return self.root

    def Delete(self, data):
        self.root = self.__delete(data, self.root)

    # 求结点的高度
    def getHeight(self, node):
        if node == None:
            return -1
        # print node
        return node.height

    def Show(self):
        self.__show(self.root)


if __name__ == '__main__':
    bi = BTree()
    insert_time = []
    delete_time = []
    for right_interval in range(1000, 500000, 50000):
        array = np.random.randint(1, 100, right_interval)
        since = time.time()
        for i in array:
            bi.Insert(i)
        end = time.time() - since
        insert_time.append(end)
        print('AVL insert : ' + str(right_interval) + ' Data: ' + str(end) + 's')

    for right_interval in range(1000, 500000, 50000):
        array = np.random.randint(1, 100, right_interval)
        since = time.time()
        for i in array:
            bi.Delete(i)
        end = time.time() - since
        delete_time.append(end)
        print('AVL delete : ' + str(right_interval) + ' Data: ' + str(end) + 's')
        for i in array:
            bi.Insert(i)

运行结果

AVL insert : 1000 Data: 0.0071680545806884766s
AVL insert : 51000 Data: 0.3300008773803711s
AVL insert : 101000 Data: 0.6365611553192139s
AVL insert : 151000 Data: 0.969146728515625s
AVL insert : 201000 Data: 1.2616446018218994s
AVL insert : 251000 Data: 1.572361946105957s
AVL insert : 301000 Data: 1.8841772079467773s
AVL insert : 351000 Data: 2.36479115486145s
AVL insert : 401000 Data: 2.8726420402526855s
AVL insert : 451000 Data: 3.1849629878997803s
AVL delete : 1000 Data: 0.0023190975189208984s
AVL delete : 51000 Data: 0.018627166748046875s
AVL delete : 101000 Data: 0.037735939025878906s
AVL delete : 151000 Data: 0.05629992485046387s
AVL delete : 201000 Data: 0.06780004501342773s
AVL delete : 251000 Data: 0.12815308570861816s
AVL delete : 301000 Data: 0.1400139331817627s
AVL delete : 351000 Data: 0.12156200408935547s
AVL delete : 401000 Data: 0.1407301425933838s
AVL delete : 451000 Data: 0.16689515113830566s

Process finished with exit code 0
avatar

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: