1
0
mirror of https://github.com/kubernetes-sigs/descheduler.git synced 2026-01-26 13:29:11 +01:00

Bump To k8s 1.22.0

This commit is contained in:
Sean Malloy
2021-08-19 00:37:47 -05:00
parent 5420988a28
commit c079c7aaae
1614 changed files with 126680 additions and 52128 deletions

View File

@@ -6,6 +6,7 @@ require (
github.com/Azure/go-autorest v14.2.0+incompatible
github.com/Azure/go-autorest/autorest/date v0.3.0
github.com/Azure/go-autorest/autorest/mocks v0.4.1
github.com/Azure/go-autorest/logger v0.2.1
github.com/Azure/go-autorest/tracing v0.6.0
github.com/form3tech-oss/jwt-go v3.2.2+incompatible
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0

View File

@@ -4,6 +4,8 @@ github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8K
github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74=
github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk=
github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k=
github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg=
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk=

View File

@@ -28,6 +28,7 @@ const (
mimeTypeFormPost = "application/x-www-form-urlencoded"
)
// DO NOT ACCESS THIS DIRECTLY. go through sender()
var defaultSender Sender
var defaultSenderInit = &sync.Once{}

View File

@@ -30,11 +30,13 @@ import (
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/logger"
"github.com/form3tech-oss/jwt-go"
)
@@ -69,13 +71,22 @@ const (
defaultMaxMSIRefreshAttempts = 5
// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
asMSIEndpointEnv = "MSI_ENDPOINT"
msiEndpointEnv = "MSI_ENDPOINT"
// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
asMSISecretEnv = "MSI_SECRET"
msiSecretEnv = "MSI_SECRET"
// the API version to use for the App Service MSI endpoint
appServiceAPIVersion = "2017-09-01"
// the API version to use for the legacy App Service MSI endpoint
appServiceAPIVersion2017 = "2017-09-01"
// secret header used when authenticating against app service MSI endpoint
secretHeader = "Secret"
// the format for expires_on in UTC with AM/PM
expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00"
// the format for expires_on in UTC without AM/PM
expiresOnDateFormat = "1/2/2006 15:04:05 +00:00"
)
// OAuthTokenProvider is an interface which should be implemented by an access token retriever
@@ -282,6 +293,8 @@ func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
msiType msiType
clientResourceID string
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
@@ -652,94 +665,173 @@ func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clie
)
}
type msiType int
const (
msiTypeUnavailable msiType = iota
msiTypeAppServiceV20170901
msiTypeCloudShell
msiTypeIMDS
)
func (m msiType) String() string {
switch m {
case msiTypeUnavailable:
return "unavailable"
case msiTypeAppServiceV20170901:
return "AppServiceV20170901"
case msiTypeCloudShell:
return "CloudShell"
case msiTypeIMDS:
return "IMDS"
default:
return fmt.Sprintf("unhandled MSI type %d", m)
}
}
// returns the MSI type and endpoint, or an error
func getMSIType() (msiType, string, error) {
if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" {
// if the env var MSI_ENDPOINT is set
if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" {
// if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService
return msiTypeAppServiceV20170901, endpointEnvVar, nil
}
// if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell
return msiTypeCloudShell, endpointEnvVar, nil
} else if msiAvailableHook(context.Background(), sender()) {
// if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the msiType is IMDS. This will timeout after 500 milliseconds
return msiTypeIMDS, msiEndpoint, nil
} else {
// if MSI_ENDPOINT is NOT set and IMDS endpoint is not available Managed Identity is not available
return msiTypeUnavailable, "", errors.New("MSI not available")
}
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
// NOTE: this always returns the IMDS endpoint, it does not work for app services or cloud shell.
// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
func GetMSIVMEndpoint() (string, error) {
return msiEndpoint, nil
}
// NOTE: this only indicates if the ASE environment credentials have been set
// which does not necessarily mean that the caller is authenticating via ASE!
func isAppService() bool {
_, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
_, asMSISecretEnvExists := os.LookupEnv(asMSISecretEnv)
return asMSIEndpointEnvExists && asMSISecretEnvExists
}
// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions
// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions.
// It will return an error when not running in an app service/functions environment.
// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
func GetMSIAppServiceEndpoint() (string, error) {
asMSIEndpoint, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
if asMSIEndpointEnvExists {
return asMSIEndpoint, nil
msiType, endpoint, err := getMSIType()
if err != nil {
return "", err
}
switch msiType {
case msiTypeAppServiceV20170901:
return endpoint, nil
default:
return "", fmt.Errorf("%s is not app service environment", msiType)
}
return "", errors.New("MSI endpoint not found")
}
// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
func GetMSIEndpoint() (string, error) {
if isAppService() {
return GetMSIAppServiceEndpoint()
}
return GetMSIVMEndpoint()
_, endpoint, err := getMSIType()
return endpoint, err
}
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
// msiEndpoint - empty string, or pass a non-empty string to override the default value.
// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, nil, callbacks...)
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...)
}
// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the clientID of specified user assigned identity when creating the token.
// msiEndpoint - empty string, or pass a non-empty string to override the default value.
// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, nil, callbacks...)
if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil {
return nil, err
}
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...)
}
// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the azure resource id of user assigned identity when creating the token.
// msiEndpoint - empty string, or pass a non-empty string to override the default value.
// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, &identityResourceID, callbacks...)
}
func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, identityResourceID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil {
return nil, err
}
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...)
}
// ManagedIdentityOptions contains optional values for configuring managed identity authentication.
type ManagedIdentityOptions struct {
// ClientID is the user-assigned identity to use during authentication.
// It is mutually exclusive with IdentityResourceID.
ClientID string
// IdentityResourceID is the resource ID of the user-assigned identity to use during authentication.
// It is mutually exclusive with ClientID.
IdentityResourceID string
}
// NewServicePrincipalTokenFromManagedIdentity creates a ServicePrincipalToken using a managed identity.
// It supports the following managed identity environments.
// - App Service Environment (API version 2017-09-01 only)
// - Cloud shell
// - IMDS with a system or user assigned identity
func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if options == nil {
options = &ManagedIdentityOptions{}
}
return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...)
}
func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if userAssignedID != nil {
if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
return nil, err
}
if userAssignedID != "" && identityResourceID != "" {
return nil, errors.New("cannot specify userAssignedID and identityResourceID")
}
if identityResourceID != nil {
if err := validateStringParam(*identityResourceID, "identityResourceID"); err != nil {
return nil, err
}
msiType, endpoint, err := getMSIType()
if err != nil {
logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v", err)
return nil, err
}
// We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint)
logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s", msiType, endpoint)
if msiEndpoint != "" {
endpoint = msiEndpoint
logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s", endpoint)
}
msiEndpointURL, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
v := url.Values{}
v.Set("resource", resource)
// App Service MSI currently only supports token API version 2017-09-01
if isAppService() {
v.Set("api-version", appServiceAPIVersion)
} else {
v.Set("api-version", msiAPIVersion)
// cloud shell sends its data in the request body
if msiType != msiTypeCloudShell {
v := url.Values{}
v.Set("resource", resource)
clientIDParam := "client_id"
switch msiType {
case msiTypeAppServiceV20170901:
clientIDParam = "clientid"
v.Set("api-version", appServiceAPIVersion2017)
break
case msiTypeIMDS:
v.Set("api-version", msiAPIVersion)
}
if userAssignedID != "" {
v.Set(clientIDParam, userAssignedID)
} else if identityResourceID != "" {
v.Set("mi_res_id", identityResourceID)
}
msiEndpointURL.RawQuery = v.Encode()
}
if userAssignedID != nil {
v.Set("client_id", *userAssignedID)
}
if identityResourceID != nil {
v.Set("mi_res_id", *identityResourceID)
}
msiEndpointURL.RawQuery = v.Encode()
spt := &ServicePrincipalToken{
inner: servicePrincipalToken{
@@ -747,10 +839,14 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
OauthConfig: OAuthConfig{
TokenEndpoint: *msiEndpointURL,
},
Secret: &ServicePrincipalMSISecret{},
Secret: &ServicePrincipalMSISecret{
msiType: msiType,
clientResourceID: identityResourceID,
},
Resource: resource,
AutoRefresh: true,
RefreshWithin: defaultRefresh,
ClientID: userAssignedID,
},
refreshLock: &sync.RWMutex{},
sender: sender(),
@@ -758,10 +854,6 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
}
if userAssignedID != nil {
spt.inner.ClientID = *userAssignedID
}
return spt, nil
}
@@ -858,31 +950,6 @@ func (spt *ServicePrincipalToken) getGrantType() string {
}
}
func isIMDS(u url.URL) bool {
return isMSIEndpoint(u) == true || isASEEndpoint(u) == true
}
func isMSIEndpoint(endpoint url.URL) bool {
msi, err := url.Parse(msiEndpoint)
if err != nil {
return false
}
return endpoint.Host == msi.Host && endpoint.Path == msi.Path
}
func isASEEndpoint(endpoint url.URL) bool {
aseEndpoint, err := GetMSIAppServiceEndpoint()
if err != nil {
// app service environment isn't enabled
return false
}
ase, err := url.Parse(aseEndpoint)
if err != nil {
return false
}
return endpoint.Host == ase.Host && endpoint.Path == ase.Path
}
func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
if spt.customRefreshFunc != nil {
token, err := spt.customRefreshFunc(ctx, resource)
@@ -892,19 +959,45 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
spt.inner.Token = *token
return spt.InvokeRefreshCallbacks(spt.inner.Token)
}
req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
}
req.Header.Add("User-Agent", UserAgent())
// Add header when runtime is on App Service or Functions
if isASEEndpoint(spt.inner.OauthConfig.TokenEndpoint) {
asMSISecret, _ := os.LookupEnv(asMSISecretEnv)
req.Header.Add("Secret", asMSISecret)
}
req = req.WithContext(ctx)
if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
var resp *http.Response
authBodyFilter := func(b []byte) []byte {
if logger.Level() != logger.LogAuth {
return []byte("**REDACTED** authentication body")
}
return b
}
if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
switch msiSecret.msiType {
case msiTypeAppServiceV20170901:
req.Method = http.MethodGet
req.Header.Set("secret", os.Getenv(msiSecretEnv))
break
case msiTypeCloudShell:
req.Header.Set("Metadata", "true")
data := url.Values{}
data.Set("resource", spt.inner.Resource)
if spt.inner.ClientID != "" {
data.Set("client_id", spt.inner.ClientID)
} else if msiSecret.clientResourceID != "" {
data.Set("msi_res_id", msiSecret.clientResourceID)
}
req.Body = ioutil.NopCloser(strings.NewReader(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
break
case msiTypeIMDS:
req.Method = http.MethodGet
req.Header.Set("Metadata", "true")
break
}
logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
} else {
v := url.Values{}
v.Set("client_id", spt.inner.ClientID)
v.Set("resource", resource)
@@ -933,40 +1026,26 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
req.Body = body
}
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
req.Method = http.MethodGet
req.Header.Set(metadataHeader, "true")
}
var resp *http.Response
if isMSIEndpoint(spt.inner.OauthConfig.TokenEndpoint) {
resp, err = getMSIEndpoint(ctx, spt.sender)
if err != nil {
// return a TokenRefreshError here so that we don't keep retrying
return newTokenRefreshError(fmt.Sprintf("the MSI endpoint is not available. Failed HTTP request to MSI endpoint: %v", err), nil)
}
resp.Body.Close()
}
if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
} else {
logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
resp, err = spt.sender.Do(req)
}
// don't return a TokenRefreshError here; this will allow retry logic to apply
if err != nil {
// don't return a TokenRefreshError here; this will allow retry logic to apply
return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
} else if resp == nil {
return fmt.Errorf("adal: received nil response and error")
}
logger.Instance.WriteResponse(resp, logger.Filter{Body: authBodyFilter})
defer resp.Body.Close()
rb, err := ioutil.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if err != nil {
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp)
}
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp)
}
// for the following error cases don't return a TokenRefreshError. the operation succeeded
@@ -979,15 +1058,60 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
if len(strings.Trim(string(rb), " ")) == 0 {
return fmt.Errorf("adal: Empty service principal token received during refresh")
}
var token Token
token := struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
// AAD returns expires_in as a string, ADFS returns it as an int
ExpiresIn json.Number `json:"expires_in"`
// expires_on can be in two formats, a UTC time stamp or the number of seconds.
ExpiresOn string `json:"expires_on"`
NotBefore json.Number `json:"not_before"`
Resource string `json:"resource"`
Type string `json:"token_type"`
}{}
// return a TokenRefreshError in the follow error cases as the token is in an unexpected format
err = json.Unmarshal(rb, &token)
if err != nil {
return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp)
}
expiresOn := json.Number("")
// ADFS doesn't include the expires_on field
if token.ExpiresOn != "" {
if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
}
}
spt.inner.Token.AccessToken = token.AccessToken
spt.inner.Token.RefreshToken = token.RefreshToken
spt.inner.Token.ExpiresIn = token.ExpiresIn
spt.inner.Token.ExpiresOn = expiresOn
spt.inner.Token.NotBefore = token.NotBefore
spt.inner.Token.Resource = token.Resource
spt.inner.Token.Type = token.Type
spt.inner.Token = token
return spt.InvokeRefreshCallbacks(spt.inner.Token)
}
return spt.InvokeRefreshCallbacks(token)
// converts expires_on to the number of seconds
func parseExpiresOn(s string) (json.Number, error) {
// convert the expiration date to the number of seconds from now
timeToDuration := func(t time.Time) json.Number {
dur := t.Sub(time.Now().UTC())
return json.Number(strconv.FormatInt(int64(dur.Round(time.Second).Seconds()), 10))
}
if _, err := strconv.ParseInt(s, 10, 64); err == nil {
// this is the number of seconds case, no conversion required
return json.Number(s), nil
} else if eo, err := time.Parse(expiresOnDateFormatPM, s); err == nil {
return timeToDuration(eo), nil
} else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil {
return timeToDuration(eo), nil
} else {
// unknown format
return json.Number(""), err
}
}
// retry logic specific to retrieving a token from the IMDS endpoint
@@ -1118,46 +1242,6 @@ func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
return tokens
}
// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh primary token: %v", err)
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.EnsureFreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh auxiliary token: %v", err)
}
}
return nil
}
// RefreshWithContext obtains a fresh token for the Service Principal.
func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh primary token: %v", err)
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.RefreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh auxiliary token: %v", err)
}
}
return nil
}
// RefreshExchangeWithContext refreshes the token, but for a different resource.
func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil {
return fmt.Errorf("failed to refresh primary token: %v", err)
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil {
return fmt.Errorf("failed to refresh auxiliary token: %v", err)
}
}
return nil
}
// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
if err := validateStringParam(clientID, "clientID"); err != nil {
@@ -1188,6 +1272,55 @@ func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig,
return &m, nil
}
// NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource.
func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) {
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if certificate == nil {
return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
}
if privateKey == nil {
return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
}
auxTenants := multiTenantCfg.AuxiliaryTenants()
m := MultiTenantServicePrincipalToken{
AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
}
primary, err := NewServicePrincipalTokenWithSecret(
*multiTenantCfg.PrimaryTenant(),
clientID,
resource,
&ServicePrincipalCertificateSecret{
PrivateKey: privateKey,
Certificate: certificate,
},
)
if err != nil {
return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
}
m.PrimaryToken = primary
for i := range auxTenants {
aux, err := NewServicePrincipalTokenWithSecret(
*auxTenants[i],
clientID,
resource,
&ServicePrincipalCertificateSecret{
PrivateKey: privateKey,
Certificate: certificate,
},
)
if err != nil {
return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
}
m.AuxiliaryTokens[i] = aux
}
return &m, nil
}
// MSIAvailable returns true if the MSI endpoint is available for authentication.
func MSIAvailable(ctx context.Context, sender Sender) bool {
resp, err := getMSIEndpoint(ctx, sender)
@@ -1196,3 +1329,8 @@ func MSIAvailable(ctx context.Context, sender Sender) bool {
}
return err == nil
}
// used for testing purposes
var msiAvailableHook = func(ctx context.Context, sender Sender) bool {
return MSIAvailable(ctx, sender)
}

View File

@@ -18,13 +18,12 @@ package adal
import (
"context"
"fmt"
"net/http"
"time"
)
func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error) {
// this cannot fail, the return sig is due to legacy reasons
msiEndpoint, _ := GetMSIVMEndpoint()
tempCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
// http.NewRequestWithContext() was added in Go 1.13
@@ -34,3 +33,43 @@ func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error)
req.URL.RawQuery = q.Encode()
return sender.Do(req)
}
// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh primary token: %w", err)
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.EnsureFreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh auxiliary token: %w", err)
}
}
return nil
}
// RefreshWithContext obtains a fresh token for the Service Principal.
func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh primary token: %w", err)
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.RefreshWithContext(ctx); err != nil {
return fmt.Errorf("failed to refresh auxiliary token: %w", err)
}
}
return nil
}
// RefreshExchangeWithContext refreshes the token, but for a different resource.
func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil {
return fmt.Errorf("failed to refresh primary token: %w", err)
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil {
return fmt.Errorf("failed to refresh auxiliary token: %w", err)
}
}
return nil
}

View File

@@ -23,8 +23,6 @@ import (
)
func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error) {
// this cannot fail, the return sig is due to legacy reasons
msiEndpoint, _ := GetMSIVMEndpoint()
tempCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
req, _ := http.NewRequest(http.MethodGet, msiEndpoint, nil)
@@ -34,3 +32,43 @@ func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error)
req.URL.RawQuery = q.Encode()
return sender.Do(req)
}
// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil {
return err
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.EnsureFreshWithContext(ctx); err != nil {
return err
}
}
return nil
}
// RefreshWithContext obtains a fresh token for the Service Principal.
func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil {
return err
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.RefreshWithContext(ctx); err != nil {
return err
}
}
return nil
}
// RefreshExchangeWithContext refreshes the token, but for a different resource.
func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil {
return err
}
for _, aux := range mt.AuxiliaryTokens {
if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil {
return err
}
}
return nil
}

View File

@@ -42,6 +42,52 @@ const (
var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.StatusCreated, http.StatusOK}
// FutureAPI contains the set of methods on the Future type.
type FutureAPI interface {
// Response returns the last HTTP response.
Response() *http.Response
// Status returns the last status message of the operation.
Status() string
// PollingMethod returns the method used to monitor the status of the asynchronous operation.
PollingMethod() PollingMethodType
// DoneWithContext queries the service to see if the operation has completed.
DoneWithContext(context.Context, autorest.Sender) (bool, error)
// GetPollingDelay returns a duration the application should wait before checking
// the status of the asynchronous request and true; this value is returned from
// the service via the Retry-After response header. If the header wasn't returned
// then the function returns the zero-value time.Duration and false.
GetPollingDelay() (time.Duration, bool)
// WaitForCompletionRef will return when one of the following conditions is met: the long
// running operation has completed, the provided context is cancelled, or the client's
// polling duration has been exceeded. It will retry failed polling attempts based on
// the retry value defined in the client up to the maximum retry attempts.
// If no deadline is specified in the context then the client.PollingDuration will be
// used to determine if a default deadline should be used.
// If PollingDuration is greater than zero the value will be used as the context's timeout.
// If PollingDuration is zero then no default deadline will be used.
WaitForCompletionRef(context.Context, autorest.Client) error
// MarshalJSON implements the json.Marshaler interface.
MarshalJSON() ([]byte, error)
// MarshalJSON implements the json.Unmarshaler interface.
UnmarshalJSON([]byte) error
// PollingURL returns the URL used for retrieving the status of the long-running operation.
PollingURL() string
// GetResult should be called once polling has completed successfully.
// It makes the final GET call to retrieve the resultant payload.
GetResult(autorest.Sender) (*http.Response, error)
}
var _ FutureAPI = (*Future)(nil)
// Future provides a mechanism to access the status and results of an asynchronous request.
// Since futures are stateful they should be passed by value to avoid race conditions.
type Future struct {

View File

@@ -37,6 +37,9 @@ const (
// should be included in the response.
HeaderReturnClientID = "x-ms-return-client-request-id"
// HeaderContentType is the type of the content in the HTTP response.
HeaderContentType = "Content-Type"
// HeaderRequestID is the Azure extension header of the service generated request ID returned
// in the response.
HeaderRequestID = "x-ms-request-id"
@@ -89,54 +92,85 @@ func (se ServiceError) Error() string {
// UnmarshalJSON implements the json.Unmarshaler interface for the ServiceError type.
func (se *ServiceError) UnmarshalJSON(b []byte) error {
// per the OData v4 spec the details field must be an array of JSON objects.
// unfortunately not all services adhear to the spec and just return a single
// object instead of an array with one object. so we have to perform some
// shenanigans to accommodate both cases.
// http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091
type serviceError1 struct {
type serviceErrorInternal struct {
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details []map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
Target *string `json:"target,omitempty"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo,omitempty"`
// not all services conform to the OData v4 spec.
// the following fields are where we've seen discrepancies
// spec calls for []map[string]interface{} but have seen map[string]interface{}
Details interface{} `json:"details,omitempty"`
// spec calls for map[string]interface{} but have seen []map[string]interface{} and string
InnerError interface{} `json:"innererror,omitempty"`
}
type serviceError2 struct {
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
sei := serviceErrorInternal{}
if err := json.Unmarshal(b, &sei); err != nil {
return err
}
se1 := serviceError1{}
err := json.Unmarshal(b, &se1)
if err == nil {
se.populate(se1.Code, se1.Message, se1.Target, se1.Details, se1.InnerError, se1.AdditionalInfo)
return nil
// copy the fields we know to be correct
se.AdditionalInfo = sei.AdditionalInfo
se.Code = sei.Code
se.Message = sei.Message
se.Target = sei.Target
// converts an []interface{} to []map[string]interface{}
arrayOfObjs := func(v interface{}) ([]map[string]interface{}, bool) {
arrayOf, ok := v.([]interface{})
if !ok {
return nil, false
}
final := []map[string]interface{}{}
for _, item := range arrayOf {
as, ok := item.(map[string]interface{})
if !ok {
return nil, false
}
final = append(final, as)
}
return final, true
}
se2 := serviceError2{}
err = json.Unmarshal(b, &se2)
if err == nil {
se.populate(se2.Code, se2.Message, se2.Target, nil, se2.InnerError, se2.AdditionalInfo)
se.Details = append(se.Details, se2.Details)
return nil
}
return err
}
// convert the remaining fields, falling back to raw JSON if necessary
func (se *ServiceError) populate(code, message string, target *string, details []map[string]interface{}, inner map[string]interface{}, additional []map[string]interface{}) {
se.Code = code
se.Message = message
se.Target = target
se.Details = details
se.InnerError = inner
se.AdditionalInfo = additional
if c, ok := arrayOfObjs(sei.Details); ok {
se.Details = c
} else if c, ok := sei.Details.(map[string]interface{}); ok {
se.Details = []map[string]interface{}{c}
} else if sei.Details != nil {
// stuff into Details
se.Details = []map[string]interface{}{
{"raw": sei.Details},
}
}
if c, ok := sei.InnerError.(map[string]interface{}); ok {
se.InnerError = c
} else if c, ok := arrayOfObjs(sei.InnerError); ok {
// if there's only one error extract it
if len(c) == 1 {
se.InnerError = c[0]
} else {
// multiple errors, stuff them into the value
se.InnerError = map[string]interface{}{
"multi": c,
}
}
} else if c, ok := sei.InnerError.(string); ok {
se.InnerError = map[string]interface{}{"error": c}
} else if sei.InnerError != nil {
// stuff into InnerError
se.InnerError = map[string]interface{}{
"raw": sei.InnerError,
}
}
return nil
}
// RequestError describes an error response returned by Azure service.
@@ -307,16 +341,30 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator {
// Check if error is unwrapped ServiceError
decoder := autorest.NewDecoder(encodedAs, bytes.NewReader(b.Bytes()))
if err := decoder.Decode(&e.ServiceError); err != nil {
return err
return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), err)
}
// for example, should the API return the literal value `null` as the response
if e.ServiceError == nil {
e.ServiceError = &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
Details: []map[string]interface{}{
{
"HttpResponse.Body": b.String(),
},
},
}
}
}
if e.ServiceError.Message == "" {
if e.ServiceError != nil && e.ServiceError.Message == "" {
// if we're here it means the returned error wasn't OData v4 compliant.
// try to unmarshal the body in hopes of getting something.
rawBody := map[string]interface{}{}
decoder := autorest.NewDecoder(encodedAs, bytes.NewReader(b.Bytes()))
if err := decoder.Decode(&rawBody); err != nil {
return err
return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), err)
}
e.ServiceError = &ServiceError{

View File

@@ -17,6 +17,7 @@ package autorest
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
@@ -165,7 +166,8 @@ type Client struct {
// Setting this to zero will use the provided context to control the duration.
PollingDuration time.Duration
// RetryAttempts sets the default number of retry attempts for client.
// RetryAttempts sets the total number of times the client will attempt to make an HTTP request.
// Set the value to 1 to disable retries. DO NOT set the value to less than 1.
RetryAttempts int
// RetryDuration sets the delay duration for retries.
@@ -259,6 +261,9 @@ func (c Client) Do(r *http.Request) (*http.Response, error) {
},
})
resp, err := SendWithSender(c.sender(tls.RenegotiateNever), r)
if resp == nil && err == nil {
err = errors.New("autorest: received nil response and error")
}
logger.Instance.WriteResponse(resp, logger.Filter{})
Respond(resp, c.ByInspecting())
return resp, err

View File

@@ -96,3 +96,8 @@ func (e DetailedError) Error() string {
}
return fmt.Sprintf("%s#%s: %s: StatusCode=%d -- Original Error: %v", e.PackageType, e.Method, e.Message, e.StatusCode, e.Original)
}
// Unwrap returns the original error.
func (e DetailedError) Unwrap() error {
return e.Original
}

View File

@@ -4,9 +4,9 @@ go 1.12
require (
github.com/Azure/go-autorest v14.2.0+incompatible
github.com/Azure/go-autorest/autorest/adal v0.9.5
github.com/Azure/go-autorest/autorest/adal v0.9.13
github.com/Azure/go-autorest/autorest/mocks v0.4.1
github.com/Azure/go-autorest/logger v0.2.0
github.com/Azure/go-autorest/logger v0.2.1
github.com/Azure/go-autorest/tracing v0.6.0
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0
)

View File

@@ -1,13 +1,13 @@
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24=
github.com/Azure/go-autorest/autorest/adal v0.9.5 h1:Y3bBUV4rTuxenJJs41HU3qmqsb+auo+a3Lz+PlJPpL0=
github.com/Azure/go-autorest/autorest/adal v0.9.5/go.mod h1:B7KF7jKIeC9Mct5spmyCB/A8CG/sEz1vwIRGv/bbw7A=
github.com/Azure/go-autorest/autorest/adal v0.9.13 h1:Mp5hbtOePIzM8pJVRa3YLrWWmZtoxRXqUEzCfJt3+/Q=
github.com/Azure/go-autorest/autorest/adal v0.9.13/go.mod h1:W/MM4U6nLxnIskrw4UwWzlHfGjwUS50aOsc/I3yuU8M=
github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw=
github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74=
github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk=
github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k=
github.com/Azure/go-autorest/logger v0.2.0 h1:e4RVHVZKC5p6UANLJHkM4OfR1UKZPj8Wt8Pcx+3oqrE=
github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg=
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk=

View File

@@ -26,8 +26,6 @@ import (
"net/url"
"reflect"
"strings"
"github.com/Azure/go-autorest/autorest/adal"
)
// EncodedAs is a series of constants specifying various data encodings
@@ -207,18 +205,6 @@ func ChangeToGet(req *http.Request) *http.Request {
return req
}
// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError
// interface. If err is a DetailedError it will walk the chain of Original errors.
func IsTokenRefreshError(err error) bool {
if _, ok := err.(adal.TokenRefreshError); ok {
return true
}
if de, ok := err.(DetailedError); ok {
return IsTokenRefreshError(de.Original)
}
return false
}
// IsTemporaryNetworkError returns true if the specified error is a temporary network error or false
// if it's not. If the error doesn't implement the net.Error interface the return value is true.
func IsTemporaryNetworkError(err error) bool {

View File

@@ -0,0 +1,29 @@
// +build go1.13
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package autorest
import (
"errors"
"github.com/Azure/go-autorest/autorest/adal"
)
// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError interface.
func IsTokenRefreshError(err error) bool {
var tre adal.TokenRefreshError
return errors.As(err, &tre)
}

View File

@@ -0,0 +1,31 @@
// +build !go1.13
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package autorest
import "github.com/Azure/go-autorest/autorest/adal"
// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError
// interface. If err is a DetailedError it will walk the chain of Original errors.
func IsTokenRefreshError(err error) bool {
if _, ok := err.(adal.TokenRefreshError); ok {
return true
}
if de, ok := err.(DetailedError); ok {
return IsTokenRefreshError(de.Original)
}
return false
}

View File

@@ -55,6 +55,10 @@ const (
// LogDebug tells a logger to log all LogDebug, LogInfo, LogWarning, LogError, LogPanic and LogFatal entries passed to it.
LogDebug
// LogAuth is a special case of LogDebug, it tells a logger to also log the body of an authentication request and response.
// NOTE: this can disclose sensitive information, use with care.
LogAuth
)
const (
@@ -65,6 +69,7 @@ const (
logWarning = "WARNING"
logInfo = "INFO"
logDebug = "DEBUG"
logAuth = "AUTH"
logUnknown = "UNKNOWN"
)
@@ -83,6 +88,8 @@ func ParseLevel(s string) (lt LevelType, err error) {
lt = LogInfo
case logDebug:
lt = LogDebug
case logAuth:
lt = LogAuth
default:
err = fmt.Errorf("bad log level '%s'", s)
}
@@ -106,6 +113,8 @@ func (lt LevelType) String() string {
return logInfo
case LogDebug:
return logDebug
case LogAuth:
return logAuth
default:
return logUnknown
}