ref: 90671e4fafe7c674b2e46f60cf8fa82b782105b8
dir: /oailib.c/
#include <u.h>
#include <libc.h>
#include <bio.h>
#include <json.h>
#include "oai.h"
static char *baseurl = nil;
static char *apikey = nil;
int oaidebug = 0;
static JSON*
copyjtree(JSON *in)
{
JSON *out;
char *s;
if (!in)
return nil;
s = smprint("%J", in);
if (!s)
return nil;
out = jsonparse(s);
free(s);
return out;
}
int
initoai(char *url, char *key)
{
if (!url)
url = getenv("oaiurl");
if (!url) {
werrstr("invalid baseurl");
return 0;
}
baseurl = url;
if (!key)
key = getenv("oaikey");
apikey = key;
JSONfmtinstall();
return 1;
}
static JSONEl*
mkstrjson(char *name, char *value)
{
JSONEl *el;
el = mallocz(sizeof(JSONEl), 1);
assert(el);
el->name = strdup(name);
assert(el->name);
el->val = mallocz(sizeof(JSON), 1);
assert(el->val);
el->val->t = JSONString;
el->val->s = strdup(value);
assert(el->val->s);
return el;
}
static JSONEl*
prompt2json(OPrompt *p)
{
JSONEl *el;
JSONEl *top;
el = top = mallocz(sizeof(JSONEl), 1);
el->val = mallocz(sizeof(JSON), 1);
el->val->t = JSONObject;
el = el->val->first = mkstrjson("role", p->role);
if (p->callid) {
el = el->next = mkstrjson("tool_call_id", p->callid);
}
if (p->content) {
el = el->next = mkstrjson("content", p->content);
return top;
}
el = el->next = mallocz(sizeof(JSONEl), 1);
el->name = p->jcname ? strdup(p->jcname) : strdup("content");
el->val = copyjtree(p->jcontent);
return top;
}
static char*
tooltype(Tooltype t)
{
switch (t) {
case Function:
return "function";
case Custom:
return "custom";
}
fprint(2, "invalid tool type: %d\n", t);
return "";
}
static void
func2json(OTool *t, JSONEl *el)
{
el->next = mallocz(sizeof(JSONEl), 1);
assert(el->next);
el = el->next;
el->name = strdup("function");
el->val = mallocz(sizeof(JSON), 1);
assert(el->val);
el->val->t = JSONObject;
el->val->first = mkstrjson("name", t->name);
el->val->first->next = mkstrjson("description", t->description);
if (!t->parameters)
return;
el = el->val->first->next->next = mallocz(sizeof(JSONEl), 1);
el->name = strdup("parameters");
el->val = copyjtree(t->parameters);
}
static void
custom2json(OTool *t, JSONEl *el)
{
el->next = mkstrjson("name", t->name);
el = el->next;
el->next = mkstrjson("description", t->description);
}
static JSONEl*
tool2json(OTool *t)
{
JSONEl *el;
el = mallocz(sizeof(JSONEl), 1);
el->val = mallocz(sizeof(JSON), 1);
el->val->t = JSONObject;
el->val->first = mkstrjson("type", tooltype(t->type));
switch (t->type) {
case Function:
func2json(t, el->val->first);
break;
case Custom:
custom2json(t, el->val->first);
break;
}
return el;
}
static JSON*
req2json(ORequest *req)
{
JSON *j;
JSONEl *el, *tel;
OPrompt *p;
OTool *t;
j = mallocz(sizeof(JSON), 1);
assert(j);
j->t = JSONObject;
el = mallocz(sizeof(JSONEl), 1);
assert(el);
el->name = strdup("messages");
assert(el->name);
el->val = mallocz(sizeof(JSON), 1);
assert(el->val);
el->val->t = JSONArray;
j->first = el;
tel = el;
for (p = req->prompts; p; p = p->next) {
if (!el->val->first) {
el->val->first = prompt2json(p);
el = el->val->first;
continue;
}
el->next = prompt2json(p);
el = el->next;
}
el = mallocz(sizeof(JSONEl), 1);
assert(el);
el->name = strdup("tools");
assert(el->name);
el->val = mallocz(sizeof(JSON), 1);
assert(el->val);
el->val->t = JSONArray;
tel->next = el;
tel = tel->next;
for (t = req->tools; t; t = t->next) {
if (!el->val->first) {
el->val->first = tool2json(t);
el = el->val->first;
continue;
}
el->next = tool2json(t);
el = el->next;
}
if (req->model) {
tel->next = mkstrjson("model", req->model);
}
return j;
}
static JSON*
getfirstchoice(JSON *j)
{
JSON *choices;
choices = jsonbyname(j, "choices");
if (!choices)
sysfatal("no choices");
if (choices->t != JSONArray)
sysfatal("choices is not an array");
if (!(choices->first && choices->first->val))
sysfatal("no first choices");
return choices->first->val;
}
static JSON*
getmessage(JSON *j)
{
JSON *message;
message = jsonbyname(j, "message");
if (!message)
sysfatal("missing message");
if (message->t != JSONObject)
sysfatal("message is not an object");
return message;
}
static OResult
j2res(JSON *j)
{
OResult r;
JSON *choices;
JSON *message;
JSON *role;
JSON *content;
choices = jsonbyname(j, "choices");
if (!choices)
sysfatal("no choices");
if (choices->t != JSONArray)
sysfatal("invalid response: choices not an array");
if (!(choices->first && choices->first->val))
sysfatal("no response message");
message = jsonbyname(choices->first->val, "message");
if (!message)
sysfatal("choice has no message");
if (message->t != JSONObject)
sysfatal("message is not an object");
role = jsonbyname(message, "role");
r.role = jsonstr(role);
if (!r.role)
sysfatal("choice has no role");
content = jsonbyname(message, "content");
r.message = jsonstr(content);
if (!r.message)
sysfatal("choice has no content");
r.role = strdup(r.role);
r.message = strdup(r.message);
return r;
}
static void
addtcprompt(OPrompt *pr, OToolcall *tc)
{
JSONEl *jel;
JSON *j;
if (!pr->jcontent) {
pr->jcontent = mallocz(sizeof(JSON), 1);
pr->jcontent->t = JSONArray;
pr->jcontent->first = jel = mallocz(sizeof(JSONEl), 1);
} else {
for (jel = pr->jcontent->first; jel->next; jel = jel->next)
;
jel->next = mallocz(sizeof(JSONEl), 1);
jel = jel->next;
}
jel->val = mallocz(sizeof(JSON), 1);
j = jel->val;
j->t = JSONObject;
jel = j->first = mkstrjson("id", tc->id);
jel = jel->next = mkstrjson("type", tooltype(tc->type));
jel = jel->next = mallocz(sizeof(JSONEl), 1);
jel->name = strdup("function");
j = jel->val = mallocz(sizeof(JSON), 1);
j->t = JSONObject;
jel = j->first = mkstrjson("name", tc->name);
if (!tc->arguments)
return;
jel = jel->next = mkstrjson("arguments", tc->arguments);
USED(jel);
}
enum {
Fstop,
Ftoolcall,
};
static int
getfinishreason(JSON *jres)
{
JSON *first;
JSON *finish;
char *s;
first = getfirstchoice(jres);
finish = jsonbyname(first, "finish_reason");
if (!finish)
sysfatal("no finish reason");
if (finish->t != JSONString)
sysfatal("finish reason not a string");
s = jsonstr(finish);
if (!s)
sysfatal("invalid finish reason");
if (oaidebug)
fprint(2, "finish reason: %s\n", s);
if (strcmp(s, "stop") == 0)
return Fstop;
if (strcmp(s, "tool_calls") == 0)
return Ftoolcall;
if (oaidebug)
fprint(2, "unknown finish_reason\n");
return -1;
}
static void
tcparse(JSON *j, OToolcall *tc)
{
JSON *type, *name, *args, *id;
JSON *func;
char *s;
type = jsonbyname(j, "type");
if (!type)
sysfatal("tool_call without type");
s = jsonstr(type);
if (!s)
sysfatal("missing tool_call type");
if (strcmp(s, "function") == 0)
tc->type = Function;
else if (strcmp(s, "custom") == 0)
tc->type = Custom;
else
sysfatal("invalid tool_call type: %s", s);
if (tc->type != Function)
sysfatal("tool_call type %s not implemented!", s);
id = jsonbyname(j, "id");
if (!id)
sysfatal("missing tool_call id");
if (id->t != JSONString)
sysfatal("tool_call id not a string");
tc->id = strdup(jsonstr(id));
func = jsonbyname(j, "function");
if (!func)
sysfatal("missing tool_call function");
if (func->t != JSONObject)
sysfatal("tool_call function wrong type");
name = jsonbyname(func, "name");
if (!name)
sysfatal("tool_call without name");
if (name->t != JSONString)
sysfatal("tool_call name not a string");
tc->name = strdup(jsonstr(name));
args = jsonbyname(func, "arguments");
if (!args)
sysfatal("tool_call without arguments");
if (args->t != JSONString)
sysfatal("tool_call arguments not a string");
tc->arguments = strdup(jsonstr(args));
}
static OTool*
findtool(ORequest *req, char *name)
{
OTool *t;
for (t = req->tools; t; t = t->next)
if (strcmp(t->name, name) == 0)
return t;
return nil;
}
static OResult
calltool(JSON *j, ORequest *req)
{
JSON *first;
JSON *calls;
JSONEl *el;
OTool *tool;
OToolcallfn toolfn;
OToolcall tc;
OResult res;
OPrompt *pr, *pres;
OPrompt *resprompt;
char *r;
first = getfirstchoice(j);
first = getmessage(first);
calls = jsonbyname(first, "tool_calls");
if (!calls)
sysfatal("tool call with missing tool_calls");
if (calls->t != JSONArray)
sysfatal("tool_calls is not an array");
pr = makeprompt("assistant");
pr->jcname = strdup("tool_calls");
addprompt(req, pr);
resprompt = makeprompt("assistant");
resprompt->jcname = strdup("tool_calls");
for (el = calls->first; el; el = el->next) {
tcparse(el->val, &tc);
tool = findtool(req, tc.name);
if (!tool) {
fprint(2, "missing tool: %s\n", tc.name);
continue;
}
toolfn = tool->func;
if (!toolfn) {
fprint(2, "tool %s has no func\n", tc.name);
continue;
}
r = toolfn(tc, tool->aux);
if (!r)
continue;
addtcprompt(pr, &tc);
addtcprompt(resprompt, &tc);
pres = makeprompt("tool");
pres->content = r;
pres->callid = strdup(tc.id);
addprompt(req, pres);
}
res = makerequest(req);
return res;
}
OResult
makerequest(ORequest *req)
{
char buf[128];
int ctlfd, pbodyfd;
Biobuf *body;
char *s;
int n;
OResult ret;
JSON *jreq;
JSON *jres;
jreq = req2json(req);
if (!jreq) {
fprint(2, "error: %r\n");
ret.success = 0;
return ret;
}
ctlfd = open("/mnt/web/clone", ORDWR);
if (ctlfd < 0)
sysfatal("webfs ctl open: %r");
if ((n = read(ctlfd, buf, sizeof buf)) < 0)
sysfatal("webfs ctl read: %r");
buf[n] = 0;
n = atoi(buf);
fprint(ctlfd, "useragent 9front\n");
fprint(ctlfd, "contenttype application/json\n");
fprint(ctlfd, "headers Authorization: Bearer %s\n", apikey ? apikey : "no-key");
fprint(ctlfd, "baseurl %s\n", baseurl);
fprint(ctlfd, "url ./v1/chat/completions\n");
if (oaidebug)
fprint(2, "request:\n%J\n\n", jreq);
snprint(buf, sizeof buf, "/mnt/web/%d/postbody", n);
pbodyfd = open(buf, OWRITE);
if (pbodyfd < 0)
sysfatal("webfs pbody open: %r");
fprint(pbodyfd, "%J", jreq);
close(pbodyfd);
snprint(buf, sizeof buf, "/mnt/web/%d/body", n);
body = Bopen(buf, OREAD);
if (!body)
sysfatal("webfs body open: %r");
s = Brdstr(body, 0, 0);
Bterm(body);
close(ctlfd);
jres = jsonparse(s);
if (oaidebug)
fprint(2, "response\n%J\n\n", jres);
switch (getfinishreason(jres)) {
default:
ret.success = 0;
return ret;
case Fstop:
break;
case Ftoolcall:
ret.success = 0;
ret = calltool(jres, req);
return ret;
}
ret = j2res(jres);
free(s);
jsonfree(jreq);
jsonfree(jres);
ret.asprompt = makeprompt(ret.role);
ret.asprompt->content = strdup(ret.message);
return ret;
}
int
addstrprompt(ORequest *r, char *role, char *content, ...)
{
OPrompt *p;
char *s;
va_list arg;
va_start(arg, content);
s = vsmprint(content, arg);
va_end(arg);
if (!r->prompts) {
p = mallocz(sizeof(OPrompt), 1);
r->prompts = p;
goto Fill;
}
for (p = r->prompts; p->next; p = p->next)
;
p->next = mallocz(sizeof(OPrompt), 1);
p = p->next;
Fill:
p->role = strdup(role);
p->content = s;
return 1;
}
int
addprompt(ORequest *r, OPrompt *p)
{
OPrompt *pr;
if (!p)
return 1;
p->next = nil;
if (!r->prompts) {
r->prompts = p;
return 1;
}
for (pr = r->prompts; pr->next; pr = pr->next)
;
pr->next = p;
return 1;
}
OPrompt*
makeprompt(char *role)
{
OPrompt *p;
p = mallocz(sizeof(OPrompt), 1);
p->role = strdup(role);
return p;
}
static JSONEl*
addgenmessage(OPrompt *p, char *type, char *cname, char *content)
{
JSON *j;
JSONEl *jel;
if (p->content) {
werrstr("prompt has string content");
return nil;
}
if (!p->jcontent) {
p->jcontent = mallocz(sizeof(JSON), 1);
p->jcontent->t = JSONArray;
} else {
if (p->jcontent->t != JSONArray) {
werrstr("prompt json is not array");
return nil;
}
}
j = p->jcontent;
if (!j->first) {
j->first = mallocz(sizeof(JSONEl), 1);
jel = j->first;
goto Fill;
}
for (jel = j->first; jel->next; jel = jel->next)
;
jel = jel->next = mallocz(sizeof(JSONEl), 1);
Fill:
jel->name = nil;
jel->val = mallocz(sizeof(JSON), 1);
j = jel->val;
j->t = JSONObject;
jel = j->first = mallocz(sizeof(JSONEl), 1);
jel->name = strdup("type");
jel->val = mallocz(sizeof(JSON), 1);
jel->val->t = JSONString;
jel->val->s = strdup(type);
if (!cname)
return jel;
jel = jel->next = mallocz(sizeof(JSONEl), 1);
jel->name = strdup(cname);
jel->val = mallocz(sizeof(JSON), 1);
jel->val->t = JSONString;
jel->val->s = strdup(content);
return jel;
}
static JSONEl*
appendfield(JSONEl *jel, char *name, char *value)
{
jel->next = mallocz(sizeof(JSONEl), 1);
jel = jel->next;
jel->name = strdup(name);
jel->val = mallocz(sizeof(JSON), 1);
jel->val->t = JSONString;
jel->val->s = strdup(value);
return jel;
}
int
addtextmessage(OPrompt *p, char *text)
{
return !!addgenmessage(p, "text", "text", text);
}
int
addfilemessage(OPrompt *p, uchar *data, long ndata, char *filename, char *fileid)
{
char *b64;
JSONEl *jel;
if (data) {
b64 = mallocz(ndata*2+1, 1);
if (enc64(b64, ndata*2+1, data, ndata) < 0) {
werrstr("enc64 error");
free(b64);
return 0;
}
jel = addgenmessage(p, "file", "file_data", b64);
free(b64);
} else {
jel = addgenmessage(p, "file", nil, nil);
}
if (!jel)
return 0;
if (filename)
jel = appendfield(jel, "filename", filename);
if (fileid)
jel = appendfield(jel, "file_id", fileid);
return !!jel;
}
OTool*
maketool(OTool *tool, Tooltype type, char *name, char *description, char *parameters, char* (*f)(OToolcall,void*), void *aux)
{
OTool *t, *n;
t = mallocz(sizeof(OTool), 1);
assert(t);
t->type = type;
t->name = strdup(name);
t->description = strdup(description);
if (parameters && parameters[0])
t->parameters = jsonparse(parameters);
t->func = f;
t->aux = aux;
if (!tool)
return t;
for (n = tool; n->next; n = n->next)
;
n->next = t;
return t;
}