shithub: ridefs

ref: 8eef000ed6e16e02354e6798f227223cc38b2d42
dir: /ridefs.c/

View raw version
#include <u.h>
#include <libc.h>
#include <fcall.h>
#include <thread.h>
#include <9p.h>
#include <stdio.h>
#include <json.h>


typedef struct Client Client;
typedef struct Rfid Rfid;

struct Client {
	Ref;

	char *addr;
	ulong umask;
	int timeout;

	/* internal use */
	char *user;
	char *pres;  /* previous result */
	long time0;
	int id;
	int oio;     /* io opened */
	int iopid;   /* read/write fork */
	int fd;      /* data */
	int cfd;     /* ctl */

	/* id reply info */
	int Rapiversion;
	int Rport;
	int Rpid;
	char *Ripaddress;
	char *Rvendor;
	char *Rlanguage;
	char *Rversion;
	char *Rmachine;
	char *Rarch;
	char *Rproject;
	char *Rprocess;
	char *Ruser;
	char *Rtoken;
	char *Rdate;
	char *Rplatform;
};

enum {
	Qroot,
		Qrctl,
		Qclone,
		Qclient,
			Qctl,
			Qio,
	QCOUNT
};

static char *nametab[] = {
	"/",
		"ctl",
		"clone",
		nil,
			"ctl",
			"io",
};

struct Rfid {
	Ref;
	int kind;
	int client;
};

#define RIDESRV_VERS 0
static Client *cpool;
static char *mtpt;
static char *service;
static char *net;
static char *user;
static ulong umask;
static long bufsz;
static long time0;
static uint nclients;
static int debug;
static int timeout;

void*
ecalloc(ulong n){
	void *p;

	p = emalloc9p(n);
	setmalloctag(p, getcallerpc(&n));
	memset(p, 0, n);

	return p;
}

void*
erealloc(void *p, ulong n){
	p = erealloc9p(p, n);
	setrealloctag(p, getcallerpc(&n));
	return p;
}

char*
estrdup(char *s){
	s = estrdup9p(s);
	setrealloctag(s, getcallerpc(&s));
	return s;
}

int
mkclient(void){
	Client *c;
	int i;

	for(i = 0; i < nclients; i++)
		if(cpool[i].ref == 0)
			break;
	if(i == nelem(cpool))
		return -1;
	c = &cpool[i];

	incref(c);
	c->id = i;
	c->timeout = timeout;
	c->umask = umask;
	c->user = estrdup(getuser());
	c->time0 = time(0);

	return i;
}

Client*
clientref(int i){
	if(i < 0 || i > nclients)
		return nil;

	return &cpool[i];
}

void
rmclient(int i){
	Client *c;

	c = clientref(i);
	if(c == nil || decref(c))
		return;

	if(c->Ripaddress) free(c->Ripaddress);
	if(c->Rvendor) free(c->Rvendor);
	if(c->Rlanguage) free(c->Rlanguage);
	if(c->Rversion) free(c->Rversion);
	if(c->Rmachine) free(c->Rmachine);
	if(c->Rarch) free(c->Rarch);
	if(c->Rproject) free(c->Rproject);
	if(c->Rprocess) free(c->Rprocess);
	if(c->Ruser) free(c->Ruser);
	if(c->Rtoken) free(c->Rtoken);
	if(c->Rdate) free(c->Rdate);
	if(c->Rplatform) free(c->Rplatform);

	if(c->user) free(c->user);
	if(c->fd) close(c->fd);
	if(c->cfd) close(c->cfd);

	memset(c, 0, sizeof(*c));
}

static void
mkqid(Qid *q, int k, Client *c){
	q->vers = RIDESRV_VERS;
	q->path = ((u64int)c->id<<32) | k&0xffffffff;

	switch(k){
	case Qroot:
	case Qclient:
		q->type = QTDIR; break;
	default:
		q->type = QTFILE;
	}
}

long
writemsg(int fd, void *pld, long n){
	long len;
	char *r;

	len = n+8;
	r = ecalloc(len+8);
	r[0] = 24>>len & 0xff;
	r[1] = 16>>len & 0xff;
	r[2] =  8>>len & 0xff;
	r[3] =     len & 0xff;
	r[4] = 'R';
	r[5] = 'I';
	r[6] = 'D';
	r[7] = 'E';
	memcpy(&r[8], pld, n);
	len = write(fd, r, len);

	return len;
}

long
readmsg(int fd, char *pld, long n){
	int len;
	char buf[9];

	if(0 > readn(fd, buf, 8))
		return -1;
	buf[9] = '\0';
	if(0 != strcmp(&buf[4], "RIDE"))
		return -2;
	len = -8 + (buf[0]<<24 | buf[1]<<16 | buf[2]<<8 | buf[3]);
	if(len > n)
		return -3;
	if(0 > readn(fd, pld, len))
		return -4;

	pld = erealloc(pld, len+1);
	pld[len] = '\0';
	return len;
}

char *
rideinit(int client){
	int fd;
	char *addr, *pld, *s;
	Client *c;
	JSON *j;
	JSONEl *d;

	c = clientref(client);
	addr = netmkaddr(c->addr, net, "tcp");
	if((fd = dial(addr, nil, nil, &c->cfd)) < 0)
		return "failed to dial addr";
	c = clientref(client);
	c->fd = fd;

	pld = ecalloc(bufsz);
	if(0 > readmsg(fd, pld, bufsz))
		return "failed to read handshake";
	if(0 != strcmp(pld, "SupportedProtocols=2"))
		return "unrecognized protocol";
	free(pld);

	pld = "UsingProtocol=2";
	if(0 > writemsg(fd, pld, strlen(pld)))
		return "failed to write handshake";
		
	pld = "[\"Identify\",{\"apiVersion\":1,\"identity\":1}]";
	if(0 > writemsg(fd, pld, strlen(pld)))
		return "failed to send identification message";

	pld = ecalloc(bufsz);
	if(0 > readmsg(fd, pld, bufsz))
		return "failed to receive identification message";
	j = jsonparse(pld);
	free(pld);
	if(j == nil || j->t != JSONArray || nil == j->first)
		return "unrecognized reply";
	if(nil == (s = jsonstr(j->first->val)) || 0 != strcmp(s, "ReplyIdentify"))
		return "unexpected identification reply";
	if(nil == (d = j->first->next) || d->val->t != JSONObject)
		return "malformed identification reply";

	c->Rapiversion = jsonbyname(d->val, "apiVersion")->n;
	c->Rport       = jsonbyname(d->val, "Port")->n;
	c->Rpid        = jsonbyname(d->val, "pid")->n;
	c->Ripaddress  = estrdup(jsonstr(jsonbyname(d->val, "IPAddress")));
	c->Rvendor     = estrdup(jsonstr(jsonbyname(d->val, "Vendor")));
	c->Rlanguage   = estrdup(jsonstr(jsonbyname(d->val, "Language")));
	c->Rversion    = estrdup(jsonstr(jsonbyname(d->val, "version")));
	c->Rmachine    = estrdup(jsonstr(jsonbyname(d->val, "Machine")));
	c->Rarch       = estrdup(jsonstr(jsonbyname(d->val, "arch")));
	c->Rproject    = estrdup(jsonstr(jsonbyname(d->val, "Project")));
	c->Rprocess    = estrdup(jsonstr(jsonbyname(d->val, "Process")));
	c->Ruser       = estrdup(jsonstr(jsonbyname(d->val, "User")));
	c->Rtoken      = estrdup(jsonstr(jsonbyname(d->val, "token")));
	c->Rdate       = estrdup(jsonstr(jsonbyname(d->val, "date")));
	c->Rplatform   = estrdup(jsonstr(jsonbyname(d->val, "platform")));

	jsonfree(j);
	return nil;
}


long
readqrctl(char **buf){
	char *b;

	b = ecalloc(bufsz);
	sprintf(b,
		"version %i\n"
		"bufsz %u\n"
		"nclients %u\n"
		"debug %i\n"
		"timeout %i\n"
		"umask %o\n",
		RIDESRV_VERS, bufsz, nclients,
		debug, timeout, umask);

	*buf = b;
	return strlen(b);
}

long
writeqrctl(char *b, long n){
	char *s;
	char *sep = " 	";

	for(s = strtok(b, sep); s != nil; s = strtok(nil, sep)){
		if(strcmp(s, "bufsz") == 0)
			bufsz = strtoul(strtok(nil, sep), nil, 0);
		if(strcmp(s, "umask") == 0)
			umask = strtoul(strtok(nil, sep), nil, 0);
		if(strcmp(s, "nclients") == 0)
			nclients = strtoul(strtok(nil, sep), nil, 0);
		if(strcmp(s, "debug") == 0)
			debug = atoi(strtok(nil, sep));
		if(strcmp(s, "timeout") == 0)
			timeout = atoi(strtok(nil, sep));
	}

	cpool = erealloc(cpool, nclients*sizeof(*cpool));

	return n;
}

long
readqctl(char **buf, Client *c){
	char *b, r[1024];

	b = ecalloc(bufsz);
	sprintf(b,
		"connect %s\n",
		"timeout %i\n",
		"umask %l",
		c->addr, c->timeout, c->umask);
	if(c->oio){
		sprintf(r,
			"Rapiversion %i\n"
			"Rport %i\n"
			"Rpid %i\n"
			"Ripaddress %s\n"
			"Rvendor %s\n"
			"Rlanguage %s\n"
			"Rversion %s\n"
			"Rmachine %s\n"
			"Rarch %s\n"
			"Rproject %s\n"
			"Rprocess %s\n"
			"Ruser %s\n"
			"Rtoken %s\n"
			"Rdate %s\n"
			"Rplatform %s\n",
			c->Rapiversion, c->Rport, c->Rpid, c->Ripaddress,
			c->Rvendor, c->Rlanguage, c->Rversion, c->Rmachine,
			c->Rarch, c->Rproject, c->Rprocess, c->Ruser,
			c->Rtoken, c->Rdate, c->Rplatform);
		strcpy(b, r);
	}

	*buf = b;
	return strlen(b);
}

long
writeqctl(char *b, long n, Client *c){
	char *s, *sep;

	sep = " 	";
	for(s = strtok(b, sep); s != nil; s = strtok(nil, sep)){
		if(strcmp(s, "connect") == 0)
			c->addr = estrdup(strtok(nil, sep));
		if(strcmp(s, "timeout") == 0)
			c->timeout = atoi(strtok(nil, sep));
		if(strcmp(s, "umask") == 0)
			c->umask = strtoul(strtok(nil, sep), nil, 0);
	}

	return n;
}

long
readqio(char **buf, Client *c){
	*buf = estrdup(c->pres);
	return strlen(*buf);
}

long
writeqio(char *b, long, Client *c){
	char *pld, *e, *p, *z;
	long sz;

	pld = z = ecalloc(bufsz);
	e = pld + bufsz - 1;
	sz = -1;
	if(p == (z = strecpy(p = z, e, "[\"Execute\",{\"text\":\"      ")))
		goto end;
	if(p == (z = strecpy(p = z, e, b)))
		goto end;
	if(p == strecpy(p = z, e, "\",\"trace\":0}]"))
		goto end;

	sz = strlen(pld);
	writemsg(c->fd, pld, sz);
	readmsg(c->fd, pld, bufsz);
	c->pres = estrdup(pld);

	end:
	free(pld);
	return sz;
}

void
mkdirent(Dir *d, int kind, Client *c){
	char *nm, *buf;

	mkqid(&d->qid, kind, c);
	d->mode = 0444 & umask;
	if(nil != (nm = nametab[kind]))
		d->name = nm;

	buf = ecalloc(bufsz);
	switch(kind){
	case Qroot:
		d->mode = 0777 & umask;
	case Qrctl:
		d->atime = d->mtime = time0;
		d->uid = estrdup(user);
		d->gid = estrdup(user);
		d->muid = estrdup(user);
		break;
	case Qclient:
		d->mode = 0777 & c->umask;
		sprintf(buf, "%i", c->id);
		d->name = estrdup(buf);
	default:
		d->atime = d->mtime = c->time0;
		d->uid = estrdup(c->user);
		d->gid = estrdup(c->user);
		d->muid = estrdup(c->user);
	}

	switch(kind){
	case Qrctl: d->length = readqrctl(&buf);
	case Qctl: d->length = readqctl(&buf, c);
	case Qio: d->length = readqio(&buf, c);
	}

	free(buf);
}

int
genqroot(int i, Dir *d, void*){
	static int n;
	int j;

	i += Qroot + 1;
	if(i < Qclient){
		mkdirent(d, i, nil);
	} else {
		i -= Qclient;
		if(i == 0)
			n = 0;
		for(j = n; j < nclients && cpool[j].ref == 0; j++);
		if(j == nclients)
			return -1;
		n++;
		mkdirent(d, Qclient, clientref(j));
	}

	return 0;
}

int
genqclient(int i, Dir *d, void *aux){
	Client *c;

	c = aux;
	i += Qclient + 1;
	if(i >= QCOUNT)
		return -1;
	mkdirent(d, i, c);
	return 0;
}


static void
fsdestroyfid(Fid *fid){
	Rfid *f;

	f = fid->aux;
	if(-1 < f->client){
		rmclient(f->client);
		free(f);
	}
}

static void
fsstart(Srv*){
	if(mtpt != nil)
		unmount(nil, mtpt);
}

static void
fsend(Srv*){
	postnote(PNGROUP, getpid(), "shutdown");
	exits(nil);
}

static void
fsattach(Req *r){
	Rfid *f;
	
	f = ecalloc(sizeof(*f));
	f->kind = Qroot;

	mkqid(&r->fid->qid, f->kind, nil);
	r->fid->aux = f;
	r->ofcall.qid = r->fid->qid;

	respond(r, nil);
}

static void
fsopen(Req *r){
	int e, pid;
	char *res;
	Rfid *f;
	Client *c;

	f = r->fid->aux;
	c = clientref(f->client);

	switch(f->kind){
	case Qclone:
		if((f->client = mkclient()) == -1){
			respond(r, "reached client limit");
			return;
		}

		f->kind = Qctl;
		c = clientref(f->client);

		mkqid(&r->ofcall.qid, f->kind, c);
		r->fid->qid = r->ofcall.qid;

		respond(r, nil);
		break;
	case Qio:
		if(e = c->oio)
			respond(r, "client in use");
		else if(e = (nil == c->addr))
			respond(r, "no server set");
		if(e)
			return;

		switch(pid = rfork(RFPROC|RFNOWAIT|RFMEM)){
		case 0:
			alarm(timeout);
			res = rideinit(f->client);
			alarm(0);
			if(res == nil)
				c->oio = 1;
			c->iopid = 0;
			respond(r, res);
			break;
		case -1:
			respond(r, "failed to init ride");
			break;
		default:
			c->iopid = pid;
		}
		break;
	}
}

static void
fsread(Req *r){
	Rfid *f;
	Client *c;
	char *buf;
	long n;

	buf = nil;
	n = -1;
	f = r->fid->aux;
	c = clientref(f->client);
	switch(f->kind){
	case Qroot: dirread9p(r, genqroot, nil); break;
	case Qrctl: n = readqrctl(&buf); break;
	case Qclone: respond(r, "read prohibited"); return;
	case Qclient: dirread9p(r, genqclient, c); break;
	case Qctl: n = readqctl(&buf, c); break;
	case Qio: n = readqio(&buf, c); break;
	}

	if(buf != nil){
		readbuf(r, buf, n);
		free(buf);
	}

	respond(r, nil);
}

static void
fswrite(Req *r){
	Rfid *f;
	Client *c;
	int pid;

	f = r->fid->aux;
	c = clientref(f->client);

	switch(f->kind){
	case Qrctl:
		r->ofcall.count = writeqrctl(r->ifcall.data, r->ifcall.count);
		break;
	case Qctl:
		r->ofcall.count = writeqctl(r->ifcall.data, r->ifcall.count, c);
		break;
	case Qio:
		switch(pid = rfork(RFPROC|RFNOWAIT|RFMEM)){
		case 0:
			alarm(c->timeout);
			r->ofcall.count = writeqio(r->ifcall.data, r->ifcall.count, c);
			alarm(0);
			c->iopid = 0;
			respond(r, nil);
			break;
		case -1:
			respond(r, "failed to send command");
			break;
		default:
			c->iopid = pid;
		}
		break;
	default:
		respond(r, "write prohibited"); return;
	}
}

static void
fsflush(Req *r){
	Rfid *f;
	Client *c;

	f = r->fid->aux;
	c = clientref(f->client);

	if(0 < c->iopid){
		postnote(PNPROC, c->iopid, "interrupt");
		respond(r, "interrupted");
	}

	respond(r, nil);
}

static void
fsstat(Req *r){
	Rfid *f;
	Client *c;

	f = r->fid->aux;
	c = clientref(f->client);
	mkdirent(&r->d, f->kind, c);

	respond(r, nil);
}

static char*
fswalk1(Fid *fid, char *name, Qid *qid){
	Rfid *f;
	Client *c;
	int i, n;
	char *nend;

	if(!(fid->qid.type&QTDIR))
		return "cannot walk from non-directory";

	f = fid->aux;
	n = -1;
	if(strcmp(name, "..") == 0){
		switch(f->kind){
		case Qroot:
			break;
		case Qclient:
			rmclient(f->client);
			f->client = -1;
			break;
		default:
			if(f->kind > Qclient)
				f->kind = Qclient;
		}
	} else {
		for(i = f->kind+1; i<QCOUNT; i++){
			if(nametab[i] && strcmp(name, nametab[i]) == 0)
				break;
			if(i == Qclient){
				n = strtol(name, &nend, 10);
				if(*nend == 0 && nil != (c = clientref(n)) && c->ref != 0){
					f->client = n;
					incref(c);
					break;
				}	
			}
		}
		if(i >= QCOUNT)
			return "directory entry not found";
		f->kind = i;
	}
	mkqid(qid, f->kind, clientref(n));
	fid->qid = *qid;
	return nil;
}

static char*
fsclone(Fid *oldfid, Fid *newfid){
	Rfid *f, *o;

	o = oldfid->aux;
	if(o == nil)
		return "bad fid";

	f = ecalloc(sizeof(*f));
	memmove(f, o, sizeof(*f));

	if(-1 < f->client)
		incref(clientref(f->client));
	newfid->aux = f;

	return nil;
}

Srv fs = {
	.destroyfid = fsdestroyfid,
	.start      = fsstart,
	.end        = fsend,
	.attach     = fsattach,
	.open       = fsopen,
	.read       = fsread,
	.write      = fswrite,
	.flush      = fsflush,
	.stat       = fsstat,
	.walk1      = fswalk1,
	.clone      = fsclone,
};

void
usage(void){
	fprintf(stderr, "usage: %s [-Dd] [-T timeout] [-m mtpt] [-s service] [-x net]\n", argv0);
}

void
main(int argc, char **argv){
	timeout = 10000;
	mtpt = "/mnt/ride";
	bufsz = 4096;
	umask = 0755;
	time0 = time(0);
	nclients = 256;

	ARGBEGIN{
	case 'D': chatty9p++; break;
	case 'd': debug++; break;
	case 'T': timeout = atoi(EARGF(usage()));
	case 'm': mtpt = EARGF(usage()); break;
	case 's': service = EARGF(usage()); break;
	case 'x': net = EARGF(usage()); break;
	default:  usage(); return;
	}ARGEND

	cpool = ecalloc(nclients*sizeof(*cpool));

	rfork(RFNOTEG);
	postmountsrv(&fs, service, mtpt, MREPL);
	exits(nil);
}