/*
 * ZDNS Copyright 2024 Regents of the University of Michigan
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy
 * of the License at http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */
package zdns

import (
	"context"
	"strings"

	"github.com/pkg/errors"
	"github.com/zmap/dns"
)

/*
It's unfortunate that we need this nslookup functionality in the main zdns package, but it's necessary to be able to
easily lookup NS records in zdns without encountering circular dependencies within the modules.
*/

// NSRecord result to be returned by scan of host
type NSRecord struct {
	Name          string   `json:"name" groups:"short,normal,long,trace"`
	Type          string   `json:"type" groups:"short,normal,long,trace"`
	IPv4Addresses []string `json:"ipv4_addresses,omitempty" groups:"short,normal,long,trace"`
	IPv6Addresses []string `json:"ipv6_addresses,omitempty" groups:"short,normal,long,trace"`
	TTL           uint32   `json:"ttl" groups:"normal,long,trace"`
}

type NSResult struct {
	Servers []NSRecord `json:"servers,omitempty" groups:"short,normal,long,trace"`
}

// DoNSLookup performs a DNS NS lookup on the given name against the given name server.
func (r *Resolver) DoNSLookup(ctx context.Context, lookupName string, nameServer *NameServer, isIterative, lookupA, lookupAAAA bool) (*NSResult, Trace, Status, error) {
	if len(lookupName) == 0 {
		return nil, nil, "", errors.New("no name provided for NS lookup")
	}
	if !lookupA && !lookupAAAA {
		return nil, nil, "", errors.New("must lookup either A or AAAA")
	}

	var trace Trace
	var ns *SingleQueryResult
	var status Status
	var err error
	if isIterative {
		ns, trace, status, err = r.IterativeLookup(ctx, &Question{Name: lookupName, Type: dns.TypeNS, Class: dns.ClassINET})
	} else {
		ns, trace, status, err = r.ExternalLookup(ctx, &Question{Name: lookupName, Type: dns.TypeNS, Class: dns.ClassINET}, nameServer)

	}

	var retv NSResult
	if status != StatusNoError || err != nil {
		return &retv, trace, status, err
	}
	ipv4s := make(map[string][]string)
	ipv6s := make(map[string][]string)
	for _, ans := range ns.Additionals {
		a, ok := ans.(Answer)
		if !ok {
			continue
		}
		recName := strings.TrimSuffix(a.Name, ".")
		if VerifyAddress(a.Type, a.Answer) {
			switch a.Type {
			case "A":
				ipv4s[recName] = append(ipv4s[recName], a.Answer)
			case "AAAA":
				ipv6s[recName] = append(ipv6s[recName], a.Answer)
			}
		}
	}
	for _, ans := range ns.Answers {
		a, ok := ans.(Answer)
		if !ok {
			continue
		}

		if a.Type != "NS" {
			continue
		}

		var rec NSRecord
		rec.Type = a.Type
		rec.Name = strings.TrimSuffix(a.Answer, ".")
		rec.TTL = a.TTL

		var findIpv4 = false
		var findIpv6 = false

		if lookupA {
			if ips, ok := ipv4s[rec.Name]; ok {
				rec.IPv4Addresses = ips
			} else {
				findIpv4 = true
			}
		}
		if lookupAAAA {
			if ips, ok := ipv6s[rec.Name]; ok {
				rec.IPv6Addresses = ips
			} else {
				findIpv6 = true
			}
		}
		if findIpv4 || findIpv6 {
			res, nextTrace, _, _ := r.DoTargetedLookup(ctx, rec.Name, nameServer, false, lookupA, lookupAAAA)
			if res != nil {
				if findIpv4 {
					rec.IPv4Addresses = res.IPv4Addresses
				}
				if findIpv6 {
					rec.IPv6Addresses = res.IPv6Addresses
				}
			}
			trace = append(trace, nextTrace...)
		}

		retv.Servers = append(retv.Servers, rec)
	}
	return &retv, trace, StatusNoError, nil
}
