#include <stdio.h>
#include "string.h"
#include "pico/stdlib.h"
#include "psram/psmalloc.h"
#include "psram/rp2_psram.h"


void psmalloc_init(void) {
    psmalloc_list.idxfree = 0;
    psmalloc_list.idxused = 0;
    for (uint32_t i = 0; i < PSMALLOCLIST; i++) {
        psmalloc_list.size[i] = 0;
        psmalloc_list.ptr[i] = NULL;
    }
}


void* psram_malloc(size_t size) {
    static const size_t align = 16;
    if (psmalloc_list.idxfree >= PSMALLOCLIST) {
        return NULL;} // No free slots

    size_t aligned_size = (size + align - 1) & ~(align - 1);
    for (uint32_t i = 0; i < PSMALLOCLIST; i++) {
        if (psmalloc_list.ptr[i] == NULL) {
            uint8_t* ptr = (uint8_t*)PSRAM_LOCATION;
            if (!ptr) {
                return NULL;} // PSRAM not initialized  
            ptr += i * aligned_size;
            psmalloc_list.ptr[i] = ptr;
            psmalloc_list.size[i] = aligned_size;
            psmalloc_list.idxused++;
            if (i == psmalloc_list.idxfree) {
                // Update idxfree to next free slot
                while (psmalloc_list.idxfree < PSMALLOCLIST && psmalloc_list.ptr[psmalloc_list.idxfree] != NULL) {
                    psmalloc_list.idxfree++;
                }
            }
            return (void*)ptr;
        }
    }
    return NULL; // No free slots found
}


void* psram_calloc(size_t num, size_t size) {
    size_t total_size = num * size;
    void* ptr = psram_malloc(total_size);
    if (ptr) {
        memset(ptr, 0, total_size);
    }
    return ptr;
}

void psram_free(void* ptr) {
    for (uint32_t i = 0; i < PSMALLOCLIST; i++) {
        if (psmalloc_list.ptr[i] == ptr) {
            psmalloc_list.ptr[i] = NULL;
            psmalloc_list.size[i] = 0;
            psmalloc_list.idxused--;
            if (i < psmalloc_list.idxfree) {
                psmalloc_list.idxfree = i;
            }
            return;
        }
    }
}   