shithub: moonfish

Download patch

ref: e1edf15affef37c38050e1d97eba438f58dff3a9
parent: 4d8547ddd37c51e447c2aa947b82157360be36ab
author: zamfofex <zamfofex@twdb.moe>
date: Tue Oct 29 13:34:03 EDT 2024

keep tree nodes across searches

--- a/chess.c
+++ b/chess.c
@@ -821,3 +821,26 @@
 	
 	*name = 0;
 }
+
+int moonfish_equal(struct moonfish_chess *a, struct moonfish_chess *b)
+{
+	int x, y, i;
+	
+	if (a->white != b->white) return 0;
+	if (a->passing != b->passing) return 0;
+	if (a->oo[0] != b->oo[0]) return 0;
+	if (a->oo[1] != b->oo[1]) return 0;
+	if (a->ooo[0] != b->ooo[0]) return 0;
+	if (a->ooo[1] != b->ooo[1]) return 0;
+	
+	for (y = 0 ; y < 8 ; y++) {
+		for (x = 0 ; x < 8 ; x++) {
+			i = (x + 1) + (y + 2) * 10;
+			if (a->board[i] != b->board[i]) {
+				return 0;
+			}
+		}
+	}
+	
+	return 1;
+}
--- a/main.c
+++ b/main.c
@@ -8,14 +8,14 @@
 
 #include "moonfish.h"
 
-static void moonfish_go(struct moonfish_chess *chess)
+static void moonfish_go(struct moonfish_node *node)
 {
-	static struct moonfish_move move;
+	static struct moonfish_result result;
 	static struct moonfish_options options;
+	static struct moonfish_chess chess;
 	
 	long int our_time, their_time, *xtime, time;
 	char *arg, *end;
-	long int count;
 	char name[6];
 	
 	our_time = -1;
@@ -22,6 +22,8 @@
 	their_time = -1;
 	time = -1;
 	
+	moonfish_root(node, &chess);
+	
 	for (;;) {
 		
 		arg = strtok(NULL, "\r\n\t ");
@@ -29,7 +31,7 @@
 		
 		if (!strcmp(arg, "wtime") || !strcmp(arg, "btime")) {
 			
-			if (chess->white) {
+			if (chess.white) {
 				if (!strcmp(arg, "wtime")) xtime = &our_time;
 				else xtime = &their_time;
 			}
@@ -75,15 +77,16 @@
 	
 	options.max_time = time;
 	options.our_time = our_time;
-	count = moonfish_best_move(chess, &move, &options);
-	moonfish_to_uci(chess, &move, name);
+	moonfish_best_move(node, &result, &options);
+	moonfish_to_uci(&chess, &result.move, name);
 	
-	printf("info nodes %ld\n", count);
+	printf("info nodes %ld\n", result.node_count);
 	printf("bestmove %s\n", name);
 }
 
-static void moonfish_position(struct moonfish_chess *chess)
+static void moonfish_position(struct moonfish_node *node)
 {
+	static struct moonfish_chess chess, chess0;
 	static struct moonfish_move move;
 	static char line[2048];
 	
@@ -95,6 +98,8 @@
 		exit(1);
 	}
 	
+	moonfish_chess(&chess);
+	
 	if (!strcmp(arg, "fen")) {
 		
 		arg = strtok(NULL, "\r\n");
@@ -103,22 +108,18 @@
 			exit(1);
 		}
 		
-		moonfish_from_fen(chess, arg);
+		moonfish_from_fen(&chess, arg);
 		
 		arg = strstr(arg, "moves");
-		if (arg == NULL) return;
-		
-		do arg--;
-		while (*arg == '\t' || *arg == ' ');
-		
-		strcpy(line, arg);
-		strtok(line, "\r\n\t ");
+		if (arg != NULL) {
+			do arg--;
+			while (*arg == '\t' || *arg == ' ');
+			strcpy(line, arg);
+			strtok(line, "\r\n\t ");
+		}
 	}
 	else {
-		if (!strcmp(arg, "startpos")) {
-			moonfish_chess(chess);
-		}
-		else {
+		if (strcmp(arg, "startpos")) {
 			fprintf(stderr, "malformed 'position' command\n");
 			exit(1);
 		}
@@ -125,25 +126,35 @@
 	}
 	
 	arg = strtok(NULL, "\r\n\t ");
-	if (arg == NULL || strcmp(arg, "moves")) return;
 	
-	for (;;) {
-		arg = strtok(NULL, "\r\n\t ");
-		if (arg == NULL) break;
-		if (moonfish_from_uci(chess, &move, arg)) {
-			fprintf(stderr, "malformed move '%s'\n", arg);
-			exit(1);
+	if (arg != NULL && !strcmp(arg, "moves")) {
+		
+		for (;;) {
+			
+			arg = strtok(NULL, "\r\n\t ");
+			if (arg == NULL) break;
+			if (moonfish_from_uci(&chess, &move, arg)) {
+				fprintf(stderr, "malformed move '%s'\n", arg);
+				exit(1);
+			}
+			
+			moonfish_root(node, &chess0);
+			if (moonfish_equal(&chess0, &chess)) moonfish_reroot(node, &move.chess);
+			
+			chess = move.chess;
 		}
-		*chess = move.chess;
 	}
+	
+	moonfish_root(node, &chess0);
+	if (!moonfish_equal(&chess0, &chess)) moonfish_reroot(node, &chess);
 }
 
 int main(int argc, char **argv)
 {
 	static char line[2048];
-	static struct moonfish_chess chess;
 	
 	char *arg;
+	struct moonfish_node *node;
 	
 	if (argc > 1) {
 		fprintf(stderr, "usage: %s (no arguments)\n", argv[0]);
@@ -150,7 +161,7 @@
 		return 1;
 	}
 	
-	moonfish_chess(&chess);
+	node = moonfish_new();
 	
 	for (;;) {
 		
@@ -166,7 +177,7 @@
 		if (arg == NULL) continue;
 		
 		if (!strcmp(arg, "go")) {
-			moonfish_go(&chess);
+			moonfish_go(node);
 			continue;
 		}
 		
@@ -173,7 +184,7 @@
 		if (!strcmp(arg, "quit")) break;
 		
 		if (!strcmp(arg, "position")) {
-			moonfish_position(&chess);
+			moonfish_position(node);
 			continue;
 		}
 		
@@ -194,5 +205,6 @@
 		fprintf(stderr, "warning: unknown command '%s'\n", arg);
 	}
 	
+	moonfish_finish(node);
 	return 0;
 }
--- a/moonfish.h
+++ b/moonfish.h
@@ -90,6 +90,9 @@
 	unsigned char from, to;
 };
 
+/* represents cross-search state */
+struct moonfish_node;
+
 /* represents options for the search */
 struct moonfish_options {
 	long int max_time;
@@ -96,6 +99,12 @@
 	long int our_time;
 };
 
+/* represents a search result */
+struct moonfish_result {
+	struct moonfish_move move;
+	long int node_count;
+};
+
 /* the PST */
 extern double moonfish_values[];
 
@@ -113,10 +122,8 @@
 int moonfish_moves(struct moonfish_chess *chess, struct moonfish_move *moves, unsigned char from);
 
 /* tries to find the best move in the given position with the given options */
-/* the move is stored in the "move" pointer */
 /* the move found is the best for the player whose turn it is on the given position */
-/* this will return the number of positions that were looked into */
-long int moonfish_best_move(struct moonfish_chess *chess, struct moonfish_move *move, struct moonfish_options *options);
+void moonfish_best_move(struct moonfish_node *node, struct moonfish_result *result, struct moonfish_options *options);
 
 /* returns the depth-zero score for the given position */
 double moonfish_score(struct moonfish_chess *chess);
@@ -176,5 +183,21 @@
 /* returns whether the game ended due to checkmate */
 /* note: 0 means false (i.e. no checkmate) */
 int moonfish_checkmate(struct moonfish_chess *chess);
+
+/* returns whether two positions are equal */
+/* note: 0 means false (i.e. the positions are different) */
+int moonfish_equal(struct moonfish_chess *a, struct moonfish_chess *b);
+
+/* sets the state's position */
+void moonfish_reroot(struct moonfish_node *node, struct moonfish_chess *chess);
+
+/* get the state's position (it is stored in the given position pointer) */
+void moonfish_root(struct moonfish_node *node, struct moonfish_chess *chess);
+
+/* creates a new state (with the initial position) */
+struct moonfish_node *moonfish_new(void);
+
+/* frees the given state (so that it is no longer usable) */
+void moonfish_finish(struct moonfish_node *node);
 
 #endif
--- a/search.c
+++ b/search.c
@@ -84,6 +84,14 @@
 	free(node->children);
 }
 
+static void moonfish_node(struct moonfish_node *node)
+{
+	node->parent = NULL;
+	node->count = 0;
+	node->score = 0;
+	node->visits = 0;
+}
+
 static void moonfish_expand(struct moonfish_node *node)
 {
 	struct moonfish_move moves[32];
@@ -90,7 +98,6 @@
 	int x, y;
 	int count, i;
 	
-	node->count = 0;
 	node->children = NULL;
 	
 	for (y = 0 ; y < 8 ; y++) {
@@ -184,10 +191,8 @@
 	}
 }
 
-long int moonfish_best_move(struct moonfish_chess *chess, struct moonfish_move *best_move, struct moonfish_options *options)
+void moonfish_best_move(struct moonfish_node *node, struct moonfish_result *result, struct moonfish_options *options)
 {
-	static struct moonfish_node node;
-	
 	struct moonfish_node *best_node;
 	long int time, time0;
 	int i;
@@ -200,28 +205,80 @@
 	if (options->our_time >= 0 && time > options->our_time / 16) time = options->our_time / 16;
 	time -= time / 32 + 125;
 	
-	node.move.chess = *chess;
-	node.parent = NULL;
-	node.count = 0;
-	node.score = 0;
-	node.visits = 0;
-	
-	moonfish_search(&node, 0x800);
+	moonfish_search(node, 0x800);
 	while (moonfish_clock() - time0 < time) {
-		moonfish_search(&node, 0x2000);
+		moonfish_search(node, 0x2000);
 	}
 	
 	best_visits = -1;
 	best_node = NULL;
 	
-	for (i = 0 ; i < node.count ; i++) {
-		if (node.children[i].visits > best_visits) {
-			best_node = node.children + i;
+	for (i = 0 ; i < node->count ; i++) {
+		if (node->children[i].visits > best_visits) {
+			best_node = node->children + i;
 			best_visits = best_node->visits;
 		}
 	}
 	
-	*best_move = best_node->move;
-	moonfish_discard(&node);
-	return node.visits;
+	result->move = best_node->move;
+	result->node_count = node->visits;
+}
+
+void moonfish_reroot(struct moonfish_node *node, struct moonfish_chess *chess)
+{
+	int i, j;
+	struct moonfish_node *children;
+	
+	children = node->children;
+	
+	for (i = 0 ; i < node->count ; i++) {
+		if (moonfish_equal(&children[i].move.chess, chess)) break;
+	}
+	
+	if (i == node->count) {
+		moonfish_discard(node);
+		moonfish_node(node);
+		node->move.chess = *chess;
+		return;
+	}
+	
+	for (j = 0 ; j < node->count ; j++) {
+		if (i == j) continue;
+		moonfish_discard(children + j);
+	}
+	
+	*node = children[i];
+	node->parent = NULL;
+	free(children);
+	
+	for (i = 0 ; i < node->count ; i++) {
+		node->children[i].parent = node;
+	}
+}
+
+void moonfish_root(struct moonfish_node *node, struct moonfish_chess *chess)
+{
+	*chess = node->move.chess;
+}
+
+struct moonfish_node *moonfish_new(void)
+{
+	struct moonfish_node *node;
+	
+	node = malloc(sizeof *node);
+	if (node == NULL) {
+		perror("malloc");
+		exit(1);
+	}
+	
+	moonfish_node(node);
+	moonfish_chess(&node->move.chess);
+	
+	return node;
+}
+
+void moonfish_finish(struct moonfish_node *node)
+{
+	moonfish_discard(node);
+	free(node);
 }
--