ref: 1c65e9795d3a51c26a87296d53b0ebe9efdc7719
dir: /oai.c/
#include <u.h>
#include <libc.h>
#include <bio.h>
#include <json.h>
#include <String.h>
#include "oai.h"
static void
usage(void)
{
fprint(2, "usage: %s [-dq] [-k apikey] [-m model] [-u baseurl] [-s sysprompt] [-vibe] [-yolo]\n", argv0);
exits("usage");
}
static int quiet = 0;
static int vibemode = 0;
static int yolomode = 0;
static char *callfunc(char*);
static char *writetofile(char*, char*);
static int
allowed(OToolcall toolcall, int destructive)
{
JSON *j;
char buf[3];
int n;
int skipflag;
fprint(2, "Attempt to call command: %s\n", toolcall.name);
j = jsonparse(toolcall.arguments);
if (j) {
fprint(2, "%J\n", j);
jsonfree(j);
}
skipflag = destructive ? yolomode : vibemode;
if (skipflag)
return 1;
Again:
fprint(2, "Continue? (y/n) ");
n = read(0, buf, 3);
if (n < 0)
sysfatal("lost connection");
if (!n)
goto Again;
switch (buf[0]) {
case 'y':
return 1;
case 'n':
return 0;
}
goto Again;
}
static char*
abortcall(OToolcall tc)
{
return smprint("%s request aborted by user!", tc.name);
}
static char*
joinstrarray(JSONEl *first, char *join, int quote)
{
JSONEl *el;
String *s;
char *cs;
s = s_new();
for (el = first; el; el = el->next) {
cs = quote ? quotestrdup(jsonstr(el->val)) : jsonstr(el->val);
if (el != first)
s = s_append(s, join);
s = s_append(s, cs);
if (quote)
free(cs);
}
cs = strdup(s_to_c(s));
s_free(s);
return cs;
}
#include "tools/listfiles.cinc"
#include "tools/readfile.cinc"
#include "tools/man.cinc"
#include "tools/writefile.cinc"
#include "tools/cmd.cinc"
OTool *tools = nil;
static void
inittools(void)
{
tools = maketool(tools, Function, "list_files", listfilesdesc, listfilesargs, list_files, nil);
maketool(tools, Function, "read_file", readfiledesc, readfileargs, read_file, nil);
maketool(tools, Function, "search_man", lookmandesc, lookmanargs, lookman, nil);
maketool(tools, Function, "read_man", mandesc, manargs, man, nil);
maketool(tools, Function, "write_file", writefiledesc, writefileargs, writefile, nil);
maketool(tools, Function, "run_command", cmddesc, cmdargs, cmd, nil);
}
static char*
writetofile(char *file, char *content)
{
int fd;
long n;
char *ret;
fd = open(file, OWRITE|OTRUNC);
if (fd < 0)
fd = create(file, OWRITE|OTRUNC, 0666);
if (fd < 0)
return smprint("cannot write to file: %r");
n = strlen(content);
if (write(fd, content, n) != n)
ret = strdup("write error: %r");
else
ret = strdup("file written successfully");
close(fd);
return ret;
}
static char*
callfunc(char *s)
{
int pin[2];
Biobuf *bin;
pipe(pin);
switch (fork()) {
default:
/* parent */
break;
case 0:
/* child */
dup(pin[1], 1);
close(pin[1]);
close(pin[0]);
execl("/bin/rc", "rc", "-c", s, nil);
break;
case -1:
sysfatal("fork: %r");
}
close(pin[1]);
bin = Bfdopen(pin[0], OREAD);
s = Brdstr(bin, 0, 0);
Bterm(bin);
close(pin[0]);
wait();
return s;
}
static char*
readconsole(Biobuf *bin)
{
char *s, *result;
print(">>> ");
while (s = Brdstr(bin, '\n', 1)) {
if (!s[0]) {
free(s);
print(">>> ");
continue;
}
switch (s[0]) {
case '':
free(s);
return nil;
case '!':
result = callfunc(s+1);
free(s);
return result;
default:
result = callfunc(s);
free(s);
print("%s", result);
free(result);
print(">>> ");
}
}
return nil;
}
static char
ask(Biobuf *bin, char *what, char *valid)
{
char *s;
char c;
for (;;) {
print("!!! %s ", what);
s = Brdstr(bin, '\n', 1);
if (!s)
break;
if (!s[0]) {
free(s);
continue;
}
if (strchr(valid, s[0])) {
c = s[0];
free(s);
return c;
}
free(s);
}
return 0;
}
static void
printusage(OResult *r)
{
print("usage: c=%d p=%d t=%d\n", r->tokscompletion, r->toksprompt, r->tokstotal);
}
char *plan9prompt =
#include "oai_common.princ"
#include "plan9.princ"
;
char *frontprompt =
#include "oai_common.princ"
#include "front.princ"
;
void
main(int argc, char **argv)
{
Biobuf *bin;
char *s;
ORequest req;
OResult res;
char *sysprompt;
char *url = nil;
char *key = nil;
char *model = nil;
if (!(access("/dist/9front", AEXIST) && access("/dist/plan9front", AEXIST))) {
/* 9front system */
sysprompt = frontprompt;
} else {
/* other plan 9 system */
sysprompt = plan9prompt;
}
ARGBEGIN{
case 'h':
usage();
case 'k':
key = EARGF(usage());
break;
case 'm':
model = EARGF(usage());
break;
case 'u':
url = EARGF(usage());
break;
case 's':
sysprompt = EARGF(usage());
break;
case '9':
sysprompt = plan9prompt;
break;
case 'q':
quiet++;
break;
case 'd':
oaidebug++;
break;
case 'v':
if (strcmp(EARGF(usage()), "ibe") == 0)
vibemode = 1;
else
usage();
break;
case 'y':
if (strcmp(EARGF(usage()), "olo") == 0)
yolomode = 1;
else
usage();
break;
}ARGEND;
if (yolomode)
vibemode = 1;
if (!initoai(url, key, model))
usage();
bin = Bfdopen(0, OREAD);
assert(bin);
inittools();
req.prompts = nil;
req.tools = tools;
if (sysprompt)
addstrprompt(&req, "system", "%s", sysprompt);
if (!quiet) print("user: ");
while (s = Brdstr(bin, '\n', 1)) {
if (s[0] == '') {
free(s);
s = readconsole(bin);
}
if (!(s && s[0])) {
free(s);
goto Next;
}
addstrprompt(&req, "user", s);
Again:
res = makerequest(&req);
printusage(&res);
if (!res.success) {
fprint(2, "ERROR: %r\n");
switch (ask(bin, "Try again (y/n)?", "yn")) {
default:
fprint(2, "exiting!\n");
exits("fail");
case 'y':
goto Again;
case 'n':
fprint(2, "bye.\n");
exits(nil);
}
}
print("%s%s%s\n\n", res.role, (quiet ? "" : ": "), res.message);
addprompt(&req, res.asprompt);
Next:
if (!quiet) print("user: ");
}
exits(nil);
}