// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package auth

import (
	"crypto/tls"
	"fmt"
	"net/http"
	"strings"

	xoauth2 "golang.org/x/oauth2"

	"github.com/apache/pulsar-client-go/oauth2"
	"github.com/apache/pulsar-client-go/oauth2/cache"
	"github.com/apache/pulsar-client-go/oauth2/clock"
)

const (
	ConfigParamType                  = "type"
	ConfigParamTypeClientCredentials = "client_credentials"
	ConfigParamIssuerURL             = "issuerUrl"
	ConfigParamAudience              = "audience"
	ConfigParamScope                 = "scope"
	ConfigParamKeyFile               = "privateKey"
	ConfigParamClientID              = "clientId"
)

type oauth2AuthProvider struct {
	clock            clock.Clock
	issuer           oauth2.Issuer
	source           cache.CachingTokenSource
	defaultTransport http.RoundTripper
	tokenTransport   *transport
	flow             *oauth2.ClientCredentialsFlow
}

// NewAuthenticationOAuth2WithParams return a interface of Provider with string map.
func NewAuthenticationOAuth2WithParams(params map[string]string) (Provider, error) {
	issuer := oauth2.Issuer{
		IssuerEndpoint: params[ConfigParamIssuerURL],
		ClientID:       params[ConfigParamClientID],
		Audience:       params[ConfigParamAudience],
	}

	switch params[ConfigParamType] {
	case ConfigParamTypeClientCredentials:
		flow, err := oauth2.NewDefaultClientCredentialsFlow(oauth2.ClientCredentialsFlowOptions{
			KeyFile:          params[ConfigParamKeyFile],
			AdditionalScopes: strings.Split(params[ConfigParamScope], " "),
		})
		if err != nil {
			return nil, err
		}
		return NewAuthenticationOAuth2(issuer, flow), nil
	default:
		return nil, fmt.Errorf("unsupported authentication type: %s", params[ConfigParamType])
	}
}

func NewAuthenticationOAuth2(
	issuer oauth2.Issuer,
	flow *oauth2.ClientCredentialsFlow) Provider {

	return &oauth2AuthProvider{
		clock:  clock.RealClock{},
		issuer: issuer,
		flow:   flow,
	}
}

func (p *oauth2AuthProvider) Init() error {
	source, err := cache.NewDefaultTokenCache(p.issuer.Audience, p.flow)
	if err != nil {
		return err
	}
	p.source = source
	return nil
}

func (p *oauth2AuthProvider) Name() string {
	return "token"
}

func (p *oauth2AuthProvider) GetTLSCertificate() (*tls.Certificate, error) {
	return nil, nil
}

func (p *oauth2AuthProvider) GetData() ([]byte, error) {
	if p.source == nil {
		// anonymous access
		return nil, nil
	}
	token, err := p.source.Token()
	if err != nil {
		return nil, err
	}
	return []byte(token.AccessToken), nil
}

func (p *oauth2AuthProvider) Close() error {
	return nil
}

type transport struct {
	source  cache.CachingTokenSource
	wrapped *xoauth2.Transport
}

func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
	if len(req.Header.Get("Authorization")) != 0 {
		return t.wrapped.Base.RoundTrip(req)
	}

	res, err := t.wrapped.RoundTrip(req)
	if err != nil {
		return nil, err
	}

	if res.StatusCode == 401 {
		err := t.source.InvalidateToken()
		if err != nil {
			return nil, err
		}
	}

	return res, nil
}

func (t *transport) WrappedRoundTripper() http.RoundTripper { return t.wrapped.Base }

func (p *oauth2AuthProvider) RoundTrip(req *http.Request) (*http.Response, error) {
	return p.tokenTransport.RoundTrip(req)
}

func (p *oauth2AuthProvider) Transport() http.RoundTripper {
	return &transport{
		source: p.source,
		wrapped: &xoauth2.Transport{
			Source: p.source,
			Base:   p.defaultTransport,
		},
	}
}

func (p *oauth2AuthProvider) WithTransport(tripper http.RoundTripper) error {
	p.defaultTransport = tripper
	p.tokenTransport = &transport{
		source: p.source,
		wrapped: &xoauth2.Transport{
			Source: p.source,
			Base:   p.defaultTransport,
		},
	}
	return nil
}
