本文介绍LRU的原理和具体c代码实现。

LRU的引文全称为Least Recently Used,即最近最少使用。该算法为cache淘汰算法,将最近最少使用的元素淘汰掉。具体代码是由hashTable和一个双向链表实现的。
使用双向链表存放key-value,使用哈希表存储双向链表的结点地址,保证访问结点的复杂度为 $O(1)$。通常来讲,对数据的读写都算做对数据的访问。

1 LRU的C代码实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>

/* define struct  start */
#define FREE(ptr) \
    free(ptr); \
    ptr = NULL;

typedef int KEY;
typedef int VAL;

typedef struct ForwardListNode {
    KEY key;
    VAL val; 
    struct ForwardListNode* prev;
    struct ForwardListNode* next;
} ForwardListNode;

typedef struct HashNode {
    ForwardListNode* dNode;
    struct HashNode* next;
} HashNode;

typedef struct HashTable {
    HashNode** bucket;
    int (*destroy)(struct HashTable* tb);
    int (*hash)(struct HashTable* tb, KEY key);
    int (*insert)(struct HashTable* tb, HashNode* node);
    int (*remove)(struct HashTable* tb, KEY key);
    HashNode* (*find)(struct HashTable* tb, KEY key);
    int bucket_num;
} HashTable;

typedef struct {
    ForwardListNode* head;
    HashTable*  ht;
    int capacity;
    int size;
} LRUCache;
/* ------------------------ */

/* LinkList function start */
int insertListNode(HashNode** head, HashNode* node)
{
    if (*head == NULL) {
        *head = node;
        return 0;
    }
    node->next = *head;
    *head = node;
    return 0;
}

int removeListNode(HashNode** head, HashNode* node)
{
    if (*head == node) {
        *head = (*head)->next;
        FREE(node);
        return 0;
    }

    HashNode* prev = *head;
    HashNode* cur = (*head)->next;
    while (cur) {
        if (cur == node) {
            prev->next = cur->next;
            FREE(node);
            return 0; 
        }
        prev = prev->next;
        cur = cur->next;
    } 
    return -1;
}

HashNode* findListNode(HashNode* head, KEY key)
{
    if (head == NULL) {
        return NULL;
    }

    while (head) {
       if (head->dNode->key == key) {
            return head;
       } 

        head = head->next;
    }        
    return NULL;
}

void ListNodeDestroy(HashNode* head) 
{
    if (head == NULL) {
        return;
    }

    HashNode* cur = head;
    while (cur) {
        HashNode* prev = cur->next;
        FREE(cur);
        cur = prev;
    }
}
/* ------------------------ */

/* HashTable function */
HashNode* initHashNode(ForwardListNode* node) 
{
    HashNode* hNode = (HashNode*)malloc(sizeof(HashNode));
    if (hNode == NULL) {
        return NULL;
    }

    hNode->dNode = node;
    hNode->next = NULL;
    return hNode;
}

int hashFunc(HashTable* tb, KEY key)
{
    return key % tb->bucket_num;
}

HashNode* findHashNode(HashTable* tb, KEY key)
{
    int index = tb->hash(tb, key);
    assert(index >= 0 && index < tb->bucket_num);
    return findListNode(tb->bucket[index], key);
}

int insertHashNode(HashTable* tb, HashNode* node)
{
    int index = tb->hash(tb, node->dNode->key);
    assert(index >= 0 && index < tb->bucket_num);
    return insertListNode(&tb->bucket[index], node); 
}

int removeHashNode(HashTable* tb, KEY key)
{
    int index = tb->hash(tb, key);
    assert(index >= 0 && index < tb->bucket_num);
    if (tb->bucket[index] == NULL) {
        return -1;
    }
    HashNode** head = &tb->bucket[index];
    HashNode* node = findListNode(*head, key);
    return removeListNode(head, node); 
}

HashTable* initHashTableMem(void)
{
   HashTable* ht = (HashTable*)malloc(sizeof(HashTable));
   if (ht == NULL) {
       return NULL;
   }
   memset(ht, 0, sizeof(HashTable));
   return ht;
}

int destroyHashTable(struct HashTable* tb)
{
    for (int i = 0; i < tb->bucket_num; ++i) {
        HashNode* head = tb->bucket[i];
        if (head == NULL) {
            continue;
        }
        ListNodeDestroy(head);
    }
    FREE(tb->bucket);
    FREE(tb);
    return 0;
}

int initializeHashTable(struct HashTable* tb, int num)
{
    tb->hash = hashFunc;
    tb->insert = insertHashNode;
    tb->remove = removeHashNode;
    tb->find = findHashNode;
    tb->destroy = destroyHashTable;

    if (num == 0) {
        printf("input bucket num must greater then 0.\n");
        return -1;
    }
    tb->bucket_num = num;
    tb->bucket = (HashNode**)malloc(sizeof(HashNode*) * num);
    if (tb->bucket == NULL) {
        printf("malloc failed.\n");
        return -1;
    }
    memset(tb->bucket, 0, sizeof(HashNode*) * num);
    return 0;
}
/*------------------------------------------*/

/* forward_list function */
ForwardListNode* initForwardListNode(KEY key, VAL val)
{
    ForwardListNode* DNode = (ForwardListNode*)malloc(sizeof(ForwardListNode));
    if (DNode == NULL) {
        return NULL;
    }        
    DNode->key = key;
    DNode->val = val;
    DNode->prev = DNode;
    DNode->next = DNode;
    return DNode;
}

void insertToHead(ForwardListNode* head, ForwardListNode* node)
{
    if (head->next != head) {
        node->next = head->next;
        head->next->prev = node;

        head->next = node;
        node->prev = head;
    } else {
        head->next = node;
        head->prev = node;
        
        node->prev = head;
        node->next = head;
    }
}

void removeFromTail(ForwardListNode* head)
{
    if (head->next != head) {
        ForwardListNode* tail = head->prev;
        tail->next->prev = tail->prev;
        tail->prev->next = tail->next;
        FREE(tail);
    }
}

void updateForwardList(ForwardListNode* head, ForwardListNode* node)
{
    if (head == NULL) {
        return;
    }

    node->prev->next = node->next;
    node->next->prev = node->prev;

    insertToHead(head, node);
}

void destroyForwardListNode(ForwardListNode* head) 
{
    if (head == NULL) {
        return;
    }

    ForwardListNode* cur = head->next;
    while (cur && cur != head) {
        ForwardListNode* next = cur->next;
        FREE(cur);
        cur = next; 
    }
    FREE(head);
}
/*------------------------------------------*/

/* LRUCache function */
LRUCache* lRUCacheCreate(int capacity)
{
    LRUCache* lru = (LRUCache*)malloc(sizeof(LRUCache));        
    if (lru == NULL) {
        return NULL;
    }
    memset(lru, 0, sizeof(LRUCache));

    lru->head = initForwardListNode(-1, -1);
    if (lru->head == NULL) {
        FREE(lru);
    }

    lru->ht = initHashTableMem();
    if (lru->ht == NULL) {
        FREE(lru->head);
        FREE(lru);
        return NULL;
    }

    int ret = initializeHashTable(lru->ht, capacity);
    if (ret != 0) {
        FREE(lru->ht);
        FREE(lru->head);
        FREE(lru);
        return NULL;
    }

    lru->size = 0;
    lru->capacity = capacity;
    return lru;
}

int lRUCacheGet(LRUCache* obj, int key)
{
    HashTable* ht = obj->ht;
    ForwardListNode* DouList = obj->head;
    HashNode* node = ht->find(ht, key);
    if (node == NULL) {
        return -1;    
    }

    updateForwardList(DouList, node->dNode); 
    return node->dNode->val;
}

void lRUCachePut(LRUCache* obj, int key, int value)
{
    HashTable* ht = obj->ht;
    ForwardListNode* DouListHead = obj->head;
    int index = ht->hash(ht, key);
    HashNode* hNode = findListNode(ht->bucket[index], key);
    if (hNode == NULL) {
        ForwardListNode* dNode = initForwardListNode(key, value);
        if (dNode == NULL) {
            printf("mem alloc failed.\n");
            return;
        }

        HashNode* new_node = initHashNode(dNode);
        if (new_node == NULL) {
            FREE(dNode);
            printf("mem alloc failed.\n");
            return;
        }
        insertToHead(DouListHead, dNode);
        insertListNode(&ht->bucket[index], new_node);
        ++obj->size;

        // del node from tail
        if (obj->size > obj->capacity) {
            ForwardListNode* tail = DouListHead->prev; 
            ht->remove(ht, tail->key);
            removeFromTail(DouListHead);
            obj->size--;
        }
    } else {
       if (hNode->dNode->val != value) {
           hNode->dNode->val = value;
       }
       updateForwardList(DouListHead, hNode->dNode);
    }
}

void lRUCacheFree(LRUCache* obj)
{
    destroyForwardListNode(obj->head);
    obj->ht->destroy(obj->ht);
    FREE(obj);    
}

1.1 关键函数功能简介

LRU的代码实现可以分为4部分。

1.1.1 单向链表的操作

  1. int insertListNode(HashNode** head, HashNode* node)
    功能说明:向单链表的头部插入一个结点

  2. int removeListNode(HashNode** head, HashNode* node)
    功能说明:从单链表中移出指定的结点

  3. HashNode* findListNode(HashNode* head, KEY key)
    功能说明: 在单链表中查找包含指定key的结点

  4. void ListNodeDestroy(HashNode* head)
    功能说明:销毁单链表

1.1.2 双向链表的操作

  1. ForwardListNode* initForwardListNode(KEY key, VAL val)
    功能说明:构造双向链表的结点。

  2. void insertToHead(ForwardListNode* head, ForwardListNode* node)
    功能说明:向双向链表的头部插入一个结点。

  3. void removeFromTail(ForwardListNode* head)
    功能说明:从双向链表的尾部删除一个结点

  4. void updateForwardList(ForwardListNode* head, ForwardListNode* node)
    功能说明:更新双向链表中的结点,即将指定结点移动到双向链表的头部。

  5. void destroyForwardListNode(ForwardListNode* head)
    功能说明:销毁双向链表

温馨提示

双向链表的遍历操作和单向链表不同,通过判断cur == head作为判断循环终止的条件。

1.1.3 哈希表的操作

  1. HashNode* initHashNode(ForwardListNode* node)
    功能说明:初始化哈希表的结点

  2. int hashFunc(HashTable* tb, KEY key)
    功能说明:哈希函数

  3. HashNode* findHashNode(HashTable* tb, KEY key)
    功能说明:根据key查找哈希结点

  4. int insertHashNode(HashTable* tb, HashNode* node)
    功能说明: 插入一个哈希结点

  5. int removeHashNode(HashTable* tb, KEY key)
    功能说明: 删除一个哈希结点

  6. HashTable* initHashTableMem(void)
    功能说明: 初始化哈希表

  7. int destroyHashTable(struct HashTable* tb)
    功能说明: 销毁哈希表

  8. int initializeHashTable(struct HashTable* tb, int num)
    功能说明: 初始化哈希表中的成员

1.1.4 LRU相关的操作

  1. LRUCache* lRUCacheCreate(int capacity)
    功能说明:创建LRU

  2. int lRUCacheGet(LRUCache* obj, int key)
    功能说明:查询lru中是否存在key

  3. void lRUCachePut(LRUCache* obj, int key, int value)
    功能说明:向LRU中插入指定的key-value

  4. void lRUCacheFree(LRUCache* obj)
    功能说明: 销毁LRU

2 LRU程序的测试代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
void printHashTable(HashTable* tb);
void printLruCache(LRUCache* lru);
int main(void)
{
    LRUCache* lru = lRUCacheCreate(2);

    printf("Put (1,1) ");
    lRUCachePut(lru, 1, 1);
    printLruCache(lru);

    printf("Put (2,2) ");
    lRUCachePut(lru, 2, 2);
    printLruCache(lru);

    printf("Get 1 ");
    int ret = lRUCacheGet(lru, 1);
    printf("get(1) ret:%d.\n", ret);
    printLruCache(lru);

    printf("Put (3,3) ");
    lRUCachePut(lru, 3, 3);
    printLruCache(lru);

    printf("Get 2 2 ");
    ret = lRUCacheGet(lru, 2);
    printf("get(2) ret:%d.\n", ret);
    printLruCache(lru);

    printf("Put (4,4) ");
    lRUCachePut(lru, 4, 4);
    printLruCache(lru);

    printf("Get 1 ");
    ret = lRUCacheGet(lru, 1);
    printf("get(1) ret:%d.\n", ret);
    printLruCache(lru);

    printf("Put (3,3) ");
    lRUCachePut(lru, 3, 3);
    printLruCache(lru);
    
    lRUCacheFree(lru);
    return 0;
} 

void printLruCache(LRUCache* lru)
{
    DouListNode* cur = lru->head->next;
    printf("LRUCache num: %d.\n", lru->size);
    while (cur != lru->head) {
        printf("[%d|%d|-]->", cur->key, cur->val); 
        cur = cur->next; 
    }
    printf("\n\n");
    return;
}

void printListNode(HashNode* head)
{
    if (head == NULL) {
        return;
    }
    while (head) {
        printf("[%d|%d|-]->", head->dNode->key, head->dNode->val);
        head = head->next;
    }
    printf("\n");
}

void printHashTable(HashTable* tb)
{
    for (int i = 0; i < tb->bucket_num; ++i) {
        HashNode* head = tb->bucket[i];
        if (head == NULL) {
            continue;
        }

        printf("index:%d, ", i);
        printListNode(head);
    }
}