Commit b59f0b04 authored by Levin Zimmermann's avatar Levin Zimmermann

go/neo/proto/msgpack: Fix case where Compression == None

'Compression' is py type 'Optional[int]' [1]. Before this patch only 'int' was
supported. Now NEO/go also understands 'Compression' with value 'Nil'.

Without this patch, NEO/go client tests fail with

```
have: neos://127.0.0.1:19847,127.0.0.1:28658/1: load 7fffffffffffffff:0000000000000006: 127.0.0.1:39230 - 127.0.0.1:46143 .291: decode: decode: M: AnswerObject.Compression: msgp: attempted to decode type "nil" with method for "uint"
```

[1] See https://lab.nexedi.com/nexedi/neoppod/-/blob/e3cd5c5bf/neo/tests/protocol#L21
    The fourth argument is 'compression':
    https://lab.nexedi.com/nexedi/neoppod/-/blob/e3cd5c5bf/neo/storage/handlers/client.py#L77-78
parent 6c431a55
...@@ -1148,8 +1148,28 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1148,8 +1148,28 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
v = fmt.Sprintf("%s(v)", typeName(userType)) v = fmt.Sprintf("%s(v)", typeName(userType))
} }
// optional emits optional value of int/float
optional := func(optionalValue string) {
if optionalValue != "" {
// Read%dBytes returns 'ErrShortBytes' in case prefix is
// correct float, but data is too short - catch this to return
// 'ErrDecodeOverflow' instead of type error.
d.emit(" err = mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" if err == ErrDecodeOverflow {")
d.emit(" return 0, err")
d.emit(" }")
d.emit(" tail, err = msgp.ReadNilBytes(data)")
d.emit(" if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" }")
d.emit(" v = %v", optionalValue)
} else {
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
}
}
// mgetint emits assignto = mget<kind>int<size>() // mgetint emits assignto = mget<kind>int<size>()
mgetint := func(kind string, size int) { mgetint := func(kind string, size int, optionalValue string) {
// we are going to go into msgp - flush previously queued // we are going to go into msgp - flush previously queued
// overflow checks; put place for next overflow check after // overflow checks; put place for next overflow check after
// msgp is done. // msgp is done.
...@@ -1164,7 +1184,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1164,7 +1184,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.emit("{") d.emit("{")
d.emit("v, tail, err := msgp.Read%snt%dBytes(data)", KI, size) d.emit("v, tail, err := msgp.Read%snt%dBytes(data)", KI, size)
d.emit("if err != nil {") d.emit("if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) optional(optionalValue)
d.emit("}") d.emit("}")
d.emit("%s= %s", assignto, v) d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
...@@ -1182,22 +1202,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1182,22 +1202,7 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.emit("{") d.emit("{")
d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size) d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size)
d.emit("if err != nil {") d.emit("if err != nil {")
if optionalValue != "" { optional(optionalValue)
// ReadFloat%dBytes returns 'ErrShortBytes' in case prefix is
// correct float, but data is too short - catch this to return
// 'ErrDecodeOverflow' instead of type error.
d.emit(" err = mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" if err == ErrDecodeOverflow {")
d.emit(" return 0, err")
d.emit(" }")
d.emit(" tail, err = msgp.ReadNilBytes(data)")
d.emit(" if err != nil {")
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
d.emit(" }")
d.emit(" v = %v", optionalValue)
} else {
d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto))
}
d.emit("}") d.emit("}")
d.emit("%s= %s", assignto, v) d.emit("%s= %s", assignto, v)
d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread"))
...@@ -1213,6 +1218,13 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1213,6 +1218,13 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
return return
} }
// Compression can be nil ('None'), this means the same as
// no compression ('py/NoneType.__bool__' is 'False').
if typeName(userType) == "Compression" {
mgetint("u", 64, "0")
return
}
switch typ.Kind() { switch typ.Kind() {
case types.Bool: case types.Bool:
d.emit("switch op := msgpack.Op(data[%v]); op {", d.n) d.emit("switch op := msgpack.Op(data[%v]); op {", d.n)
...@@ -1224,15 +1236,15 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty ...@@ -1224,15 +1236,15 @@ func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Ty
d.n++ d.n++
d.overflow.Add(1) d.overflow.Add(1)
case types.Int8: mgetint("", 8) case types.Int8: mgetint("", 8, "")
case types.Int16: mgetint("", 16) case types.Int16: mgetint("", 16, "")
case types.Int32: mgetint("", 32) case types.Int32: mgetint("", 32, "")
case types.Int64: mgetint("", 64) case types.Int64: mgetint("", 64, "")
case types.Uint8: mgetint("u", 8) case types.Uint8: mgetint("u", 8, "")
case types.Uint16: mgetint("u", 16) case types.Uint16: mgetint("u", 16, "")
case types.Uint32: mgetint("u", 32) case types.Uint32: mgetint("u", 32, "")
case types.Uint64: mgetint("u", 64) case types.Uint64: mgetint("u", 64, "")
case types.Float64: mgetfloat(64, "") case types.Float64: mgetfloat(64, "")
} }
......
...@@ -4967,9 +4967,17 @@ func (p *AnswerRebaseObject) neoMsgDecodeM(data []byte) (int, error) { ...@@ -4967,9 +4967,17 @@ func (p *AnswerRebaseObject) neoMsgDecodeM(data []byte) (int, error) {
} }
{ {
v, tail, err := msgp.ReadUint64Bytes(data) v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("AnswerRebaseObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("AnswerRebaseObject.Compression", err) return 0, mdecodeErr("AnswerRebaseObject.Compression", err)
} }
v = 0
}
p.Compression = Compression(v) p.Compression = Compression(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
...@@ -5180,9 +5188,17 @@ func (p *StoreObject) neoMsgDecodeM(data []byte) (int, error) { ...@@ -5180,9 +5188,17 @@ func (p *StoreObject) neoMsgDecodeM(data []byte) (int, error) {
} }
{ {
v, tail, err := msgp.ReadUint64Bytes(data) v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("StoreObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("StoreObject.Compression", err) return 0, mdecodeErr("StoreObject.Compression", err)
} }
v = 0
}
p.Compression = Compression(v) p.Compression = Compression(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
...@@ -6223,9 +6239,17 @@ func (p *AnswerObject) neoMsgDecodeM(data []byte) (int, error) { ...@@ -6223,9 +6239,17 @@ func (p *AnswerObject) neoMsgDecodeM(data []byte) (int, error) {
} }
{ {
v, tail, err := msgp.ReadUint64Bytes(data) v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("AnswerObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("AnswerObject.Compression", err) return 0, mdecodeErr("AnswerObject.Compression", err)
} }
v = 0
}
p.Compression = Compression(v) p.Compression = Compression(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
...@@ -13168,9 +13192,17 @@ func (p *AddObject) neoMsgDecodeM(data []byte) (int, error) { ...@@ -13168,9 +13192,17 @@ func (p *AddObject) neoMsgDecodeM(data []byte) (int, error) {
} }
{ {
v, tail, err := msgp.ReadUint64Bytes(data) v, tail, err := msgp.ReadUint64Bytes(data)
if err != nil {
err = mdecodeErr("AddObject.Compression", err)
if err == ErrDecodeOverflow {
return 0, err
}
tail, err = msgp.ReadNilBytes(data)
if err != nil { if err != nil {
return 0, mdecodeErr("AddObject.Compression", err) return 0, mdecodeErr("AddObject.Compression", err)
} }
v = 0
}
p.Compression = Compression(v) p.Compression = Compression(v)
nread += uint64(len(data) - len(tail)) nread += uint64(len(data) - len(tail))
data = tail data = tail
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment