shithub: furgit

ref: c1f17baa57bad0f61e639fc39c8cd5e4872142f6
dir: /internal/flatex/slice_inflate.go/

View raw version
package flatex

import (
	"io"
	"math/bits"
	"sync"
)

// sliceInflater is a specialized DEFLATE decoder that reads directly from an
// in-memory byte slice. It mirrors the main decompressor but avoids the
// overhead of the Reader interfaces, enabling faster byte-slice decoding.
type sliceInflater struct {
	input   []byte
	pos     int
	roffset int64

	b  uint32
	nb uint

	h1, h2 huffmanDecoder

	bits     *[maxNumLit + maxNumDist]int
	codebits *[numCodes]int

	window windowDecoder

	toRead    []byte
	step      func(*sliceInflater)
	stepState int
	final     bool
	err       error
	hl, hd    *huffmanDecoder
	copyLen   int
	copyDist  int
}

var sliceInflaterPool = sync.Pool{
	New: func() any {
		fixedHuffmanDecoderInit()
		return &sliceInflater{
			bits:     new([maxNumLit + maxNumDist]int),
			codebits: new([numCodes]int),
		}
	},
}

func (f *sliceInflater) reset(src []byte) error {
	bits := f.bits
	codebits := f.codebits
	windowState := f.window
	*f = sliceInflater{
		input:    src,
		bits:     bits,
		codebits: codebits,
		window:   windowState,
		step:     (*sliceInflater).nextBlock,
	}
	f.window.init(maxMatchOffset)
	return nil
}

func (f *sliceInflater) nextBlock() {
	for f.nb < 1+2 {
		if err := f.moreBits(); err != nil {
			f.err = err
			return
		}
	}
	f.final = f.b&1 == 1
	f.b >>= 1
	typ := f.b & 3
	f.b >>= 2
	f.nb -= 1 + 2
	switch typ {
	case 0:
		f.dataBlock()
	case 1:
		f.hl = &fixedHuffmanDecoder
		f.hd = nil
		f.huffmanBlock()
	case 2:
		if err := f.readHuffman(); err != nil {
			f.err = err
			return
		}
		f.hl = &f.h1
		f.hd = &f.h2
		f.huffmanBlock()
	default:
		f.err = CorruptInputError(f.roffset)
	}
}

func (f *sliceInflater) huffmanBlock() {
	const (
		stateInit = iota
		stateDict
	)
	switch f.stepState {
	case stateInit:
		goto readLiteral
	case stateDict:
		goto copyHistory
	}

readLiteral:
	{
		v, err := f.huffSym(f.hl)
		if err != nil {
			f.err = err
			return
		}
		var n uint
		var length int
		switch {
		case v < 256:
			f.window.writeByte(byte(v))
			if f.window.availWrite() == 0 {
				f.toRead = f.window.readFlush()
				f.step = (*sliceInflater).huffmanBlock
				f.stepState = stateInit
				return
			}
			goto readLiteral
		case v == 256:
			f.finishBlock()
			return
		case v < 265:
			length = v - (257 - 3)
			n = 0
		case v < 269:
			length = v*2 - (265*2 - 11)
			n = 1
		case v < 273:
			length = v*4 - (269*4 - 19)
			n = 2
		case v < 277:
			length = v*8 - (273*8 - 35)
			n = 3
		case v < 281:
			length = v*16 - (277*16 - 67)
			n = 4
		case v < 285:
			length = v*32 - (281*32 - 131)
			n = 5
		case v < maxNumLit:
			length = 258
			n = 0
		default:
			f.err = CorruptInputError(f.roffset)
			return
		}
		if n > 0 {
			for f.nb < n {
				if err = f.moreBits(); err != nil {
					f.err = err
					return
				}
			}
			length += int(f.b & uint32(1<<n-1))
			f.b >>= n
			f.nb -= n
		}

		var dist int
		if f.hd == nil {
			for f.nb < 5 {
				if err = f.moreBits(); err != nil {
					f.err = err
					return
				}
			}
			dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3)))
			f.b >>= 5
			f.nb -= 5
		} else {
			if dist, err = f.huffSym(f.hd); err != nil {
				f.err = err
				return
			}
		}

		switch {
		case dist < 4:
			dist++
		case dist < maxNumDist:
			nb := uint(dist-2) >> 1
			extra := (dist & 1) << nb
			for f.nb < nb {
				if err = f.moreBits(); err != nil {
					f.err = err
					return
				}
			}
			extra |= int(f.b & uint32(1<<nb-1))
			f.b >>= nb
			f.nb -= nb
			dist = 1<<(nb+1) + 1 + extra
		default:
			f.err = CorruptInputError(f.roffset)
			return
		}

		if dist > f.window.histSize() {
			f.err = CorruptInputError(f.roffset)
			return
		}

		f.copyLen, f.copyDist = length, dist
		goto copyHistory
	}

copyHistory:
	{
		cnt := f.window.tryWriteCopy(f.copyDist, f.copyLen)
		if cnt == 0 {
			cnt = f.window.writeCopy(f.copyDist, f.copyLen)
		}
		f.copyLen -= cnt

		if f.window.availWrite() == 0 || f.copyLen > 0 {
			f.toRead = f.window.readFlush()
			f.step = (*sliceInflater).huffmanBlock
			f.stepState = stateDict
			return
		}
		goto readLiteral
	}
}

func (f *sliceInflater) dataBlock() {
	f.nb = 0
	f.b = 0

	if f.pos+4 > len(f.input) {
		f.pos = len(f.input)
		f.err = io.ErrUnexpectedEOF
		return
	}
	hdr := f.input[f.pos : f.pos+4]
	f.pos += 4
	f.roffset += 4
	n := int(hdr[0]) | int(hdr[1])<<8
	nn := int(hdr[2]) | int(hdr[3])<<8
	if uint16(nn) != uint16(^n) {
		f.err = CorruptInputError(f.roffset)
		return
	}

	if n == 0 {
		f.toRead = f.window.readFlush()
		f.finishBlock()
		return
	}

	f.copyLen = n
	f.copyData()
}

func (f *sliceInflater) copyData() {
	for {
		if f.copyLen == 0 {
			f.finishBlock()
			return
		}
		buf := f.window.writeSlice()
		if len(buf) == 0 {
			f.toRead = f.window.readFlush()
			f.step = (*sliceInflater).copyData
			return
		}
		n := f.copyLen
		if n > len(buf) {
			n = len(buf)
		}
		if f.pos+n > len(f.input) {
			f.err = io.ErrUnexpectedEOF
			return
		}
		copy(buf[:n], f.input[f.pos:f.pos+n])
		f.pos += n
		f.roffset += int64(n)
		f.copyLen -= n
		f.window.writeMark(n)
		if f.window.availWrite() == 0 {
			f.toRead = f.window.readFlush()
			f.step = (*sliceInflater).copyData
			return
		}
	}
}

func (f *sliceInflater) finishBlock() {
	if f.final {
		if f.window.availRead() > 0 {
			f.toRead = f.window.readFlush()
		}
		f.err = io.EOF
	}
	f.step = (*sliceInflater).nextBlock
	f.stepState = 0
}

func (f *sliceInflater) moreBits() error {
	if f.pos >= len(f.input) {
		return io.ErrUnexpectedEOF
	}
	c := f.input[f.pos]
	f.pos++
	f.roffset++
	f.b |= uint32(c) << (f.nb & 31)
	f.nb += 8
	return nil
}

func (f *sliceInflater) huffSym(h *huffmanDecoder) (int, error) {
	n := uint(h.min)
	nb, b := f.nb, f.b
	for {
		for nb < n {
			if f.pos >= len(f.input) {
				f.b = b
				f.nb = nb
				return 0, io.ErrUnexpectedEOF
			}
			c := f.input[f.pos]
			f.pos++
			f.roffset++
			b |= uint32(c) << (nb & 31)
			nb += 8
		}
		chunk := h.chunks[b&(huffmanNumChunks-1)]
		n = uint(chunk & huffmanCountMask)
		if n > huffmanChunkBits {
			chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask]
			n = uint(chunk & huffmanCountMask)
		}
		if n <= nb {
			if n == 0 {
				f.b = b
				f.nb = nb
				f.err = CorruptInputError(f.roffset)
				return 0, f.err
			}
			f.b = b >> (n & 31)
			f.nb = nb - n
			return int(chunk >> huffmanValueShift), nil
		}
	}
}

func (f *sliceInflater) readHuffman() error {
	for f.nb < 5+5+4 {
		if err := f.moreBits(); err != nil {
			return err
		}
	}
	nlit := int(f.b&0x1F) + 257
	if nlit > maxNumLit {
		return CorruptInputError(f.roffset)
	}
	f.b >>= 5
	ndist := int(f.b&0x1F) + 1
	if ndist > maxNumDist {
		return CorruptInputError(f.roffset)
	}
	f.b >>= 5
	nclen := int(f.b&0xF) + 4
	f.b >>= 4
	f.nb -= 5 + 5 + 4
	codebits := f.codebits[:]
	bits := f.bits[:]
	clear(codebits)
	clear(bits)
	for i := 0; i < nclen; i++ {
		for f.nb < 3 {
			if err := f.moreBits(); err != nil {
				return err
			}
		}
		codebits[codeOrder[i]] = int(f.b & 0x7)
		f.b >>= 3
		f.nb -= 3
	}
	if !f.h1.init(codebits) {
		return CorruptInputError(f.roffset)
	}
	for i := range bits {
		bits[i] = 0
	}
	i := 0
	for i < nlit+ndist {
		x, err := f.huffSym(&f.h1)
		if err != nil {
			return err
		}
		switch {
		case x < 16:
			bits[i] = x
			i++
		case x == 16:
			if i == 0 {
				return CorruptInputError(f.roffset)
			}
			repeat := 3
			for f.nb < 2 {
				if err := f.moreBits(); err != nil {
					return err
				}
			}
			repeat += int(f.b & 0x3)
			f.b >>= 2
			f.nb -= 2
			for repeat > 0 {
				if i >= len(bits) {
					return CorruptInputError(f.roffset)
				}
				bits[i] = bits[i-1]
				i++
				repeat--
			}
		case x == 17:
			repeat := 3
			for f.nb < 3 {
				if err := f.moreBits(); err != nil {
					return err
				}
			}
			repeat += int(f.b & 0x7)
			f.b >>= 3
			f.nb -= 3
			for repeat > 0 {
				if i >= len(bits) {
					return CorruptInputError(f.roffset)
				}
				bits[i] = 0
				i++
				repeat--
			}
		case x == 18:
			repeat := 11
			for f.nb < 7 {
				if err := f.moreBits(); err != nil {
					return err
				}
			}
			repeat += int(f.b & 0x7F)
			f.b >>= 7
			f.nb -= 7
			for repeat > 0 {
				if i >= len(bits) {
					return CorruptInputError(f.roffset)
				}
				bits[i] = 0
				i++
				repeat--
			}
		default:
			return CorruptInputError(f.roffset)
		}
	}
	if !f.h1.init(bits[:nlit]) {
		return CorruptInputError(f.roffset)
	}
	if !f.h2.init(bits[nlit : nlit+ndist]) {
		return CorruptInputError(f.roffset)
	}
	if f.h1.min < bits[endBlockMarker] {
		f.h1.min = bits[endBlockMarker]
	}
	return nil
}