ref: 07398654cc5684fec2ff81fc1b21c3862536291c
dir: /ann/main.c/
#include <u.h> #include <libc.h> #include <ctype.h> #include "ann.h" void usage(char **argv) { fprint(2, "usage: %s [-train] filename [num_layers num_input_layer ... num_output_layer]\n", argv[0]); exits("usage"); } void main(int argc, char **argv) { Ann *ann; char *filename; int train; Dir *dir; int num_layers = 0; int *layers = nil; int i; char *line; double *input; double *output = nil; double *runoutput; int ninput; int noutput; int offset; double f; int trainline; int nline; train = 0; if (argc < 2) usage(argv); filename = argv[1]; if (argv[1][0] == '-' && argv[1][1] == 't') { if (argc < 3) usage(argv); train = 1; filename = argv[2]; } ann = nil; dir = dirstat(filename); if (dir != nil) { free(dir); ann = annload(filename); if (ann == nil) exits("load"); } if (argc >= (train + 3)) { num_layers = atoi(argv[train + 2]); if (num_layers < 2 || argc != (train + 3 + num_layers)) usage(argv); layers = calloc(num_layers, sizeof(int)); for (i = 0; i < num_layers; i++) layers[i] = atoi(argv[train + 3 + i]); } if (num_layers > 0) { if (ann != nil) { if (ann->n != num_layers) { fprint(2, "num_layers: %d != %d\n", ann->n, num_layers); exits("num_layers"); } for (i = 0; i < num_layers; i++) { if (layers[i] != ann->layers[i]->n) { fprint(2, "num_layer_%d: %d != %d\n", i, layers[i], ann->layers[i]->n); exits("num_layer"); } } } else { ann = anncreatev(num_layers, layers); if (ann == nil) exits("anncreatev"); } } if (ann == nil) { fprint(2, "file not found: %s\n", filename); exits("file not found"); } ninput = ann->layers[0]->n; noutput = ann->layers[ann->n - 1]->n; input = calloc(ninput, sizeof(double)); if (train == 1) output = calloc(noutput, sizeof(double)); trainline = 0; nline = ninput; do { int i = 0; while ((line = readline(0)) != nil) { do { if (strlen(line) == 0) break; while(isspace(*line)) line++; if (strlen(line) == 0) break; offset = 0; while (isdigit(line[offset]) || line[offset] == '.' || line[offset] == '-') offset++; if (!isspace(line[offset]) && line[offset] != '\0') { fprint(2, "input error: %s\n", line); exits("input"); } f = atof(line); if (trainline == 0) { input[i] = f; i++; } else { output[i] = f; i++; } line = &line[offset]; while(isspace(*line)) line++; } while(i < nline && strlen(line) > 0); if (i == nline) { if (trainline == 0) { runoutput = annrun(ann, input); for (i = 0; i < noutput; i++) /* if (runoutput[i] == 0.0) print("0%c", (i == (noutput-1))? '\n': ' '); else if (runoutput[i] == 1.0) print("1%c", (i == (noutput-1))? '\n': ' '); else */ print("%f%c", runoutput[i], (i == (noutput-1))? '\n': ' '); free(runoutput); } if (train == 1) { if (trainline == 0) { trainline = 1; nline = noutput; } else { anntrain(ann, input, output); trainline = 0; nline = ninput; } } i = 0; } } } while(line != nil); if (train == 1 && annsave(filename, ann) != 0) exits("save"); exits(nil); }