shithub: util

ref: 07398654cc5684fec2ff81fc1b21c3862536291c
dir: /ann/main.c/

View raw version
#include <u.h>
#include <libc.h>
#include <ctype.h>
#include "ann.h"

void
usage(char **argv)
{
	fprint(2, "usage: %s [-train] filename [num_layers num_input_layer ... num_output_layer]\n", argv[0]);
	exits("usage");
}

void
main(int argc, char **argv)
{
	Ann *ann;
	char *filename;
	int train;
	Dir *dir;
	int num_layers = 0;
	int *layers = nil;
	int i;
	char *line;
	double *input;
	double *output = nil;
	double *runoutput;
	int ninput;
	int noutput;
	int offset;
	double f;
	int trainline;
	int nline;

	train = 0;

	if (argc < 2)
		usage(argv);

	filename = argv[1];

	if (argv[1][0] == '-' && argv[1][1] == 't') {
		if (argc < 3)
			usage(argv);

		train = 1;
		filename = argv[2];
	}

	ann = nil;
	dir = dirstat(filename);
	if (dir != nil) {
		free(dir);
		ann = annload(filename);
		if (ann == nil)
			exits("load");
	}

	if (argc >= (train + 3)) {
		num_layers = atoi(argv[train + 2]);

		if (num_layers < 2 || argc != (train + 3 + num_layers))
			usage(argv);

		layers = calloc(num_layers, sizeof(int));

		for (i = 0; i < num_layers; i++)
			layers[i] = atoi(argv[train + 3 + i]);
	}

	if (num_layers > 0) {
		if (ann != nil) {
			if (ann->n != num_layers) {
				fprint(2, "num_layers: %d != %d\n", ann->n, num_layers);
				exits("num_layers");
			}

			for (i = 0; i < num_layers; i++) {
				if (layers[i] != ann->layers[i]->n) {
					fprint(2, "num_layer_%d: %d != %d\n", i, layers[i], ann->layers[i]->n);
					exits("num_layer");
				}
			}
		} else {
			ann = anncreatev(num_layers, layers);
			if (ann == nil)
				exits("anncreatev");
		}
	}

	if (ann == nil) {
		fprint(2, "file not found: %s\n", filename);
		exits("file not found");
	}

	ninput = ann->layers[0]->n;
	noutput = ann->layers[ann->n - 1]->n;
	input = calloc(ninput, sizeof(double));
	if (train == 1)
		output = calloc(noutput, sizeof(double));

	trainline = 0;
	nline = ninput;

	do {
		int i = 0;
		while ((line = readline(0)) != nil) {
			do {
				if (strlen(line) == 0)
					break;
				while(isspace(*line))
					line++;
				if (strlen(line) == 0)
					break;
				offset = 0;
				while (isdigit(line[offset]) || line[offset] == '.' || line[offset] == '-')
					offset++;
				if (!isspace(line[offset]) && line[offset] != '\0') {
					fprint(2, "input error: %s\n", line);
					exits("input");
				}
				f = atof(line);
				if (trainline == 0) {
					input[i] = f;
					i++;
				} else {
					output[i] = f;
					i++;
				}
				line = &line[offset];
				while(isspace(*line))
					line++;
			} while(i < nline && strlen(line) > 0);

			if (i == nline) {
				if (trainline == 0) {
					runoutput = annrun(ann, input);
					for (i = 0; i < noutput; i++)
/*						if (runoutput[i] == 0.0)
							print("0%c", (i == (noutput-1))? '\n': ' ');
						else if (runoutput[i] == 1.0)
							print("1%c", (i == (noutput-1))? '\n': ' ');
						else */
							print("%f%c", runoutput[i], (i == (noutput-1))? '\n': ' ');
					free(runoutput);
				}

				if (train == 1) {
					if (trainline == 0) {
						trainline = 1;
						nline = noutput;
					} else {
						anntrain(ann, input, output);
						trainline = 0;
						nline = ninput;
					}
				}
				i = 0;
			}
		}
	} while(line != nil);

	if (train == 1 && annsave(filename, ann) != 0)
		exits("save");

	exits(nil);
}