#ifndef ZONE_H
#define ZONE_H

#include <stdint.h>
#include <stddef.h>

#define ZONE_USE_MALLOC         0
#define ZONE_USE_MMAP           1
#define ZONE_USE_WASM_BULKMEM   2

#ifndef ZONE_BACKEND
	#ifdef __unix__
		#define ZONE_BACKEND ZONE_USE_MMAP
	#else
		#define ZONE_BACKEND ZONE_USE_MALLOC
	#endif
#endif

#if !defined(ZONE_NOSTDLIB) && (ZONE_BACKEND != ZONE_USE_MMAP && ZONE_BACKEND != ZONE_USE_MALLOC)
	#define ZONE_NOSTDLIB
#endif

typedef struct ZoneFrame ZoneFrame;
struct ZoneFrame {
	ZoneFrame *prev, *next;
	uint8_t *beg, *end;
	uint8_t data[];
};

typedef struct ZoneState ZoneState;

struct ZoneState {
	ZoneFrame *zf;
	ZoneState *last;
	uint8_t *beg;
};

typedef struct {
	ZoneFrame *cur, *tail;
	ZoneState *last;
} Zone;

void zn_free(Zone *z);
void zn_reset(Zone *z);

void zn_begin(Zone *z);
void zn_end(Zone *z);
void zn_save(Zone *z, ZoneState *m);
void zn_load(Zone *z, ZoneState *m);

void *zn_alloc(Zone *z, ptrdiff_t n);
void *zn_alloc_align(Zone *z, ptrdiff_t n, size_t align);
void *zn_realloc(Zone *z, void *ptr, ptrdiff_t oldsz, ptrdiff_t newsz);
void *zn_realloc_align(Zone *z, void *ptr, ptrdiff_t oldsz, ptrdiff_t newsz, size_t align);

char *zn_strdup(Zone *z, const char *s);
void *zn_zeroed(void *, size_t);

#define alignof _Alignof
#define zn_new(z, t) (t*)zn_zeroed(zn_alloc_align(z, sizeof(t), alignof(t)), sizeof(t))

#ifdef ZONE_IMPL

#ifndef ZONE_PAGE_MULT
#define ZONE_PAGE_MULT 2
#endif

#define UPTR_ALIGN(x, align) ((x)+((-x)&((align)-1)))
#define PTR_ALIGN(ptr, align) (typeof(ptr)) (((uintptr_t)ptr + (align - 1)) & -align)

#if ZONE_BACKEND == ZONE_USE_MMAP || ZONE_BACKEND == ZONE_USE_MALLOC
	#include <stdio.h>
	#include <stdlib.h>
	[[noreturn]] void zn_abort(const char *msg) {
		fprintf(stderr, "%s\n", msg);
		abort();
	}
#elif ZONE_BACKEND == ZONE_USE_WASM_BULKMEM
	__attribute((import_name("alert"))) void js_alert(const char *msg);
	__attribute((import_name("abort"))) [[noreturn]] void js_abort(void);
	[[noreturn]] void zn_abort(const char *msg) {
		js_alert(msg);
		js_abort();
	}
#else
	#error "zn_abort not implemented for the current platform -- using __builtin_trap() instead"
	[[noreturn]] void zn_abort(cont char *msg) {
		(void)msg;
		__builtin_trap();
	}
#endif

#define ZONE_MEMSET __builtin_memset
#define ZONE_MEMCPY __builtin_memcpy

#if ZONE_BACKEND == ZONE_USE_MMAP

	#include <sys/mman.h>
	#include <unistd.h>

	#define ZONE_PAGE_SIZE (ZONE_PAGE_MULT * sysconf(_SC_PAGE_SIZE))
	#define ZONE_PAGE_FAIL MAP_FAILED

	static void *zn_pg_alloc(size_t n) {
		return mmap(NULL, n * ZONE_PAGE_SIZE, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
	}
	static void zn_pg_free(void *p, size_t n) {
		munmap(p, n);
	}

#elif ZONE_BACKEND == ZONE_USE_MALLOC

	#define ZONE_PAGE_SIZE (ZONE_PAGE_MULT * 4 * 1024)
	#define ZONE_PAGE_FAIL NULL

	static void *zn_pg_alloc(size_t n) {
		return malloc(n * ZONE_PAGE_SIZE);
	}
	static void zn_pg_free(void *p, size_t n) {
		free(p);
	}

#elif ZONE_BACKEND == ZONE_USE_WASM_BULKMEM

	#define ZONE_PAGE_SIZE 65536
	#define ZONE_PAGE_FAIL (void*)(uintptr_t)-1
	extern uint8_t __heap_base;
	static void *zn_pg_free_adr = 0;
	static void *zn_pg_alloc(size_t n) {
		if (zn_pg_free_adr) {
			void *p = zn_pg_free_adr;
			zn_pg_free_adr = *(void **)zn_pg_free_adr;
			return p;
		}
		void *p = &__heap_base + __builtin_wasm_memory_size(0);
		__builtin_wasm_memory_grow(0, 1);
		return p;
	}
	static void zn_pg_free(void *p, size_t n) {
		uintptr_t adr = ((uintptr_t)p & ~0xFFFF);
		n = (n + 0xFFFF) >> 16;
		while (n--) {
			void *pg = (void *)(adr + (n << 16));
			*(void **)pg = zn_pg_free_adr;
			zn_pg_free_adr = pg;
		}
	}

#else

	#error "unknown or unsupported zone backend"

#endif

static inline size_t zn_pg_fit(size_t cap) {
	return cap + (-cap & (ZONE_PAGE_SIZE - 1));
}

static ZoneFrame *zn_zf_new(size_t capacity) {
	ZoneFrame *zf = zn_pg_alloc(capacity / ZONE_PAGE_SIZE);
	if (zf == ZONE_PAGE_FAIL) zn_abort("failed to allocate zone frame\n");
	zf->prev = NULL;
	zf->next = NULL;
	zf->beg = zf->data;
	zf->end = (uint8_t*)zf + capacity;
	return zf;
}

static void zn_zf_free(ZoneFrame *z) {
	zn_pg_free(z, (uintptr_t)z->end - (uintptr_t)z->beg);
}

void *zn_alloc_align(Zone *z, ptrdiff_t n, size_t align) {
	ZoneFrame *zf = z->cur;
	uint8_t *aligned;
	for (;;) {
		if (!zf) {
			zf = zn_zf_new(zn_pg_fit(n + sizeof(ZoneFrame) + align - 1));
			zf->prev = z->tail;
			if (z->tail) z->tail->next = zf;
			z->tail = zf;
			z->cur = zf;
		}
		aligned = PTR_ALIGN(zf->beg, align);
		if (aligned + n < zf->end) break;
		zf = zf->next;
	}
	zf->beg = aligned + n;
	return aligned;
}

void *zn_realloc_align(Zone *z, void *ptr, ptrdiff_t oldsz, ptrdiff_t newsz, size_t align) {
	if (!ptr || !oldsz) return zn_alloc_align(z, newsz, align);
	if (z->cur && ptr == z->cur->beg - oldsz && z->cur->beg - oldsz + newsz < z->cur->end) {
		z->cur->beg -= oldsz;
		z->cur->beg += newsz;
		return ptr;
	} else {
		void *p = zn_alloc_align(z, newsz, align);
		ZONE_MEMCPY(p, ptr, oldsz);
		return p;
	}
}

static inline size_t zn_align_for(size_t n) {
	size_t a = 1;
	while (n > a) a <<= 1;
	return a;
}

void *zn_alloc(Zone *z, ptrdiff_t n) {
	return zn_alloc_align(z, n, zn_align_for(n));
}

void *zn_realloc(Zone *z, void *ptr, ptrdiff_t oldsz, ptrdiff_t newsz) {
	return zn_realloc_align(z, ptr, oldsz, newsz, zn_align_for(newsz));
}

void *zn_calloc(Zone *z, size_t n, size_t size) {
	void *p = zn_alloc_align(z, n * size, zn_align_for(size));
	ZONE_MEMSET(p, 0, n * size);
	return p;
}

void zn_free(Zone *z) {
	ZoneFrame *a = z->tail, *b;
	while (a) {
		b = a->prev;
		zn_zf_free(a);
		a = b;
	}
}

void zn_reset(Zone *z) {
	ZoneFrame *zf = z->tail;
	while (zf) {
		zf->beg = zf->data;
		if (!zf->prev) break;
		zf = zf->prev;
	}
	if (zf) z->cur = zf;
}

void zn_save(Zone *z, ZoneState *m) {
	m->zf = z->cur;
	m->beg = z->cur ? z->cur->beg : 0;
	m->last = z->last;
}

void zn_load(Zone *z, ZoneState *m) {
	if (m->zf) {
		while (z->cur != m->zf) {
			z->cur->beg = z->cur->data;
			z->cur = z->cur->prev;
		}
		m->zf->beg = m->beg;
	} else {
		zn_reset(z);
	}
	z->last = m->last;
}

/* utils */

#ifndef ZONE_NOSTDLIB
char *zn_strdup(Zone *z, const char *s) {
	size_t n = strlen(s) + 1;
	char *d = zn_alloc_align(z, n, 1);
	ZONE_MEMCPY(d, s, n);
	return d;
}
#endif

void *zn_zeroed(void *ptr, size_t n) {
	ZONE_MEMSET(ptr, 0, n);
	return ptr;
}

void zn_begin(Zone *z) {
	ZoneState m = { 0 };
	zn_save(z, &m);
	ZoneState *p = zn_new(z, ZoneState);
	*p = m;
	z->last = p;
}

void zn_end(Zone *z) {
	zn_load(z, z->last);
}

#endif
#endif