#ifndef ARGS_H
#define ARGS_H

#include "str.h"

typedef struct {
	const char **arg;
	const char *opt_end;
	int no_opts;
} ArgsState;

typedef enum {
	ARG_OK = 0,
	ARG_END = -1,
	ARG_BAD = -2,
	ARG_EMPTY = -3
} ArgResult;

ArgsState args_begin(const char **argv);
ArgResult arg_getv(ArgsState *a, const char *fmt, Str *arg, ...);
#define arg_get(a, fmt, arg, ...)\
	    arg_getv(a, fmt, arg __VA_OPT__(,) __VA_ARGS__, NULL)

#ifdef ARGS_IMPL
#include <stdarg.h>

ArgsState args_begin(const char **argv) {
	return (ArgsState) { argv + 1, NULL, 0 };
}

static int arg_opt_find(const char **opts, Str key) {
	for (int i = 0; opts[i]; i++) {
		const char *o = opts[i];
		if (*o == ':') o++;
		if (str_eql(str_from_cstr(o), key)) return i;
	}
	return -1;
}

static ArgResult arg_param(ArgsState *a, Str name, Str rem, Str *arg) {
	if (rem.n > 0) {
		*arg = rem;
		return ARG_OK;
	} else if (a->arg[1]) {
		*arg = str_from_cstr(*++a->arg);
		return ARG_OK;
	} else {
		*arg = name;
		return ARG_EMPTY;
	}
}

static int arg_got_long(ArgsState *a, const char **opts, int *optv, Str *arg) {
	Cut key = str_cut(str_from_cstr(*a->arg + 2), '=');
	if (opts && optv) {
		int o = arg_opt_find(opts, key.head);
		if (o < 0) {
			*arg = key.head;
			return ARG_BAD;
		}
		if (opts[o][0] == ':') {
			int x = arg_param(a, key.head, key.tail, arg);
			if (x < 0) return x;
		}
		a->arg++;
		return optv[o];
	}
	*arg = key.head;
	return ARG_BAD;
}

static ArgResult arg_got_short(ArgsState *a, const char *fmt, Str *arg) {
	Str opt = { (char*)*a->arg, 1 };
	Str rem = str_from_cstr(*a->arg + 1);
	for (const char *f = fmt; *f; f++) {
		if (*f == ':') continue;
		if (*f == **a->arg) {
			if (f[1] == ':') {
				int x = arg_param(a, opt, rem, arg);
				if (x < 0) return x;
				a->arg++;
			} else {
				(*a->arg)++;
			}
			return *f;
		}
	}
	*arg = opt;
	return ARG_BAD;
}

static ArgResult arg_get_long(ArgsState *a, const char *fmt, const char **opts, int *optv, Str *arg) {
recurse:
	if (!*a->arg) return ARG_END;
	if (a->no_opts) {
pop:		*arg = str_from_cstr(*a->arg++);
		return ARG_OK;
	}
	const char *arg_end = *a->arg + strlen(*a->arg);
	if (a->opt_end == arg_end) {
		if (!**a->arg) {
			a->arg++;
			goto recurse;
		}
		return arg_got_short(a, fmt, arg);
	}
	if (**a->arg != '-' || a->arg[0][1] == '\0') goto pop;
	if (a->arg[0][1] == '-') {
		if (a->arg[0][2] == '\0') {
			a->no_opts = 1;
			a->arg++;
			goto recurse;
		} else {
			return arg_got_long(a, opts, optv, arg);
		}
	}
	(*a->arg)++;
	a->opt_end = arg_end;
	goto recurse;
}

ArgResult arg_getv(ArgsState *a, const char *fmt, Str *arg, ...) {
	/* I think this is a legitimate usecase for VLAs --- they're not
	 * safe if N depends on user input, but here it very much doesn't!
	 * Just on the number of arguments passed, which is a compile time
	 * constant. */
	va_list ap;
	int n = 0;
	va_start(ap, arg);
	while (va_arg(ap, const char *)) {
		n++;
		(void)va_arg(ap, int);
	}
	va_end(ap);
	if (n > 0) {
		const char *opt[n];
		int optv[n];
		va_start(ap, arg);
		for (int i = 0; i < n; i++) {
			opt[i] = va_arg(ap, const char *);
			optv[i] = va_arg(ap, int);
		}
		va_end(ap);
		return arg_get_long(a, fmt, opt, optv, arg);
	} else {
		return arg_get_long(a, fmt, NULL, NULL, arg);
	}
}

#endif
#endif