// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package network

//docgen:jsonschema

import (
	"encoding/base64"
	"errors"
	"fmt"
	"net/netip"
	"time"

	"github.com/siderolabs/gen/optional"
	"github.com/siderolabs/gen/xslices"

	"github.com/siderolabs/talos/pkg/machinery/config/config"
	"github.com/siderolabs/talos/pkg/machinery/config/internal/registry"
	"github.com/siderolabs/talos/pkg/machinery/config/types/meta"
	"github.com/siderolabs/talos/pkg/machinery/config/validation"
)

// WireguardKind is a Wireguard config document kind.
const WireguardKind = "WireguardConfig"

func init() {
	registry.Register(WireguardKind, func(version string) config.Document {
		switch version {
		case "v1alpha1": //nolint:goconst
			return &WireguardConfigV1Alpha1{}
		default:
			return nil
		}
	})
}

// Check interfaces.
var (
	_ config.NetworkWireguardConfig = &WireguardConfigV1Alpha1{}
	_ config.ConflictingDocument    = &WireguardConfigV1Alpha1{}
	_ config.NamedDocument          = &WireguardConfigV1Alpha1{}
	_ config.Validator              = &WireguardConfigV1Alpha1{}
	_ config.SecretDocument         = &WireguardConfigV1Alpha1{}
)

// WireguardConfigV1Alpha1 is a config document to create and configure a Wireguard network link.
//
//	examples:
//	  - value: exampleWireguardConfigV1Alpha1()
//	alias: WireguardConfig
//	schemaRoot: true
//	schemaMeta: v1alpha1/WireguardConfig
type WireguardConfigV1Alpha1 struct {
	meta.Meta `yaml:",inline"`

	//   description: |
	//     Name of the Wireguard link (interface).
	//
	//   examples:
	//    - value: >
	//       "wg.int"
	//   schemaRequired: true
	MetaName string `yaml:"name"`
	//   description: |
	//     Specifies a private key configuration (base64 encoded).
	//     Can be generated by `wg genkey`.
	//   schemaRequired: true
	WireguardPrivateKey string `yaml:"privateKey,omitempty"`
	//   description: |
	//     Specifies a device's listening port (UDP).
	//     If not specified, a random port will be chosen.
	WireguardListenPort int `yaml:"listenPort,omitempty"`
	//   description: |
	//     Specifies a device's firewall mark.
	//     Useful for advanced routing setups, marking packets originating from this device.
	WireguardFirewallMark int `yaml:"firewallMark,omitempty"`
	//   description: Specifies a list of peer configurations to apply to a device.
	WireguardPeers []WireguardPeer `yaml:"peers,omitempty"`

	//nolint:embeddedstructfieldcheck
	CommonLinkConfig `yaml:",inline"`
}

// WireguardPeer describes a Wireguard peer configuration.
type WireguardPeer struct {
	//   description: |
	//     Specifies the public key of this peer.
	//     Can be extracted from private key by running `wg pubkey < private.key`.
	//   schemaRequired: true
	WireguardPublicKey string `yaml:"publicKey,omitempty"`
	//   description: |
	//     Specifies the preshared key for this peer (base64 encoded).
	//     Can be generated by `wg genpsk`.
	//     Optional, this key provides an additional layer of symmetric-key cryptography
	//     to the peer connection.
	WireguardPresharedKey string `yaml:"presharedKey,omitempty"`
	//   description: |
	//     Specifies the endpoint of this peer entry.
	//     Format: <IP address>:<port>.
	//     If not set, the peer should connect to us without us connecting to it first.
	//   schema:
	//     type: string
	//     pattern: ^([0-9a-f.:]+|\[[0-9a-f:.]+\]):\d{1,5}$
	WireguardEndpoint AddrPort `yaml:"endpoint,omitempty"`
	//   description: |
	//     Specifies the persistent keepalive interval for this peer.
	//     Field format accepts any Go time.Duration format ('1h' for one hour, '10m' for ten minutes).
	//   schema:
	//     type: string
	//     pattern: ^[-+]?(((\d+(\.\d*)?|\d*(\.\d+)+)([nuµm]?s|m|h))|0)+$
	WireguardPersistentKeepaliveInterval time.Duration `yaml:"persistentKeepaliveInterval,omitempty"`
	//   description: |
	//     AllowedIPs specifies a list of allowed IP addresses in CIDR notation for this peer.
	//     These IPs will be routed to this peer, and defines which IPs this peer is allowed to use.
	//   schema:
	//     type: array
	//     items:
	//       type: string
	//       pattern: ^[0-9a-f.:]+/\d{1,3}$
	WireguardAllowedIPs []Prefix `yaml:"allowedIPs,omitempty"`
}

// NewWireguardConfigV1Alpha1 creates a new WireguardConfig config document.
func NewWireguardConfigV1Alpha1(name string) *WireguardConfigV1Alpha1 {
	return &WireguardConfigV1Alpha1{
		Meta: meta.Meta{
			MetaKind:       WireguardKind,
			MetaAPIVersion: "v1alpha1",
		},
		MetaName: name,
	}
}

func exampleWireguardConfigV1Alpha1() *WireguardConfigV1Alpha1 {
	cfg := NewWireguardConfigV1Alpha1("wg1")
	cfg.WireguardPrivateKey = "OJ34O6J1z4ZZB+t16c+vYrzIrKddxyU3Z2eLhwYzqE8="
	cfg.WireguardListenPort = 51820
	cfg.WireguardPeers = []WireguardPeer{
		{
			WireguardPublicKey:  "fP+xJZvUA5n1Pi/f5wcPiV6tZ6fHwqcGaXe98NfEgkE=",
			WireguardAllowedIPs: []Prefix{{netip.MustParsePrefix("192.168.2.0/24")}},
			WireguardEndpoint:   AddrPort{netip.MustParseAddrPort("10.0.0.1:5180")},
		},
		{
			WireguardPublicKey:    "TDd25Cwq6tMZANIKUaqred+Zt+09HtCqwFeOLtKQ9Cs=",
			WireguardPresharedKey: "UpH8htYK7yJBPg5+q4M/Tx0o5ipHbeSZtI/h/mHxOeU=",
			WireguardAllowedIPs:   []Prefix{{netip.MustParsePrefix("192.168.3.0/24")}},
		},
	}
	cfg.LinkAddresses = []AddressConfig{
		{
			AddressAddress: netip.MustParsePrefix("192.168.1.100/24"),
		},
	}
	cfg.LinkMTU = 1420

	return cfg
}

// Clone implements config.Document interface.
func (s *WireguardConfigV1Alpha1) Clone() config.Document {
	return s.DeepCopy()
}

// Name implements config.NamedDocument interface.
func (s *WireguardConfigV1Alpha1) Name() string {
	return s.MetaName
}

// WireguardConfig implements NetworkWireguardConfig interface.
func (s *WireguardConfigV1Alpha1) WireguardConfig() {}

// ConflictsWithKinds implements config.ConflictingDocument interface.
func (s *WireguardConfigV1Alpha1) ConflictsWithKinds() []string {
	return conflictingLinkKinds(WireguardKind)
}

func validateWireguardKey(key string) error {
	// this is hand-rolled to avoid importing wgtypes into machinery
	raw, err := base64.StdEncoding.DecodeString(key)
	if err != nil {
		return err
	}

	if len(raw) != 32 {
		return errors.New("invalid wireguard key length")
	}

	return nil
}

// Validate implements config.Validator interface.
//
//nolint:gocyclo
func (s *WireguardConfigV1Alpha1) Validate(validation.RuntimeMode, ...validation.Option) ([]string, error) {
	var (
		errs     error
		warnings []string
	)

	if s.MetaName == "" {
		errs = errors.Join(errs, errors.New("name must be specified"))
	}

	if s.WireguardPrivateKey == "" {
		errs = errors.Join(errs, errors.New("wireguard private key must be specified"))
	} else if err := validateWireguardKey(s.WireguardPrivateKey); err != nil {
		errs = errors.Join(errs, errors.New("wireguard private key is invalid: "+err.Error()))
	}

	if s.WireguardListenPort < 0 || s.WireguardListenPort > 65535 {
		errs = errors.Join(errs, errors.New("wireguard listen port must be between 0 and 65535"))
	}

	for i, peer := range s.WireguardPeers {
		if peer.WireguardPublicKey == "" {
			errs = errors.Join(errs, fmt.Errorf("wireguard peer public key must be specified (peer index %d)", i))
		} else if err := validateWireguardKey(peer.WireguardPublicKey); err != nil {
			errs = errors.Join(errs, fmt.Errorf("wireguard peer public key is invalid (peer index %d): %w", i, err))
		}

		if peer.WireguardPresharedKey != "" {
			if err := validateWireguardKey(peer.WireguardPresharedKey); err != nil {
				errs = errors.Join(errs, fmt.Errorf("wireguard peer preshared key is invalid (peer index %d): %w", i, err))
			}
		}

		if peer.WireguardPersistentKeepaliveInterval < 0 {
			errs = errors.Join(errs, fmt.Errorf("wireguard peer persistent keepalive interval cannot be negative (peer index %d)", i))
		}
	}

	extraWarnings, extraErrs := s.CommonLinkConfig.Validate()
	errs, warnings = errors.Join(errs, extraErrs), append(warnings, extraWarnings...)

	return warnings, errs
}

// PrivateKey implements NetworkWireguardConfig interface.
func (s *WireguardConfigV1Alpha1) PrivateKey() string {
	return s.WireguardPrivateKey
}

// ListenPort implements NetworkWireguardConfig interface.
func (s *WireguardConfigV1Alpha1) ListenPort() optional.Optional[int] {
	if s.WireguardListenPort == 0 {
		return optional.None[int]()
	}

	return optional.Some(s.WireguardListenPort)
}

// FirewallMark implements NetworkWireguardConfig interface.
func (s *WireguardConfigV1Alpha1) FirewallMark() optional.Optional[int] {
	if s.WireguardFirewallMark == 0 {
		return optional.None[int]()
	}

	return optional.Some(s.WireguardFirewallMark)
}

// Peers implements NetworkWireguardConfig interface.
func (s *WireguardConfigV1Alpha1) Peers() []config.NetworkWireguardPeerConfig {
	return xslices.Map(s.WireguardPeers, func(peer WireguardPeer) config.NetworkWireguardPeerConfig {
		return peer
	})
}

// PublicKey implements NetworkWireguardPeerConfig interface.
func (p WireguardPeer) PublicKey() string {
	return p.WireguardPublicKey
}

// PresharedKey implements NetworkWireguardPeerConfig interface.
func (p WireguardPeer) PresharedKey() optional.Optional[string] {
	if p.WireguardPresharedKey == "" {
		return optional.None[string]()
	}

	return optional.Some(p.WireguardPresharedKey)
}

// Endpoint implements NetworkWireguardPeerConfig interface.
func (p WireguardPeer) Endpoint() optional.Optional[string] {
	if p.WireguardEndpoint.IsZero() {
		return optional.None[string]()
	}

	return optional.Some(p.WireguardEndpoint.String())
}

// AllowedIPs implements NetworkWireguardPeerConfig interface.
func (p WireguardPeer) AllowedIPs() []netip.Prefix {
	return xslices.Map(p.WireguardAllowedIPs, func(pr Prefix) netip.Prefix {
		return pr.Prefix
	})
}

// PersistentKeepalive implements NetworkWireguardPeerConfig interface.
func (p WireguardPeer) PersistentKeepalive() optional.Optional[time.Duration] {
	if p.WireguardPersistentKeepaliveInterval == 0 {
		return optional.None[time.Duration]()
	}

	return optional.Some(p.WireguardPersistentKeepaliveInterval)
}

// Redact does in-place replacement of secrets with the given string.
func (s *WireguardConfigV1Alpha1) Redact(replacement string) {
	if s.WireguardPrivateKey != "" {
		s.WireguardPrivateKey = replacement
	}

	for i := range s.WireguardPeers {
		s.WireguardPeers[i].Redact(replacement)
	}
}

// Redact does in-place replacement of secrets with the given string.
func (p *WireguardPeer) Redact(replacement string) {
	if p.WireguardPresharedKey != "" {
		p.WireguardPresharedKey = replacement
	}
}
