// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

// Package protector provides a set of protectors that stop the query services when the resource usage exceeds the limit.
package protector

import (
	"context"
	"crypto/rand"
	"encoding/binary"
	"errors"
	"fmt"
	"runtime/metrics"
	"sync/atomic"
	"time"

	"github.com/dustin/go-humanize"

	"github.com/apache/skywalking-banyandb/banyand/observability"
	"github.com/apache/skywalking-banyandb/pkg/cgroups"
	"github.com/apache/skywalking-banyandb/pkg/logger"
	"github.com/apache/skywalking-banyandb/pkg/meter"
	"github.com/apache/skywalking-banyandb/pkg/run"
)

var scope = observability.RootScope.SubScope("memory_protector")

// secureRandFloat64 generates a cryptographically secure random float64 in the range [0.0, 1.0).
func secureRandFloat64() float64 {
	var buf [8]byte
	if _, err := rand.Read(buf[:]); err != nil {
		// If crypto/rand fails, fallback to a simple deterministic jitter
		// This should rarely happen in practice
		return 0.5
	}
	// Convert bytes to uint64, then to float64 in range [0.0, 1.0)
	return float64(binary.BigEndian.Uint64(buf[:])) / float64(^uint64(0))
}

// Memory is an interface for monitoring and limiting memory usage to prevent OOM.
type Memory interface {
	AvailableBytes() int64
	GetLimit() uint64
	AcquireResource(ctx context.Context, size uint64) error
	// ShouldCache returns true if the file size is smaller than the threshold.
	ShouldCache(fileSize int64) bool
	run.PreRunner
	run.Config
	run.Service
	// State returns the current memory pressure state.
	State() State
}

// State represents memory pressure levels for load shedding.
type State int

const (
	// StateLow indicates normal memory usage, accept all requests.
	StateLow State = iota
	// StateHigh indicates high memory pressure, reject new requests.
	StateHigh
)

var _ Memory = (*memory)(nil)

// Memory is a protector that stops the query services when the memory usage exceeds the limit.
type memory struct {
	omr            observability.MetricsRegistry
	limitGauge     meter.Gauge
	usageGauge     meter.Gauge
	l              *logger.Logger
	closed         chan struct{}
	blockedChan    chan struct{}
	allowedPercent int
	allowedBytes   run.Bytes
	limit          atomic.Uint64
	usage          uint64
}

// NewMemory creates a new Memory protector.
func NewMemory(omr observability.MetricsRegistry) Memory {
	queueSize := cgroups.CPUs()
	factory := omr.With(scope)

	return &memory{
		omr:         omr,
		blockedChan: make(chan struct{}, queueSize),
		closed:      make(chan struct{}),
		l:           logger.GetLogger("memory-protector"),
		limitGauge:  factory.NewGauge("limit"),
		usageGauge:  factory.NewGauge("usage"),
	}
}

// AcquireResource attempts to acquire a `size` amount of memory using exponential back-off.
func (m *memory) AcquireResource(ctx context.Context, size uint64) error {
	if m.limit.Load() == 0 {
		return nil
	}
	start := time.Now()

	select {
	case m.blockedChan <- struct{}{}:
		defer func() { <-m.blockedChan }()
	case <-ctx.Done():
		return fmt.Errorf("context canceled while waiting for blocked queue slot: %w", ctx.Err())
	}

	const (
		initialBackoff = 10 * time.Millisecond
		maxBackoff     = 1 * time.Second
		backoffFactor  = 1.5
		jitterFactor   = 0.1
	)

	backoff := initialBackoff
	attempt := 0

	for {
		currentUsage := atomic.LoadUint64(&m.usage)
		if currentUsage+size <= m.limit.Load() {
			return nil
		}

		attempt++

		// Calculate next backoff with exponential growth
		nextBackoff := min(time.Duration(float64(backoff)*backoffFactor), maxBackoff)

		// Add jitter to prevent thundering herd
		jitter := time.Duration(secureRandFloat64() * float64(nextBackoff) * jitterFactor)
		finalBackoff := nextBackoff + jitter

		select {
		case <-time.After(finalBackoff):
			backoff = nextBackoff
			continue
		case <-ctx.Done():
			return fmt.Errorf(
				"context canceled: memory acquisition failed (currentUsage: %d, limit: %d, size: %d, attempts: %d, blockedDuration: %v): %w",
				currentUsage, m.limit.Load(), size, attempt, time.Since(start), ctx.Err(),
			)
		}
	}
}

// GetLimit returns the memory limit of the protector.
func (m *memory) GetLimit() uint64 {
	return m.limit.Load()
}

// AvailableBytes returns the available memory (limit - usage).
func (m *memory) AvailableBytes() int64 {
	if m.limit.Load() == 0 {
		return -1
	}
	usage := atomic.LoadUint64(&m.usage)
	if usage >= m.limit.Load() {
		return 0
	}
	return int64(m.limit.Load() - usage)
}

// State returns the current memory pressure state for load shedding decisions.
// It uses 20% of the memory limit as the threshold for high pressure state.
// Returns StateLow if no limit is set (fail open).
func (m *memory) State() State {
	limit := m.GetLimit()
	if limit == 0 {
		// No limit set, fail open
		return StateLow
	}

	available := m.AvailableBytes()
	if available <= 0 {
		return StateHigh
	}

	// Use 20% of memory limit as the threshold for high pressure
	// This provides a buffer zone to prevent rapid state oscillations
	threshold := int64(limit / 5)
	if available <= threshold {
		return StateHigh
	}

	return StateLow
}

// Name returns the name of the protector.
func (m *memory) Name() string {
	return "memory-protector"
}

// FlagSet returns the flag set for the protector.
func (m *memory) FlagSet() *run.FlagSet {
	flagS := run.NewFlagSet(m.Name())
	flagS.IntVarP(&m.allowedPercent, "allowed-percent", "", 75,
		"Allowed percentage of total memory usage. If usage exceeds this value, the query services will stop. "+
			"This takes effect only if `allowed-bytes` is 0. If usage is too high, it may cause OS page cache eviction.")
	flagS.VarP(&m.allowedBytes, "allowed-bytes", "", "Allowed bytes of memory usage. If the memory usage exceeds this value, the query services will stop. "+
		"Setting a large value may evict data from the OS page cache, causing high disk I/O.")
	return flagS
}

// Validate validates the protector's flags.
func (m *memory) Validate() error {
	if m.allowedPercent <= 0 || m.allowedPercent > 100 {
		if m.allowedBytes <= 0 {
			return errors.New("allowed-bytes must be greater than 0")
		}
		return errors.New("allowed-percent must be in the range (0, 100]")
	}
	return nil
}

// PreRun initializes the protector.
func (m *memory) PreRun(context.Context) error {
	m.l = logger.GetLogger(m.Name())
	if m.allowedBytes > 0 {
		m.limit.Store(uint64(m.allowedBytes))
		m.l.Info().
			Str("limit", humanize.Bytes(m.limit.Load())).
			Msg("memory protector enabled")
	} else {
		cgLimit, err := cgroups.MemoryLimit()
		if err != nil {
			m.l.Warn().Err(err).Msg("failed to get memory limit from cgroups, disable memory protector")
			return nil
		}
		if cgLimit <= 0 || cgLimit > 1e18 {
			m.l.Warn().Int64("cgroup_memory_limit", cgLimit).Msg("cgroup memory limit is invalid, disable memory protector")
			return nil
		}
		m.limit.Store(uint64(cgLimit) * uint64(m.allowedPercent) / 100)
		m.l.Info().
			Str("limit", humanize.Bytes(m.limit.Load())).
			Str("cgroup_limit", humanize.Bytes(uint64(cgLimit))).
			Int("percent", m.allowedPercent).
			Msg("memory protector enabled")
	}
	m.limitGauge.Set(float64(m.limit.Load()))
	return nil
}

// GracefulStop stops the protector.
func (m *memory) GracefulStop() {
	close(m.closed)
}

// Serve starts the protector.
func (m *memory) Serve() run.StopNotify {
	if m.limit.Load() == 0 {
		return m.closed
	}
	go func() {
		ticker := time.NewTicker(5 * time.Second)
		defer ticker.Stop()

		for {
			select {
			case <-m.closed:
				return
			case <-ticker.C:
				samples := []metrics.Sample{
					{Name: "/memory/classes/heap/objects:bytes"},
					{Name: "/memory/classes/heap/stacks:bytes"},
					{Name: "/memory/classes/metadata/mcache/inuse:bytes"},
					{Name: "/memory/classes/metadata/mspan/inuse:bytes"},
					{Name: "/memory/classes/metadata/other:bytes"},
					{Name: "/memory/classes/os-stacks:bytes"},
					{Name: "/memory/classes/other:bytes"},
				}
				metrics.Read(samples)
				var usedBytes uint64
				for _, sample := range samples {
					usedBytes += sample.Value.Uint64()
				}

				atomic.StoreUint64(&m.usage, usedBytes)

				if usedBytes > m.limit.Load() {
					m.l.Warn().Str("used", humanize.Bytes(usedBytes)).Str("limit", humanize.Bytes(m.limit.Load())).Msg("memory usage exceeds limit")
				}
			}
		}
	}()
	return m.closed
}

// GetThreshold returns the threshold for large file detection (1% of page cache).
func (m *memory) GetThreshold() int64 {
	// Try reading cgroup memory limit
	cgLimit, err := cgroups.MemoryLimit()
	if err != nil {
		if dl := m.l.Debug(); dl.Enabled() {
			dl.Err(err).Msg("failed to get memory limit from cgroups, using default threshold")
		}
		// Fallback default threshold of 64MB
		return 64 << 20
	}

	// Determine effective memory to use based on flags
	var totalMemory int64
	if m.allowedBytes > 0 {
		totalMemory = cgLimit - int64(m.allowedBytes)
	} else {
		totalMemory = cgLimit * int64(m.allowedPercent) / 100
	}

	// Compute 1% of that memory as page cache threshold
	threshold := totalMemory / 100
	const minThreshold = 10 << 20 // 10MB
	if threshold < minThreshold {
		threshold = minThreshold
	}
	return threshold
}

// ShouldCache returns true if the file size is smaller than the threshold.
func (m *memory) ShouldCache(fileSize int64) bool {
	return fileSize < m.GetThreshold()
}
