shithub: oai

ref: 1f68d6a68ba2b4512d963ea2815280cabe707f53
dir: /oailib.c/

View raw version
#include <u.h>
#include <libc.h>
#include <bio.h>
#include <json.h>
#include "oai.h"

static char *baseurl = nil;
static char *apikey = nil;

static OToolcallfn toolcall = nil;

int oaidebug = 0;

int
initoai(char *url, char *key, OToolcallfn fn)
{
	if (!url)
		url = getenv("oaiurl");
	if (!url) {
		werrstr("invalid baseurl");
		return 0;
	}
	baseurl = url;
	toolcall = fn;
	
	if (!key)
		key = getenv("oaikey");
	apikey = key;
	
	JSONfmtinstall();
	
	return 1;
}

static JSONEl*
mkstrjson(char *name, char *value)
{
	JSONEl *el;
	
	el = mallocz(sizeof(JSONEl), 1);
	assert(el);
	el->name = strdup(name);
	assert(el->name);
	el->val = mallocz(sizeof(JSON), 1);
	assert(el->val);
	el->val->t = JSONString;
	el->val->s = strdup(value);
	assert(el->val->s);
	return el;
}

static JSONEl*
prompt2json(OPrompt *p)
{
	JSONEl *el;
	JSONEl *top;
	
	el = top = mallocz(sizeof(JSONEl), 1);
	el->val = mallocz(sizeof(JSON), 1);
	el->val->t = JSONObject;
	el = el->val->first = mkstrjson("role", p->role);
	if (p->callid) {
		el = el->next = mkstrjson("tool_call_id", p->callid);
	}
	if (p->content) {
		el = el->next = mkstrjson("content", p->content);
		return top;
	}
	el = el->next = mallocz(sizeof(JSONEl), 1);
	el->name = p->jcname ? strdup(p->jcname) : strdup("content");
	el->val = p->jcontent;
	return top;
}

static char*
tooltype(Tooltype t)
{
	switch (t) {
	case Function:
		return "function";
	case Custom:
		return "custom";
	}
	fprint(2, "invalid tool type: %d\n", t);
	return "";
}

static void
func2json(OTool *t, JSONEl *el)
{
	el->next = mallocz(sizeof(JSONEl), 1);
	assert(el->next);
	el = el->next;
	el->name = strdup("function");
	el->val = mallocz(sizeof(JSON), 1);
	assert(el->val);
	el->val->t = JSONObject;
	el->val->first = mkstrjson("name", t->name);
	el->val->first->next = mkstrjson("description", t->description);
	if (t->parameters)
		el->val->first->next->next = mkstrjson("parameters", t->parameters);
}

static void
custom2json(OTool *t, JSONEl *el)
{
	el->next = mkstrjson("name", t->name);
	el = el->next;
	el->next = mkstrjson("description", t->description);
}

static JSONEl*
tool2json(OTool *t)
{
	JSONEl *el;
	
	el = mallocz(sizeof(JSONEl), 1);
	el->val = mallocz(sizeof(JSON), 1);
	el->val->t = JSONObject;
	el->val->first = mkstrjson("type", tooltype(t->type));
	switch (t->type) {
	case Function:
		func2json(t, el->val->first);
		break;
	case Custom:
		custom2json(t, el->val->first);
		break;
	}
	
	return el;
}

static JSON*
req2json(ORequest *req)
{
	JSON *j;
	JSONEl *el, *tel;
	OPrompt *p;
	OTool *t;
	
	j = mallocz(sizeof(JSON), 1);
	assert(j);
	
	j->t = JSONObject;
	
	el = mallocz(sizeof(JSONEl), 1);
	assert(el);
	el->name = strdup("messages");
	assert(el->name);
	el->val = mallocz(sizeof(JSON), 1);
	assert(el->val);
	el->val->t = JSONArray;
	j->first = el;
	tel = el;
	
	for (p = req->prompts; p; p = p->next) {
		if (!el->val->first) {
			el->val->first = prompt2json(p);
			el = el->val->first;
			continue;
		}
		el->next = prompt2json(p);
		el = el->next;
	}
	
	el = mallocz(sizeof(JSONEl), 1);
	assert(el);
	el->name = strdup("tools");
	assert(el->name);
	el->val = mallocz(sizeof(JSON), 1);
	assert(el->val);
	el->val->t = JSONArray;
	tel->next = el;
	tel = tel->next;
	
	for (t = req->tools; t; t = t->next) {
		if (!el->val->first) {
			el->val->first = tool2json(t);
			el = el->val->first;
			continue;
		}
		el->next = tool2json(t);
		el = el->next;
	}
	
	if (req->model) {
		tel->next = mkstrjson("model", req->model);
	}
	
	return j;
}

static JSON*
getfirstchoice(JSON *j)
{
	JSON *choices;
	
	choices = jsonbyname(j, "choices");
	if (!choices)
		sysfatal("no choices");
	if (choices->t != JSONArray)
		sysfatal("choices is not an array");
	if (!(choices->first && choices->first->val))
		sysfatal("no first choices");
	
	return choices->first->val;
}

static JSON*
getmessage(JSON *j)
{
	JSON *message;
	
	message = jsonbyname(j, "message");
	if (!message)
		sysfatal("missing message");
	if (message->t != JSONObject)
		sysfatal("message is not an object");
	return message;
}

static OResult
j2res(JSON *j)
{
	OResult r;
	JSON *choices;
	JSON *message;
	JSON *role;
	JSON *content;
	
	choices = jsonbyname(j, "choices");
	
	if (!choices)
		sysfatal("no choices");
	if (choices->t != JSONArray)
		sysfatal("invalid response: choices not an array");
	if (!(choices->first && choices->first->val))
		sysfatal("no response message");
	
	message = jsonbyname(choices->first->val, "message");
	if (!message)
		sysfatal("choice has no message");
	if (message->t != JSONObject)
		sysfatal("message is not an object");
	
	role = jsonbyname(message, "role");
	r.role = jsonstr(role);
	if (!r.role)
		sysfatal("choice has no role");
	
	content = jsonbyname(message, "content");
	r.message = jsonstr(content);
	if (!r.message)
		sysfatal("choice has no content");
	
	r.role = strdup(r.role);
	r.message = strdup(r.message);
	return r;
}

static void
addtcprompt(OPrompt *pr, OToolcall *tc)
{
	JSONEl *jel;
	JSON *j;
	
	if (!pr->jcontent) {
		pr->jcontent = mallocz(sizeof(JSON), 1);
		pr->jcontent->t = JSONArray;
		pr->jcontent->first = jel = mallocz(sizeof(JSONEl), 1);
	} else {
		for (jel = pr->jcontent->first; jel->next; jel = jel->next)
			;
		jel->next = mallocz(sizeof(JSONEl), 1);
		jel = jel->next;
	}
	
	jel->val = mallocz(sizeof(JSON), 1);
	j = jel->val;
	j->t = JSONObject;
	jel = j->first = mkstrjson("id", tc->id);
	jel = jel->next = mkstrjson("type", tooltype(tc->type));
	jel = jel->next = mallocz(sizeof(JSONEl), 1);
	jel->name = strdup("function");
	j = jel->val = mallocz(sizeof(JSON), 1);
	j->t = JSONObject;
	jel = j->first = mkstrjson("name", tc->name);
	jel = jel->next = mkstrjson("arguments", tc->arguments);
	USED(jel);
}

enum {
	Fstop,
	Ftoolcall,
};

static int
getfinishreason(JSON *jres)
{
	JSON *first;
	JSON *finish;
	char *s;
	
	first = getfirstchoice(jres);
	
	finish = jsonbyname(first, "finish_reason");
	if (!finish)
		sysfatal("no finish reason");
	if (finish->t != JSONString)
		sysfatal("finish reason not a string");
	
	s = jsonstr(finish);
	if (!s)
		sysfatal("invalid finish reason");
	
	if (oaidebug)
		fprint(2, "finish reason: %s\n", s);
	
	if (strcmp(s, "stop") == 0)
		return Fstop;
	if (strcmp(s, "tool_calls") == 0)
		return Ftoolcall;
	
	if (oaidebug)
		fprint(2, "unknown finish_reason\n");
	return -1;
}

static void
tcparse(JSON *j, OToolcall *tc)
{
	JSON *type, *name, *args, *id;
	JSON *func;
	char *s;
	
	type = jsonbyname(j, "type");
	if (!type)
		sysfatal("tool_call without type");
	s = jsonstr(type);
	if (!s)
		sysfatal("missing tool_call type");
	if (strcmp(s, "function") == 0)
		tc->type = Function;
	else if (strcmp(s, "custom") == 0)
		tc->type = Custom;
	else
		sysfatal("invalid tool_call type: %s", s);
	
	if (tc->type != Function)
		sysfatal("tool_call type %s not implemented!", s);
	
	id = jsonbyname(j, "id");
	if (!id)
		sysfatal("missing tool_call id");
	if (id->t != JSONString)
		sysfatal("tool_call id not a string");
	tc->id = strdup(jsonstr(id));
	
	func = jsonbyname(j, "function");
	if (!func)
		sysfatal("missing tool_call function");
	if (func->t != JSONObject)
		sysfatal("tool_call function wrong type");
	
	name = jsonbyname(func, "name");
	if (!name)
		sysfatal("tool_call without name");
	if (name->t != JSONString)
		sysfatal("tool_call name not a string");
	tc->name = strdup(jsonstr(name));
	
	args = jsonbyname(func, "arguments");
	if (!args)
		sysfatal("tool_call without arguments");
	if (args->t != JSONString)
		sysfatal("tool_call arguments not a string");
	tc->arguments = strdup(jsonstr(args));
}

static OResult
calltool(JSON *j, ORequest *req)
{
	JSON *first;
	JSON *calls;
	JSONEl *el;
	OToolcall tc;
	OResult res;
	OPrompt *pr, *pres;
	char *r;
	
	if (!toolcall)
		sysfatal("tool calls not supported by application!");
	
	first = getfirstchoice(j);
	first = getmessage(first);
	calls = jsonbyname(first, "tool_calls");
	if (!calls)
		sysfatal("tool call with missing tool_calls");
	if (calls->t != JSONArray)
		sysfatal("tool_calls is not an array");
	
	pr = makeprompt("assistant");
	pr->jcname = strdup("tool_calls");
	addprompt(req, pr);
	
	for (el = calls->first; el; el = el->next) {
		tcparse(el->val, &tc);
		r = toolcall(tc);
		if (!r)
			continue;
		addtcprompt(pr, &tc);
		pres = makeprompt("tool");
		pres->content = r;
		pres->callid = strdup(tc.id);
		addprompt(req, pres);
	}
	
	res = makerequest(*req);
	return res;
}

OResult
makerequest(ORequest req)
{
	char buf[128];
	int ctlfd, pbodyfd;
	Biobuf *body;
	char *s;
	int n;
	OResult ret;
	JSON *jreq;
	JSON *jres;
	
	jreq = req2json(&req);
	
	ctlfd = open("/mnt/web/clone", ORDWR);
	if (ctlfd < 0)
		sysfatal("webfs ctl open: %r");
	if ((n = read(ctlfd, buf, sizeof buf)) < 0)
		sysfatal("webfs ctl read: %r");
	buf[n] = 0;
	
	n = atoi(buf);
	
	fprint(ctlfd, "useragent 9front\n");
	fprint(ctlfd, "contenttype application/json\n");
	fprint(ctlfd, "headers Authorization: Bearer %s\n", apikey ? apikey : "no-key");
	fprint(ctlfd, "baseurl %s\n", baseurl);
	fprint(ctlfd, "url ./v1/chat/completions\n");
	
	if (oaidebug)
		fprint(2, "request:\n%J\n\n", jreq);
	
	snprint(buf, sizeof buf, "/mnt/web/%d/postbody", n);
	pbodyfd = open(buf, OWRITE);
	if (pbodyfd < 0)
		sysfatal("webfs pbody open: %r");
	fprint(pbodyfd, "%J", jreq);
	close(pbodyfd);
	
	snprint(buf, sizeof buf, "/mnt/web/%d/body", n);
	body = Bopen(buf, OREAD);
	if (!body)
		sysfatal("webfs body open: %r");
	
	s = Brdstr(body, 0, 0);
	Bterm(body);
	close(ctlfd);
	
	jres = jsonparse(s);
	if (oaidebug)
		fprint(2, "response\n%J\n\n", jres);
	
	switch (getfinishreason(jres)) {
	default:
		ret.success = 0;
		return ret;
	case Fstop:
		break;
	case Ftoolcall:
		ret.success = 0;
		if (toolcall)
			ret = calltool(jres, &req);
		return ret;
	}
	
	ret = j2res(jres);
	
	free(s);
	jsonfree(jreq);
	jsonfree(jres);
	
	return ret;
}

int
addstrprompt(ORequest *r, char *role, char *content, ...)
{
	OPrompt *p;
	char *s;
	va_list arg;
	
	va_start(arg, content);
	s = vsmprint(content, arg);
	va_end(arg);
	
	if (!r->prompts) {
		p = mallocz(sizeof(OPrompt), 1);
		r->prompts = p;
		goto Fill;
	}
	for (p = r->prompts; p->next; p = p->next)
		;
	p->next = mallocz(sizeof(OPrompt), 1);
	p = p->next;
	
Fill:
	p->role = strdup(role);
	p->content = s;
	return 1;
}

int
addprompt(ORequest *r, OPrompt *p)
{
	OPrompt *pr;
	p->next = nil;
	if (!r->prompts) {
		r->prompts = p;
		return 1;
	}
	for (pr = r->prompts; pr->next; pr = pr->next)
		;
	pr->next = p;
	return 1;
}

OPrompt*
makeprompt(char *role)
{
	OPrompt *p;
	
	p = mallocz(sizeof(OPrompt), 1);
	p->role = strdup(role);
	return p;
}

static JSONEl*
addgenmessage(OPrompt *p, char *type, char *cname, char *content)
{
	JSON *j;
	JSONEl *jel;
	
	if (p->content) {
		werrstr("prompt has string content");
		return nil;
	}
	
	if (!p->jcontent) {
		p->jcontent = mallocz(sizeof(JSON), 1);
		p->jcontent->t = JSONArray;
	} else {
		if (p->jcontent->t != JSONArray) {
			werrstr("prompt json is not array");
			return nil;
		}
	}
	j = p->jcontent;
	
	if (!j->first) {
		j->first = mallocz(sizeof(JSONEl), 1);
		jel = j->first;
		goto Fill;
	}
	
	for (jel = j->first; jel->next; jel = jel->next)
		;
	
	jel = jel->next = mallocz(sizeof(JSONEl), 1);
	
Fill:
	jel->name = nil;
	jel->val = mallocz(sizeof(JSON), 1);
	j = jel->val;
	j->t = JSONObject;
	
	jel = j->first = mallocz(sizeof(JSONEl), 1);
	jel->name = strdup("type");
	jel->val = mallocz(sizeof(JSON), 1);
	jel->val->t = JSONString;
	jel->val->s = strdup(type);
	
	if (!cname)
		return jel;
	
	jel = jel->next = mallocz(sizeof(JSONEl), 1);
	jel->name = strdup(cname);
	jel->val = mallocz(sizeof(JSON), 1);
	jel->val->t = JSONString;
	jel->val->s = strdup(content);
	return jel;
}

static JSONEl*
appendfield(JSONEl *jel, char *name, char *value)
{
	jel->next = mallocz(sizeof(JSONEl), 1);
	jel = jel->next;
	jel->name = strdup(name);
	jel->val = mallocz(sizeof(JSON), 1);
	jel->val->t = JSONString;
	jel->val->s = strdup(value);
	return jel;
}

int
addtextmessage(OPrompt *p, char *text)
{
	return !!addgenmessage(p, "text", "text", text);
}

int
addfilemessage(OPrompt *p, uchar *data, long ndata, char *filename, char *fileid)
{
	char *b64;
	JSONEl *jel;
	
	if (data) {
		b64 = mallocz(ndata*2+1, 1);
		if (enc64(b64, ndata*2+1, data, ndata) < 0) {
			werrstr("enc64 error");
			free(b64);
			return 0;
		}
		jel = addgenmessage(p, "file", "file_data", b64);
		free(b64);
	} else {
		jel = addgenmessage(p, "file", nil, nil);
	}
	
	if (!jel)
		return 0;
	
	if (filename)
		jel = appendfield(jel, "filename", filename);
	if (fileid)
		jel = appendfield(jel, "file_id", fileid);
	
	return !!jel;
}

OTool*
maketool(OTool *tool, Tooltype type, char *name, char *description, char *parameters)
{
	OTool *t, *n;
	
	t = mallocz(sizeof(OTool), 1);
	assert(t);
	
	t->type = type;
	t->name = strdup(name);
	t->description = strdup(description);
	if (parameters && parameters[0])
		t->parameters = strdup(parameters);
	
	if (!tool)
		return t;
	
	for (n = tool; n->next; n = n->next)
		;
	n->next = t;
	return t;
}