shithub: oai

ref: 785ad1d9d08166d95344f1d7f53346eb27cca31c
dir: /oai.c/

View raw version
#include <u.h>
#include <libc.h>
#include <bio.h>
#include <json.h>
#include <String.h>
#include "oai.h"

static void
usage(void)
{
	fprint(2, "usage: %s [-dq] [-k apikey] [-m model] [-u baseurl] [-s sysprompt] [-vibe] [-yolo]\n", argv0);
	exits("usage");
}

static int quiet = 0;
static int vibemode = 0;
static int yolomode = 0;

static char *callfunc(char*);

static int
allowed(OToolcall toolcall, int destructive)
{
	JSON *j;
	char buf[3];
	int n;
	int skipflag;
	
	fprint(2, "Attempt to call command: %s\n", toolcall.name);
	j = jsonparse(toolcall.arguments);
	if (j) {
		fprint(2, "%J\n", j);
		jsonfree(j);
	}
	
	skipflag = destructive ? yolomode : vibemode;
	if (skipflag)
		return 1;
Again:
	fprint(2, "Continue? (y/n) ");
	n = read(0, buf, 3);
	if (n < 0)
		sysfatal("lost connection");
	if (!n)
		goto Again;
	switch (buf[0]) {
	case 'y':
		return 1;
	case 'n':
		return 0;
	}
	goto Again;
}

static char*
abortcall(OToolcall tc)
{
	return smprint("%s request aborted by user!", tc.name);
}

static char*
list_files(OToolcall tc, void*)
{
	String *str;
	int n, i, fd;
	Dir *dirbuf;
	char *s;
	JSON *j, *jf;
	
	if (!allowed(tc, 0))
		return abortcall(tc);
	
	j = jsonparse(tc.arguments);
	jf = jsonbyname(j, "folder");
	s = jsonstr(jf);
	if (!(s && s[0])) {
		fprint(2, "list_files: invalid folder!");
		jsonfree(j);
		return strdup("list_files: invalid folder");
	}
	
	dirbuf = dirstat(s);
	if (!(dirbuf->mode&DMDIR)) {
		s = smprint("%s is a file, not a folder", s);
		free(dirbuf);
		jsonfree(j);
		return s;
	}
	
	fd = open(s, OREAD);
	if (fd < 0)
		return strdup("");
	
	jsonfree(j);
	
	n = dirreadall(fd, &dirbuf);
	close(fd);
	if (n < 0)
		return strdup("");
	
	str = s_new();
	for (i = 0; i < n; i++) {
		str = s_append(str, dirbuf[i].name);
		if (dirbuf[i].mode&DMDIR)
			str = s_append(str, "/");
		str = s_append(str, "\n");
	}
	free(dirbuf);
	s = strdup(s_to_c(str));
	s_free(str);
	return s;
}

static char* listfilesdesc = "list all files in the specified directory, similar to `ls` command. Paths are relative to the current working directory. Use `.` for the current directory.";
static char* listfilesargs = "{"
"	\"type\": \"object\","
"	\"properties\": {"
"		\"folder\": {"
"			\"type\": \"string\","
"			\"description\": \"relative or absolute path to the folder.\""
"		}"
"	},"
"	\"required\": [ \"folder\" ]"
"}";

static char*
read_file(OToolcall toolcall, void*)
{
	JSON *j, *fj;
	char *file;
	Biobuf *io;
	char *s;
	Dir *dir;
	
	if (!allowed(toolcall, 0))
		return abortcall(toolcall);
	
	j = jsonparse(toolcall.arguments);
	fj = jsonbyname(j, "file");
	if (!fj) {
		fprint(2, "no file in read_file request!\n");
		jsonfree(j);
		return strdup("bad request in read_file!\n");
	}
	
	file = jsonstr(fj);
	if (!(file && file[0])) {
		fprint(2, "invalid file in read_file request!\n");
		jsonfree(j);
		return strdup("bad request in read_file: no file!\n");
	}
	
	dir = dirstat(file);
	if (!dir)
		return smprint("error opening file: %r\n");
	
	if (dir->mode&DMDIR) {
		s = smprint("error: file is a directory\n");
		free(dir);
		return s;
	}
	
	io = Bopen(file, OREAD);
	if (!io) {
		fprint(2, "open file: %r\n");
		return smprint("open file in read_file: %r");
	}
	
	fprint(2, "read_file: %s\n", file);
	
	s = Brdstr(io, 0, 0);
	jsonfree(j);
	Bterm(io);
	return s;
}

static char *readfiledesc = "read the contents of a specified file, similar to the `cat` unix command. Especially useful on a Plan 9 system for interaction with filesystems.";
static char *readfileargs = "{"
"	\"type\": \"object\","
"	\"properties\": {"
"		\"file\": {"
"			\"type\": \"string\","
"			\"description\": \"relative or absolute path to the file\""
"		}"
"	},"
"	\"required\": [ \"file\" ]"
"}";

static char*
lookman(OToolcall tc, void*)
{
	JSON *j, *jk;
	char *s, *cmd;
	
	if (!allowed(tc, 0))
		return abortcall(tc);
	
	j = jsonparse(tc.arguments);
	jk = jsonbyname(j, "keyword");
	s = jsonstr(jk);
	if (!(s && s[0]))
		return strdup("lookman: missing keyword");
	
	cmd = smprint("lookman '%s'", s);
	jsonfree(j);
	s = callfunc(cmd);
	free(cmd);
	return s;
}

static char *lookmandesc = "Search the man pages for the specified keyword. The result shows the name of the man page and the section number.";
static char *lookmanargs = "{"
"	\"type\": \"object\","
"	\"properties\": {"
"		\"keyword\": {"
"			\"type\": \"string\","
"			\"description\": \"keyword to search for in the man pages\""
"		}"
"	},"
"	\"required\": [ \"keyword\" ]"
"}";

static char*
man(OToolcall tc, void*)
{
	JSON *j, *js, *jn;
	char *sec, *name;
	char *s, *cmd;
	
	if (!allowed(tc, 0))
		return abortcall(tc);
	
	j = jsonparse(tc.arguments);
	js = jsonbyname(j, "section");
	jn = jsonbyname(j, "name");
	
	sec = jsonstr(js);
	name = jsonstr(jn);
	if (!(sec && name))
		return strdup("man: missing section or name");
	
	if (!(sec[0] && name[0]))
		return strdup("man: empty section or name");
	
	cmd = smprint("man '%s' '%s'", sec, name);
	jsonfree(j);
	s = callfunc(cmd);
	free(cmd);
	return s;
}

static char *mandesc = "get the contents of the specified man page.";
static char *manargs = "{"
"	\"type\": \"object\","
"	\"properties\": {"
"		\"section\": {"
"			\"type\": \"string\","
"			\"description\": \"section number of the man page\""
"		},"
"		\"name\": {"
"			\"type\": \"string\","
"			\"description\": \"name of the section\""
"		}"
"	},"
"	\"required\": [ \"section\", \"name\" ]"
"}";

OTool *tools = nil;

static void
inittools(void)
{
	tools = maketool(tools, Function, "list_files", listfilesdesc, listfilesargs, list_files, nil);
	maketool(tools, Function, "read_file", readfiledesc, readfileargs, read_file, nil);
	maketool(tools, Function, "search_man", lookmandesc, lookmanargs, lookman, nil);
	maketool(tools, Function, "read_man", mandesc, manargs, man, nil);
}

static char*
callfunc(char *s)
{
	int pin[2];
	Biobuf *bin;
	
	pipe(pin);
	
	switch (fork()) {
	default:
		/* parent */
		break;
	case 0:
		/* child */
		dup(pin[1], 1);
		close(pin[1]);
		close(pin[0]);
		execl("/bin/rc", "rc", "-c", s, nil);
		break;
	case -1:
		sysfatal("fork: %r");
	}
	
	close(pin[1]);
	bin = Bfdopen(pin[0], OREAD);
	s = Brdstr(bin, 0, 0);
	Bterm(bin);
	close(pin[0]);
	wait();
	return s;
}

static char*
readconsole(Biobuf *bin)
{
	char *s, *result;
	
	print(">>> ");
	while (s = Brdstr(bin, '\n', 1)) {
		if (!s[0]) {
			free(s);
			print(">>> ");
			continue;
		}
		switch (s[0]) {
		case '':
			free(s);
			return nil;
		case '!':
			result = callfunc(s+1);
			free(s);
			return result;
		default:
			result = callfunc(s);
			free(s);
			print("%s", result);
			free(result);
			print(">>> ");
		}
	}
	return nil;
}

static char
ask(Biobuf *bin, char *what, char *valid)
{
	char *s;
	char c;
	
	for (;;) {
		print("!!! %s ", what);
		s = Brdstr(bin, '\n', 1);
		if (!s)
			break;
		if (!s[0]) {
			free(s);
			continue;
		}
		if (strchr(valid, s[0])) {
			c = s[0];
			free(s);
			return c;
		}
		free(s);
	}
	return 0;
}

static void
printusage(OResult *r)
{
	print("usage: c=%d p=%d t=%d\n", r->tokscompletion, r->toksprompt, r->tokstotal);
}

char *plan9prompt =
#include "oai_common.princ"
#include "plan9.princ"
;

char *frontprompt =
#include "oai_common.princ"
#include "front.princ"
;

void
main(int argc, char **argv)
{
	Biobuf *bin;
	char *s;
	ORequest req;
	OResult res;
	char *sysprompt;
	
	char *url = nil;
	char *key = nil;
	char *model = nil;
	
	if (!(access("/dist/9front", AEXIST) && access("/dist/plan9front", AEXIST))) {
		/* 9front system */
		sysprompt = frontprompt;
	} else {
		/* other plan 9 system */
		sysprompt = plan9prompt;
	}
	
	ARGBEGIN{
	case 'h':
		usage();
	case 'k':
		key = EARGF(usage());
		break;
	case 'm':
		model = EARGF(usage());
		break;
	case 'u':
		url = EARGF(usage());
		break;
	case 's':
		sysprompt = EARGF(usage());
		break;
	case '9':
		sysprompt = plan9prompt;
		break;
	case 'q':
		quiet++;
		break;
	case 'd':
		oaidebug++;
		break;
	case 'v':
		if (strcmp(EARGF(usage()), "ibe") == 0)
			vibemode = 1;
		else
			usage();
		break;
	case 'y':
		if (strcmp(EARGF(usage()), "olo") == 0)
			yolomode = 1;
		else
			usage();
		break;
	}ARGEND;
	
	if (yolomode)
		vibemode = 1;
	
	if (!initoai(url, key, model))
		usage();
	
	bin = Bfdopen(0, OREAD);
	assert(bin);
	
	inittools();
	
	req.prompts = nil;
	req.tools = tools;
	
	if (sysprompt)
		addstrprompt(&req, "system", "%s", sysprompt);
	
	if (!quiet) print("user: ");
	while (s = Brdstr(bin, '\n', 1)) {
		if (s[0] == '') {
			free(s);
			s = readconsole(bin);
		}
		if (!(s && s[0])) {
			free(s);
			goto Next;
		}
		addstrprompt(&req, "user", s);
Again:
		res = makerequest(&req);
		printusage(&res);
		if (!res.success) {
			fprint(2, "ERROR: %r\n");
			switch (ask(bin, "Try again (y/n)?", "yn")) {
			default:
				fprint(2, "exiting!\n");
				exits("fail");
			case 'y':
				goto Again;
			case 'n':
				fprint(2, "bye.\n");
				exits(nil);
			}
		}
		print("%s%s%s\n\n", res.role, (quiet ? "" : ": "), res.message);
		addprompt(&req, res.asprompt);
Next:
		if (!quiet) print("user: ");
	}
	exits(nil);
}