diff --git a/internal/network/mqtt/buffer.go b/internal/network/mqtt/buffer.go index 4b9f5608..591a0749 100644 --- a/internal/network/mqtt/buffer.go +++ b/internal/network/mqtt/buffer.go @@ -23,7 +23,7 @@ const smallBufferSize = 64 const maxInt = int(^uint(0) >> 1) // buffers are reusable fixed-side buffers for faster encoding. -var buffers = newBufferPool(maxMessageSize) +var buffers = newBufferPool(MaxMessageSize) // bufferPool represents a thread safe buffer pool type bufferPool struct { diff --git a/internal/network/mqtt/mqtt.go b/internal/network/mqtt/mqtt.go index 266374d5..110baf32 100644 --- a/internal/network/mqtt/mqtt.go +++ b/internal/network/mqtt/mqtt.go @@ -23,7 +23,7 @@ import ( const ( maxHeaderSize = 6 maxTopicSize = 1024 // max MQTT header size - maxMessageSize = 65536 // max MQTT message size is impossible to increase as per protocol (uint16 len) + MaxMessageSize = 65536 // max MQTT message size is impossible to increase as per protocol (uint16 len) ) // ErrMessageTooLarge occurs when a message encoded/decoded is larger than max MQTT frame. @@ -327,7 +327,7 @@ func (p *Publish) EncodeTo(w io.Writer) (int, error) { length += 2 } - if length > maxMessageSize { + if length > MaxMessageSize { return 0, ErrMessageTooLarge } diff --git a/internal/provider/storage/countLimiter.go b/internal/provider/storage/countLimiter.go new file mode 100644 index 00000000..e541dd75 --- /dev/null +++ b/internal/provider/storage/countLimiter.go @@ -0,0 +1,40 @@ +package storage + +import "github.com/emitter-io/emitter/internal/message" + +type Limiter interface { + Admit(*message.Message) bool + Limit(*message.Frame) +} + +// MessageCountLimiter provide an Limiter implementation to replace the "limit" +// parameter in the Query() function. +type MessageCountLimiter struct { + count int64 `binary:"-"` + MsgLimit int64 // TODO: why is this exported? +} + +func (limiter *MessageCountLimiter) Admit(m *message.Message) bool { + // As this function won't be called multiple times once the limit is reached, + // the following implementation should be faster than using a branching statement + // to check if the limit is reached, before incrementing the counter. + limiter.count += 1 + return limiter.count <= limiter.MsgLimit + + // The following implementation would use a branching each time the function is called. + /* + if limiter.count < limiter.MsgLimit { + limiter.count += 1 + return true + } + return false + */ +} + +func (limiter *MessageCountLimiter) Limit(frame *message.Frame) { + frame.Limit(int(limiter.MsgLimit)) +} + +func NewMessageNumberLimiter(limit int64) Limiter { + return &MessageCountLimiter{MsgLimit: limit} +} diff --git a/internal/provider/storage/sizeLimiter.go b/internal/provider/storage/sizeLimiter.go new file mode 100644 index 00000000..55133e63 --- /dev/null +++ b/internal/provider/storage/sizeLimiter.go @@ -0,0 +1,46 @@ +package storage + +import "github.com/emitter-io/emitter/internal/message" + +// MessageSizeLimiter provide an Limiter implementation based on both the +// number of messages and the total size of the response. +type MessageSizeLimiter struct { + count int64 `binary:"-"` + size int64 `binary:"-"` + countLimit int64 + sizeLimit int64 +} + +func (limiter *MessageSizeLimiter) Admit(m *message.Message) bool { + // As this function won't be called multiple times once the limit is reached, + // the following implementation should be faster than using a branching statement + // to check if the limit is reached, before incrementing the counter. + // Todo: discuss whether it's ok to make that assumption + + // This size calculation comes from mqtt.go:EncodeTo() line 320. + // Todo: discuss whether this is the best way to calculate the size. + // 2 bytes for message ID. + limiter.size += int64(2 + len(m.Channel) + len(m.Payload)) + limiter.count += 1 + return limiter.count <= limiter.countLimit && limiter.size <= limiter.sizeLimit +} + +func (limiter *MessageSizeLimiter) Limit(frame *message.Frame) { + // Limit takes the first N elements that fit into a message, sorted by message time + frame.Sort() + frame.Limit(int(limiter.countLimit)) + + totalSize := int64(0) + for i, m := range *frame { + totalSize += int64(2 + len(m.Channel) + len(m.Payload)) + if totalSize > limiter.sizeLimit { + *frame = (*frame)[:i] + break + } + limiter.size += totalSize + } +} + +func NewMessageSizeLimiter(countLimit, sizeLimit int64) Limiter { + return &MessageSizeLimiter{countLimit: countLimit, sizeLimit: sizeLimit} +} diff --git a/internal/provider/storage/sizeLimiter_test.go b/internal/provider/storage/sizeLimiter_test.go new file mode 100644 index 00000000..72d39758 --- /dev/null +++ b/internal/provider/storage/sizeLimiter_test.go @@ -0,0 +1,29 @@ +package storage + +import ( + "fmt" + "testing" + + "github.com/emitter-io/emitter/internal/message" + "github.com/stretchr/testify/assert" +) + +func TestXxx(t *testing.T) { + frame := make(message.Frame, 0, 100) + for i := int64(0); i < 100; i++ { + msg := message.New(message.Ssid{0, 1, 2}, []byte("a/b/c/"), []byte(fmt.Sprintf("%d", i))) + msg.ID.SetTime(msg.ID.Time() + (i * 10000)) + frame = append(frame, *msg) + } + + sizeLimiter := NewMessageSizeLimiter(100, 50) + sizeLimiter.Limit(&frame) + + assert.Len(t, frame, 5) + assert.Equal(t, message.Ssid{0, 1, 2}, frame[0].Ssid()) + assert.Equal(t, "95", string(frame[0].Payload)) + assert.Equal(t, "96", string(frame[1].Payload)) + assert.Equal(t, "97", string(frame[2].Payload)) + assert.Equal(t, "98", string(frame[3].Payload)) + assert.Equal(t, "99", string(frame[4].Payload)) +} diff --git a/internal/provider/storage/ssd_test.go b/internal/provider/storage/ssd_test.go index bb96e21c..7cf3f53b 100644 --- a/internal/provider/storage/ssd_test.go +++ b/internal/provider/storage/ssd_test.go @@ -101,6 +101,12 @@ func TestSSD_QueryRange(t *testing.T) { }) } +func TestSSD_QueryResponseSizeLimited(t *testing.T) { + runSSDTest(func(store *SSD) { + testResponseSizeLimited(t, store) + }) +} + func TestSSD_QuerySurveyed(t *testing.T) { runSSDTest(func(s *SSD) { const wildcard = uint32(1815237614) diff --git a/internal/provider/storage/storage.go b/internal/provider/storage/storage.go index 6bb19dc2..6e137f05 100644 --- a/internal/provider/storage/storage.go +++ b/internal/provider/storage/storage.go @@ -70,8 +70,8 @@ type lookupQuery struct { From int64 // (required) The beginning of the time window. UntilTime int64 // Lookup stops when reaches this time. UntilID message.ID // Lookup stops when reaches this message ID. - LimitByCount *MessageNumberLimiter - //LimitBySize *MessageSizeLimiter + LimitByCount *MessageCountLimiter + LimitBySize *MessageSizeLimiter } // newLookupQuery creates a new lookup query @@ -85,8 +85,10 @@ func newLookupQuery(ssid message.Ssid, from, until time.Time, untilID message.ID } switch v := limiter.(type) { - case *MessageNumberLimiter: + case *MessageCountLimiter: query.LimitByCount = v + case *MessageSizeLimiter: + query.LimitBySize = v } return query } @@ -95,37 +97,13 @@ func (q *lookupQuery) Limiter() Limiter { switch { case q.LimitByCount != nil: return q.LimitByCount + case q.LimitBySize != nil: + return q.LimitBySize default: - return &MessageNumberLimiter{} + return &MessageCountLimiter{} } } -type Limiter interface { - Admit(*message.Message) bool - Limit(*message.Frame) -} - -// MessageNumberLimiter provide an Limiter implementation to replace the "limit" -// parameter in the Query() function. -type MessageNumberLimiter struct { - count int64 `binary:"-"` - MsgLimit int64 -} - -func (limiter *MessageNumberLimiter) Admit(m *message.Message) bool { - admit := limiter.count < limiter.MsgLimit - limiter.count += 1 - return admit -} - -func (limiter *MessageNumberLimiter) Limit(frame *message.Frame) { - frame.Limit(int(limiter.MsgLimit)) -} - -func NewMessageNumberLimiter(limit int64) Limiter { - return &MessageNumberLimiter{MsgLimit: limit} -} - // configUint32 retrieves an uint32 from the config func configUint32(config map[string]interface{}, name string, defaultValue uint32) uint32 { if v, ok := config[name]; ok { diff --git a/internal/provider/storage/storage_test.go b/internal/provider/storage/storage_test.go index b51c1140..4209d91b 100644 --- a/internal/provider/storage/storage_test.go +++ b/internal/provider/storage/storage_test.go @@ -47,7 +47,7 @@ func TestNoop_Store(t *testing.T) { func TestNoop_Query(t *testing.T) { s := new(Noop) zero := time.Unix(0, 0) - r, err := s.Query(testMessage(1, 2, 3).Ssid(), zero, zero, nil, NewMessageNumberLimiter(10)) + r, err := s.Query(testMessage(1, 2, 3).Ssid(), zero, zero, nil, NewMessageSizeLimiter(10, MaxMessageSize)) assert.NoError(t, err) for range r { t.Errorf("Should be empty") @@ -80,7 +80,7 @@ func testOrder(t *testing.T, store Storage) { // Issue a query zero := time.Unix(0, 0) - f, err := store.Query([]uint32{0, 1, 2}, zero, zero, nil, NewMessageNumberLimiter(5)) + f, err := store.Query([]uint32{0, 1, 2}, zero, zero, nil, NewMessageSizeLimiter(5, MaxMessageSize)) assert.NoError(t, err) assert.Len(t, f, 5) @@ -103,7 +103,7 @@ func testRetained(t *testing.T, store Storage) { // Issue a query zero := time.Unix(0, 0) - f, err := store.Query([]uint32{0, 1, 2}, zero, zero, nil, NewMessageNumberLimiter(1)) + f, err := store.Query([]uint32{0, 1, 2}, zero, zero, nil, NewMessageSizeLimiter(1, MaxMessageSize)) assert.NoError(t, err) assert.Len(t, f, 1) @@ -124,7 +124,7 @@ func testUntilID(t *testing.T, store Storage) { // Issue a query zero := time.Unix(0, 0) - f, err := store.Query([]uint32{0, 1, 2}, zero, zero, fourth, NewMessageNumberLimiter(100)) + f, err := store.Query([]uint32{0, 1, 2}, zero, zero, fourth, NewMessageSizeLimiter(100, MaxMessageSize)) assert.NoError(t, err) assert.Len(t, f, 4) @@ -146,7 +146,7 @@ func testRange(t *testing.T, store Storage) { } // Issue a query - f, err := store.Query([]uint32{0, 1, 2}, time.Unix(t0, 0), time.Unix(t1, 0), nil, NewMessageNumberLimiter(5)) + f, err := store.Query([]uint32{0, 1, 2}, time.Unix(t0, 0), time.Unix(t1, 0), nil, NewMessageSizeLimiter(5, MaxMessageSize)) assert.NoError(t, err) assert.Len(t, f, 5) @@ -158,6 +158,21 @@ func testRange(t *testing.T, store Storage) { assert.Equal(t, "60", string(f[4].Payload)) } +func testResponseSizeLimited(t *testing.T, store Storage) { + var t0, t1 int64 + for i := int64(0); i < 100; i++ { + msg := message.New(message.Ssid{0, 1, 2}, []byte("a/b/c/"), []byte(fmt.Sprintf("%d", i))) + msg.ID.SetTime(msg.ID.Time() + (i * 10000)) + assert.NoError(t, store.Store(msg)) + } + + // Issue a query + f, err := store.Query([]uint32{0, 1, 2}, time.Unix(t0, 0), time.Unix(t1, 0), nil, NewMessageSizeLimiter(100, 200)) + assert.NoError(t, err) + + assert.Len(t, f, 20) +} + func Test_configUint32(t *testing.T) { raw := `{ "provider": "memory", diff --git a/internal/service/history/history.go b/internal/service/history/history.go index 1db61b27..cdbb60eb 100644 --- a/internal/service/history/history.go +++ b/internal/service/history/history.go @@ -19,6 +19,7 @@ import ( "github.com/emitter-io/emitter/internal/errors" "github.com/emitter-io/emitter/internal/message" + "github.com/emitter-io/emitter/internal/network/mqtt" "github.com/emitter-io/emitter/internal/provider/logging" "github.com/emitter-io/emitter/internal/provider/storage" "github.com/emitter-io/emitter/internal/security" @@ -76,7 +77,8 @@ func (s *Service) OnRequest(c service.Conn, payload []byte) (service.Response, b ssid := message.NewSsid(key.Contract(), channel.Query) t0, t1 := channel.Window() // Get the window - messageLimiter := storage.NewMessageNumberLimiter(limit) + //messageLimiter := storage.NewMessageNumberLimiter(limit) + messageLimiter := storage.NewMessageSizeLimiter(limit, mqtt.MaxMessageSize) msgs, err := s.store.Query(ssid, t0, t1, request.LastMessageID, messageLimiter) if err != nil { logging.LogError("conn", "query last messages", err)