shithub: furgit

Download patch

ref: c8f00194c617796e2b83f715b4d2ece80a34a716
parent: fdf1bd9a68091a1640b79fcb4ea979ed10f8b904
author: Runxi Yu <me@runxiyu.org>
date: Thu Mar 5 13:07:04 EST 2026

internal/compress/zlib: Use flate's compression consumed counter

--- a/internal/compress/zlib/reader.go
+++ b/internal/compress/zlib/reader.go
@@ -71,36 +71,14 @@
 type Reader struct {
 	r            flate.Reader
 	decompressor io.ReadCloser
+	progress     flate.InputProgress
 	digest       hash.Hash32
-	counter      *countingFlateReader
+	headerRead   uint64
+	trailerRead  uint64
 	err          error
 	scratch      [4]byte
 }
 
-// countingFlateReader wraps flate input and tracks consumed bytes.
-type countingFlateReader struct {
-	inner flate.Reader
-	read  uint64
-}
-
-// Read implements io.Reader.
-func (reader *countingFlateReader) Read(dst []byte) (int, error) {
-	n, err := reader.inner.Read(dst)
-	reader.read += uint64(n)
-
-	return n, err
-}
-
-// ReadByte implements io.ByteReader.
-func (reader *countingFlateReader) ReadByte() (byte, error) {
-	b, err := reader.inner.ReadByte()
-	if err == nil {
-		reader.read++
-	}
-
-	return b, err
-}
-
 // NewReader creates a new ReadCloser.
 // Reads from the returned ReadCloser read and decompress data from r.
 // If r does not implement [io.ByteReader], the decompressor may read more
@@ -152,7 +130,8 @@
 	}
 
 	// Finished file; check checksum.
-	_, err = io.ReadFull(z.r, z.scratch[0:4])
+	readN, err := io.ReadFull(z.r, z.scratch[0:4])
+	z.trailerRead += uint64(readN)
 	if err != nil {
 		if errors.Is(err, io.EOF) {
 			err = io.ErrUnexpectedEOF
@@ -178,11 +157,12 @@
 // This count includes the zlib header, deflate payload, and zlib checksum
 // trailer bytes read by the reader.
 func (z *Reader) InputConsumed() uint64 {
-	if z.counter == nil {
-		return 0
+	out := z.headerRead + z.trailerRead
+	if z.progress != nil {
+		out += uint64(z.progress.InputConsumed())
 	}
 
-	return z.counter.read
+	return out
 }
 
 // Close does not close the wrapped [io.Reader] originally passed to [NewReader].
--- a/internal/compress/zlib/reader_reset.go
+++ b/internal/compress/zlib/reader_reset.go
@@ -25,11 +25,12 @@
 		input = bufio.NewReader(r)
 	}
 
-	z.counter = &countingFlateReader{inner: input}
-	z.r = z.counter
+	z.r = input
 
 	// Read the header (RFC 1950 section 2.2.).
-	_, z.err = io.ReadFull(z.r, z.scratch[0:2])
+	readN, err := io.ReadFull(z.r, z.scratch[0:2])
+	z.headerRead += uint64(readN)
+	z.err = err
 	if z.err != nil {
 		if errors.Is(z.err, io.EOF) {
 			z.err = io.ErrUnexpectedEOF
@@ -47,7 +48,8 @@
 
 	haveDict := z.scratch[1]&0x20 != 0
 	if haveDict {
-		_, z.err = io.ReadFull(z.r, z.scratch[0:4])
+		readN, z.err = io.ReadFull(z.r, z.scratch[0:4])
+		z.headerRead += uint64(readN)
 		if z.err != nil {
 			if errors.Is(z.err, io.EOF) {
 				z.err = io.ErrUnexpectedEOF
@@ -74,6 +76,11 @@
 		if z.err != nil {
 			return z.err
 		}
+		progress, ok := z.decompressor.(flate.InputProgress)
+		if !ok {
+			panic("zlib: pooled decompressor does not implement flate.InputProgress")
+		}
+		z.progress = progress
 
 		z.digest = adler32.New()
 
@@ -85,6 +92,11 @@
 	} else {
 		z.decompressor = flate.NewReader(z.r)
 	}
+	progress, ok := z.decompressor.(flate.InputProgress)
+	if !ok {
+		panic("zlib: decompressor does not implement flate.InputProgress")
+	}
+	z.progress = progress
 
 	z.digest = adler32.New()
 
--