Skip to content

Commit

Permalink
use parent context when looking up IPs (#85)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew LeFevre <Andrew LeFevre>
  • Loading branch information
capnspacehook authored Sep 29, 2024
1 parent af289b8 commit 761daae
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
17 changes: 8 additions & 9 deletions filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func createFilter(ctx context.Context, logger *zap.Logger, opts *FilterOptions,
f.additionalHostnames = timedcache.New[string](filterLogger, false)

nf4, nf6, err := openNfQueues(ctx, filterLogger, opts.TrafficQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc {
return newGenericCallback(&f, ipv6)
return newGenericCallback(ctx, &f, ipv6)
})
if err != nil {
return nil, fmt.Errorf("error starting traffic nfqueues: %w", err)
Expand Down Expand Up @@ -823,7 +823,7 @@ func newDNSResponseCallback(f *FilterManager, ipv6 bool) nfqueue.HookFunc {
}
}

func newGenericCallback(f *filter, ipv6 bool) nfqueue.HookFunc {
func newGenericCallback(ctx context.Context, f *filter, ipv6 bool) nfqueue.HookFunc {
var queueNum uint16
if !ipv6 {
queueNum = f.opts.TrafficQueue.IPv4
Expand Down Expand Up @@ -909,7 +909,7 @@ func newGenericCallback(f *filter, ipv6 bool) nfqueue.HookFunc {

// validate that either the source or destination IP is allowed
var verdict int
allowed, err := f.validateIPs(logger, src, dst)
allowed, err := f.validateIPs(ctx, logger, src, dst)
if err != nil {
logger.Error("error validating IPs", zap.Stringer("conn.src", src), zap.Stringer("conn.dst", dst), zap.NamedError("error", err))
verdict = nfqueue.NfDrop
Expand All @@ -931,7 +931,7 @@ func newGenericCallback(f *filter, ipv6 bool) nfqueue.HookFunc {
}
}

func (f *filter) validateIPs(logger *zap.Logger, src, dst netip.Addr) (bool, error) {
func (f *filter) validateIPs(ctx context.Context, logger *zap.Logger, src, dst netip.Addr) (bool, error) {
// check if the destination IP is allowed first, as most likely
// we are validating an outbound connection
if f.allowedIPs.EntryExists(dst) {
Expand All @@ -948,7 +948,7 @@ func (f *filter) validateIPs(logger *zap.Logger, src, dst netip.Addr) (bool, err
// preform reverse IP lookups on the destination and then source
// IPs only if the IPs are not private
if !dst.IsPrivate() {
allowed, err := f.lookupAndValidateIP(logger, dst)
allowed, err := f.lookupAndValidateIP(ctx, logger, dst)
if err != nil {
return false, err
}
Expand All @@ -958,15 +958,14 @@ func (f *filter) validateIPs(logger *zap.Logger, src, dst netip.Addr) (bool, err
}

if !src.IsPrivate() {
return f.lookupAndValidateIP(logger, src)
return f.lookupAndValidateIP(ctx, logger, src)
}

return false, nil
}

func (f *filter) lookupAndValidateIP(logger *zap.Logger, ip netip.Addr) (bool, error) {
// TODO: build from top-level context
ctx, cancel := context.WithTimeout(context.Background(), dnsQueryTimeout)
func (f *filter) lookupAndValidateIP(ctx context.Context, logger *zap.Logger, ip netip.Addr) (bool, error) {
ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout)
defer cancel()

logger.Info("preforming reverse IP lookup", zap.Stringer("ip", ip))
Expand Down
7 changes: 5 additions & 2 deletions timedcache/timed_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,13 @@ func (t *TimedCache[T]) AddEntry(entry T, ttl time.Duration) {
case <-timer.C:
running = false
case s := <-status:
if s == reset {
switch s {
case reset:
// wait until timer is finished resetting
<-status
} else if s == stop {
case start:
// the timer has started, wait for another status
case stop:
return
}
}
Expand Down

0 comments on commit 761daae

Please sign in to comment.