ref: c1f17baa57bad0f61e639fc39c8cd5e4872142f6
dir: /internal/flatex/slice_inflate.go/
package flatex
import (
"io"
"math/bits"
"sync"
)
// sliceInflater is a specialized DEFLATE decoder that reads directly from an
// in-memory byte slice. It mirrors the main decompressor but avoids the
// overhead of the Reader interfaces, enabling faster byte-slice decoding.
type sliceInflater struct {
input []byte
pos int
roffset int64
b uint32
nb uint
h1, h2 huffmanDecoder
bits *[maxNumLit + maxNumDist]int
codebits *[numCodes]int
window windowDecoder
toRead []byte
step func(*sliceInflater)
stepState int
final bool
err error
hl, hd *huffmanDecoder
copyLen int
copyDist int
}
var sliceInflaterPool = sync.Pool{
New: func() any {
fixedHuffmanDecoderInit()
return &sliceInflater{
bits: new([maxNumLit + maxNumDist]int),
codebits: new([numCodes]int),
}
},
}
func (f *sliceInflater) reset(src []byte) error {
bits := f.bits
codebits := f.codebits
windowState := f.window
*f = sliceInflater{
input: src,
bits: bits,
codebits: codebits,
window: windowState,
step: (*sliceInflater).nextBlock,
}
f.window.init(maxMatchOffset)
return nil
}
func (f *sliceInflater) nextBlock() {
for f.nb < 1+2 {
if err := f.moreBits(); err != nil {
f.err = err
return
}
}
f.final = f.b&1 == 1
f.b >>= 1
typ := f.b & 3
f.b >>= 2
f.nb -= 1 + 2
switch typ {
case 0:
f.dataBlock()
case 1:
f.hl = &fixedHuffmanDecoder
f.hd = nil
f.huffmanBlock()
case 2:
if err := f.readHuffman(); err != nil {
f.err = err
return
}
f.hl = &f.h1
f.hd = &f.h2
f.huffmanBlock()
default:
f.err = CorruptInputError(f.roffset)
}
}
func (f *sliceInflater) huffmanBlock() {
const (
stateInit = iota
stateDict
)
switch f.stepState {
case stateInit:
goto readLiteral
case stateDict:
goto copyHistory
}
readLiteral:
{
v, err := f.huffSym(f.hl)
if err != nil {
f.err = err
return
}
var n uint
var length int
switch {
case v < 256:
f.window.writeByte(byte(v))
if f.window.availWrite() == 0 {
f.toRead = f.window.readFlush()
f.step = (*sliceInflater).huffmanBlock
f.stepState = stateInit
return
}
goto readLiteral
case v == 256:
f.finishBlock()
return
case v < 265:
length = v - (257 - 3)
n = 0
case v < 269:
length = v*2 - (265*2 - 11)
n = 1
case v < 273:
length = v*4 - (269*4 - 19)
n = 2
case v < 277:
length = v*8 - (273*8 - 35)
n = 3
case v < 281:
length = v*16 - (277*16 - 67)
n = 4
case v < 285:
length = v*32 - (281*32 - 131)
n = 5
case v < maxNumLit:
length = 258
n = 0
default:
f.err = CorruptInputError(f.roffset)
return
}
if n > 0 {
for f.nb < n {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
length += int(f.b & uint32(1<<n-1))
f.b >>= n
f.nb -= n
}
var dist int
if f.hd == nil {
for f.nb < 5 {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3)))
f.b >>= 5
f.nb -= 5
} else {
if dist, err = f.huffSym(f.hd); err != nil {
f.err = err
return
}
}
switch {
case dist < 4:
dist++
case dist < maxNumDist:
nb := uint(dist-2) >> 1
extra := (dist & 1) << nb
for f.nb < nb {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
extra |= int(f.b & uint32(1<<nb-1))
f.b >>= nb
f.nb -= nb
dist = 1<<(nb+1) + 1 + extra
default:
f.err = CorruptInputError(f.roffset)
return
}
if dist > f.window.histSize() {
f.err = CorruptInputError(f.roffset)
return
}
f.copyLen, f.copyDist = length, dist
goto copyHistory
}
copyHistory:
{
cnt := f.window.tryWriteCopy(f.copyDist, f.copyLen)
if cnt == 0 {
cnt = f.window.writeCopy(f.copyDist, f.copyLen)
}
f.copyLen -= cnt
if f.window.availWrite() == 0 || f.copyLen > 0 {
f.toRead = f.window.readFlush()
f.step = (*sliceInflater).huffmanBlock
f.stepState = stateDict
return
}
goto readLiteral
}
}
func (f *sliceInflater) dataBlock() {
f.nb = 0
f.b = 0
if f.pos+4 > len(f.input) {
f.pos = len(f.input)
f.err = io.ErrUnexpectedEOF
return
}
hdr := f.input[f.pos : f.pos+4]
f.pos += 4
f.roffset += 4
n := int(hdr[0]) | int(hdr[1])<<8
nn := int(hdr[2]) | int(hdr[3])<<8
if uint16(nn) != uint16(^n) {
f.err = CorruptInputError(f.roffset)
return
}
if n == 0 {
f.toRead = f.window.readFlush()
f.finishBlock()
return
}
f.copyLen = n
f.copyData()
}
func (f *sliceInflater) copyData() {
for {
if f.copyLen == 0 {
f.finishBlock()
return
}
buf := f.window.writeSlice()
if len(buf) == 0 {
f.toRead = f.window.readFlush()
f.step = (*sliceInflater).copyData
return
}
n := f.copyLen
if n > len(buf) {
n = len(buf)
}
if f.pos+n > len(f.input) {
f.err = io.ErrUnexpectedEOF
return
}
copy(buf[:n], f.input[f.pos:f.pos+n])
f.pos += n
f.roffset += int64(n)
f.copyLen -= n
f.window.writeMark(n)
if f.window.availWrite() == 0 {
f.toRead = f.window.readFlush()
f.step = (*sliceInflater).copyData
return
}
}
}
func (f *sliceInflater) finishBlock() {
if f.final {
if f.window.availRead() > 0 {
f.toRead = f.window.readFlush()
}
f.err = io.EOF
}
f.step = (*sliceInflater).nextBlock
f.stepState = 0
}
func (f *sliceInflater) moreBits() error {
if f.pos >= len(f.input) {
return io.ErrUnexpectedEOF
}
c := f.input[f.pos]
f.pos++
f.roffset++
f.b |= uint32(c) << (f.nb & 31)
f.nb += 8
return nil
}
func (f *sliceInflater) huffSym(h *huffmanDecoder) (int, error) {
n := uint(h.min)
nb, b := f.nb, f.b
for {
for nb < n {
if f.pos >= len(f.input) {
f.b = b
f.nb = nb
return 0, io.ErrUnexpectedEOF
}
c := f.input[f.pos]
f.pos++
f.roffset++
b |= uint32(c) << (nb & 31)
nb += 8
}
chunk := h.chunks[b&(huffmanNumChunks-1)]
n = uint(chunk & huffmanCountMask)
if n > huffmanChunkBits {
chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask]
n = uint(chunk & huffmanCountMask)
}
if n <= nb {
if n == 0 {
f.b = b
f.nb = nb
f.err = CorruptInputError(f.roffset)
return 0, f.err
}
f.b = b >> (n & 31)
f.nb = nb - n
return int(chunk >> huffmanValueShift), nil
}
}
}
func (f *sliceInflater) readHuffman() error {
for f.nb < 5+5+4 {
if err := f.moreBits(); err != nil {
return err
}
}
nlit := int(f.b&0x1F) + 257
if nlit > maxNumLit {
return CorruptInputError(f.roffset)
}
f.b >>= 5
ndist := int(f.b&0x1F) + 1
if ndist > maxNumDist {
return CorruptInputError(f.roffset)
}
f.b >>= 5
nclen := int(f.b&0xF) + 4
f.b >>= 4
f.nb -= 5 + 5 + 4
codebits := f.codebits[:]
bits := f.bits[:]
clear(codebits)
clear(bits)
for i := 0; i < nclen; i++ {
for f.nb < 3 {
if err := f.moreBits(); err != nil {
return err
}
}
codebits[codeOrder[i]] = int(f.b & 0x7)
f.b >>= 3
f.nb -= 3
}
if !f.h1.init(codebits) {
return CorruptInputError(f.roffset)
}
for i := range bits {
bits[i] = 0
}
i := 0
for i < nlit+ndist {
x, err := f.huffSym(&f.h1)
if err != nil {
return err
}
switch {
case x < 16:
bits[i] = x
i++
case x == 16:
if i == 0 {
return CorruptInputError(f.roffset)
}
repeat := 3
for f.nb < 2 {
if err := f.moreBits(); err != nil {
return err
}
}
repeat += int(f.b & 0x3)
f.b >>= 2
f.nb -= 2
for repeat > 0 {
if i >= len(bits) {
return CorruptInputError(f.roffset)
}
bits[i] = bits[i-1]
i++
repeat--
}
case x == 17:
repeat := 3
for f.nb < 3 {
if err := f.moreBits(); err != nil {
return err
}
}
repeat += int(f.b & 0x7)
f.b >>= 3
f.nb -= 3
for repeat > 0 {
if i >= len(bits) {
return CorruptInputError(f.roffset)
}
bits[i] = 0
i++
repeat--
}
case x == 18:
repeat := 11
for f.nb < 7 {
if err := f.moreBits(); err != nil {
return err
}
}
repeat += int(f.b & 0x7F)
f.b >>= 7
f.nb -= 7
for repeat > 0 {
if i >= len(bits) {
return CorruptInputError(f.roffset)
}
bits[i] = 0
i++
repeat--
}
default:
return CorruptInputError(f.roffset)
}
}
if !f.h1.init(bits[:nlit]) {
return CorruptInputError(f.roffset)
}
if !f.h2.init(bits[nlit : nlit+ndist]) {
return CorruptInputError(f.roffset)
}
if f.h1.min < bits[endBlockMarker] {
f.h1.min = bits[endBlockMarker]
}
return nil
}