#include #include #include #include #include #define ALIGNMENT 8 #define ALIGN(size) (((size) + (ALIGNMENT - 1)) & ~(ALIGNMENT - 1)) typedef struct block_header { size_t size; struct block_header* next; int free; } block_header_t; #define BLOCK_SIZE sizeof(block_header_t) static uint8_t* heap_base = (uint8_t*)HEAP_START; static uint8_t* heap_end = (uint8_t*)(HEAP_START + HEAP_SIZE); static block_header_t* free_list = NULL; void heap_init(void) { free_list = (block_header_t*)heap_base; free_list->size = HEAP_SIZE - BLOCK_SIZE; free_list->next = NULL; free_list->free = 1; } void* malloc(size_t size) { size = ALIGN(size); block_header_t* curr = free_list; while (curr) { if (curr->free && curr->size >= size) { /* Split if there's space for another block */ if (curr->size >= size + BLOCK_SIZE + ALIGNMENT) { block_header_t* new_block = (block_header_t*)((uint8_t*)curr + BLOCK_SIZE + size); new_block->size = curr->size - size - BLOCK_SIZE; new_block->next = curr->next; new_block->free = 1; curr->next = new_block; curr->size = size; } curr->free = 0; return (void*)((uint8_t*)curr + BLOCK_SIZE); } curr = curr->next; } printd("Malloc failed due to lack of free memory\n"); return NULL; } void* calloc(size_t nmemb, size_t size) { size_t total = nmemb * size; void* ptr = malloc(total); if (ptr) { memset(ptr, 0, total); } return ptr; } void* realloc(void* ptr, size_t size) { if (!ptr) { return malloc(size); } if (size == 0) { free(ptr); return NULL; } block_header_t* block = (block_header_t*)((uint8_t*)ptr - BLOCK_SIZE); if (block->size >= size) { return ptr; } void* new_ptr = malloc(size); if (new_ptr) { memcpy(new_ptr, ptr, block->size); free(ptr); } return new_ptr; } void free(void* ptr) { if (!ptr) { return; } block_header_t* block = (block_header_t*)((uint8_t*)ptr - BLOCK_SIZE); block->free = 1; /* Forward coalescing */ if (block->next && block->next->free) { block->size += BLOCK_SIZE + block->next->size; block->next = block->next->next; } /* Backward coalescing */ block_header_t* prev = NULL; block_header_t* curr = free_list; while (curr && curr != block) { prev = curr; curr = curr->next; } if (prev && prev->free) { prev->size += BLOCK_SIZE + block->size; prev->next = block->next; } }