1
0
mirror of https://github.com/kubernetes-sigs/descheduler.git synced 2026-01-26 05:14:13 +01:00

bump(*): kubernetes release-1.16.0 dependencies

This commit is contained in:
Mike Dame
2019-10-12 11:11:43 -04:00
parent 5af668e89a
commit 1652ba7976
28121 changed files with 3491095 additions and 2280257 deletions

View File

@@ -0,0 +1,7 @@
Thank you for your contribution to Go-AutoRest! We will triage and review it as soon as we can.
As part of submitting, please make sure you can make the following assertions:
- [ ] I've tested my changes, adding unit tests if applicable.
- [ ] I've added Apache 2.0 Headers to the top of any new source files.
- [ ] I'm submitting this PR to the `dev` branch, except in the case of urgent bug fixes warranting their own release.
- [ ] If I'm targeting `master`, I've updated [CHANGELOG.md](https://github.com/Azure/go-autorest/blob/master/CHANGELOG.md) to address the changes I'm making.

View File

@@ -9,6 +9,7 @@ _obj
_test
.DS_Store
.idea/
.vscode/
# Architecture specific extensions/prefixes
*.[568vq]

View File

@@ -1,25 +0,0 @@
sudo: false
language: go
go:
- 1.9
- 1.8
- 1.7
- 1.6
install:
- go get -u github.com/golang/lint/golint
- go get -u github.com/Masterminds/glide
- go get -u github.com/stretchr/testify
- go get -u github.com/GoASTScanner/gas
- glide install
script:
- grep -L -r --include *.go --exclude-dir vendor -P "Copyright (\d{4}|\(c\)) Microsoft" ./ | tee /dev/stderr | test -z "$(< /dev/stdin)"
- test -z "$(gofmt -s -l -w ./autorest/. | tee /dev/stderr)"
- test -z "$(golint ./autorest/... | tee /dev/stderr)"
- go vet ./autorest/...
- test -z "$(gas ./autorest/... | tee /dev/stderr | grep Error)"
- go build -v ./autorest/...
- go test -v ./autorest/...

View File

@@ -1,5 +1,435 @@
# CHANGELOG
## v11.1.2
### Bug Fixes (back-ports)
- If zero bytes are read from a polling response body don't attempt to unmarshal them.
- For an LRO PUT operation the final GET URL was incorrectly set to the Location polling header in some cases.
- In `Future.WaitForCompletionRef()` if the provided context has a deadline don't add the default deadline.
## v11.1.1
### Bug Fixes
- When creating a future always include the polling tracker even if there's a failure; this allows the underlying response to be obtained by the caller.
## v11.1.0
### New Features
- Added `auth.NewAuthorizerFromCLI` to create an authorizer configured from the Azure 2.0 CLI.
- Added `adal.NewOAuthConfigWithAPIVersion` to create an OAuthConfig with the specified API version.
## v11.0.1
### New Features
- Added `x5c` header to client assertion for certificate Issuer+Subject Name authentication.
## v11.0.0
### Breaking Changes
- To handle differences between ADFS and AAD the following fields have had their types changed from `string` to `json.Number`
- ExpiresIn
- ExpiresOn
- NotBefore
### New Features
- Added `auth.NewAuthorizerFromFileWithResource` to create an authorizer from the config file with the specified resource.
- Setting a client's `PollingDuration` to zero will use the provided context to control a LRO's polling duration.
## v10.15.5
### Bug Fixes
- In `DoRetryForStatusCodes`, if a request's context is cancelled return the last response.
## v10.15.4
### Bug Fixes
- If a polling operation returns a failure status code return the associated error.
## v10.15.3
### Bug Fixes
- Initialize the polling URL and method for an LRO tracker on each iteration, favoring the Azure-AsyncOperation header.
## v10.15.2
### Bug Fixes
- Use fmt.Fprint when printing request/response so that any escape sequences aren't treated as format specifiers.
## v10.15.1
### Bug Fixes
- If an LRO API returns a ```Failed``` provisioning state in the initial response return an error at that point so the caller doesn't have to poll.
- For failed LROs without an OData v4 error include the response body in the error's ```AdditionalInfo``` field to aid in diagnosing the failure.
## v10.15.0
### New Features
- Add initial support for request/response logging via setting environment variables.
Setting ```AZURE_GO_SDK_LOG_LEVEL``` to ```LogInfo``` will log request/response
without their bodies. To include the bodies set the log level to ```LogDebug```.
By default the logger writes to strerr, however it can also write to stdout or a file
if specified in ```AZURE_GO_SDK_LOG_FILE```. Note that if the specified file
already exists it will be truncated.
IMPORTANT: by default the logger will redact the Authorization and Ocp-Apim-Subscription-Key
headers. Any other secrets will *not* be redacted.
## v10.14.0
### New Features
- Added package version that contains version constants and user-agent data.
### Bug Fixes
- Add the user-agent to token requests.
## v10.13.0
- Added support for additionalInfo in ServiceError type.
## v10.12.0
### New Features
- Added field ServicePrincipalToken.MaxMSIRefreshAttempts to configure the maximun number of attempts to refresh an MSI token.
## v10.11.4
### Bug Fixes
- If an LRO returns http.StatusOK on the initial response with no async headers return the response body from Future.GetResult().
- If there is no "final GET URL" return an error from Future.GetResult().
## v10.11.3
### Bug Fixes
- In IMDS retry logic, if we don't receive a response don't retry.
- Renamed the retry function so it's clear it's meant for IMDS only.
- For error response bodies that aren't OData-v4 compliant stick the raw JSON in the ServiceError.Details field so the information isn't lost.
- Also add the raw HTTP response to the DetailedResponse.
- Removed superfluous wrapping of response error in azure.DoRetryWithRegistration().
## v10.11.2
### Bug Fixes
- Validation for integers handles int and int64 types.
## v10.11.1
### Bug Fixes
- Adding User information to authorization config as parsed from CLI cache.
## v10.11.0
### New Features
- Added NewServicePrincipalTokenFromManualTokenSecret for creating a new SPT using a manual token and secret
- Added method ServicePrincipalToken.MarshalTokenJSON() to marshall the inner Token
## v10.10.0
### New Features
- Most ServicePrincipalTokens can now be marshalled/unmarshall to/from JSON (ServicePrincipalCertificateSecret and ServicePrincipalMSISecret are not supported).
- Added method ServicePrincipalToken.SetRefreshCallbacks().
## v10.9.2
### Bug Fixes
- Refreshing a refresh token obtained from a web app authorization code now works.
## v10.9.1
### Bug Fixes
- The retry logic for MSI token requests now uses exponential backoff per the guidelines.
- IsTemporaryNetworkError() will return true for errors that don't implement the net.Error interface.
## v10.9.0
### Deprecated Methods
| Old Method | New Method |
|-------------:|:-----------:|
|azure.NewFuture() | azure.NewFutureFromResponse()|
|Future.WaitForCompletion() | Future.WaitForCompletionRef()|
### New Features
- Added azure.NewFutureFromResponse() for creating a Future from the initial response from an async operation.
- Added Future.GetResult() for making the final GET call to retrieve the result from an async operation.
### Bug Fixes
- Some futures failed to return their results, this should now be fixed.
## v10.8.2
### Bug Fixes
- Add nil-gaurd to token retry logic.
## v10.8.1
### Bug Fixes
- Return a TokenRefreshError if the sender fails on the initial request.
- Don't retry on non-temporary network errors.
## v10.8.0
- Added NewAuthorizerFromEnvironmentWithResource() helper function.
## v10.7.0
### New Features
- Added *WithContext() methods to ADAL token refresh operations.
## v10.6.2
- Fixed a bug on device authentication.
## v10.6.1
- Added retries to MSI token get request.
## v10.6.0
- Changed MSI token implementation. Now, the token endpoint is the IMDS endpoint.
## v10.5.1
### Bug Fixes
- `DeviceFlowConfig.Authorizer()` now prints the device code message when running `go test`. `-v` flag is required.
## v10.5.0
### New Features
- Added NewPollingRequestWithContext() for use with polling asynchronous operations.
### Bug Fixes
- Make retry logic use the request's context instead of the deprecated Cancel object.
## v10.4.0
### New Features
- Added helper for parsing Azure Resource ID's.
- Added deprecation message to utils.GetEnvVarOrExit()
## v10.3.0
### New Features
- Added EnvironmentFromURL method to load an Environment from a given URL. This function is particularly useful in the private and hybrid Cloud model, where one may define their own endpoints
- Added TokenAudience endpoint to Environment structure. This is useful in private and hybrid cloud models where TokenAudience endpoint can be different from ResourceManagerEndpoint
## v10.2.0
### New Features
- Added endpoints for batch management.
## v10.1.3
### Bug Fixes
- In Client.Do() invoke WithInspection() last so that it will inspect WithAuthorization().
- Fixed authorization methods to invoke p.Prepare() first, aligning them with the other preparers.
## v10.1.2
- Corrected comment for auth.NewAuthorizerFromFile() function.
## v10.1.1
- Updated version number to match current release.
## v10.1.0
### New Features
- Expose the polling URL for futures.
### Bug Fixes
- Add validation.NewErrorWithValidationError back to prevent breaking changes (it is deprecated).
## v10.0.0
### New Features
- Added target and innererror fields to ServiceError to comply with OData v4 spec.
- The Done() method on futures will now return a ServiceError object when available (it used to return a partial value of such errors).
- Added helper methods for obtaining authorizers.
- Expose the polling URL for futures.
### Bug Fixes
- Switched from glide to dep for dependency management.
- Fixed unmarshaling of ServiceError for JSON bodies that don't conform to the OData spec.
- Fixed a race condition in token refresh.
### Breaking Changes
- The ServiceError.Details field type has been changed to match the OData v4 spec.
- Go v1.7 has been dropped from CI.
- API parameter validation failures will now return a unique error type validation.Error.
- The adal.Token type has been decomposed from adal.ServicePrincipalToken (this was necessary in order to fix the token refresh race).
## v9.10.0
- Fix the Service Bus suffix in Azure public env
- Add Service Bus Endpoint (AAD ResourceURI) for use in [Azure Service Bus RBAC Preview](https://docs.microsoft.com/en-us/azure/service-bus-messaging/service-bus-role-based-access-control)
## v9.9.0
### New Features
- Added EventGridKeyAuthorizer for key authorization with event grid topics.
### Bug Fixes
- Fixed race condition when auto-refreshing service principal tokens.
## v9.8.1
### Bug Fixes
- Added http.StatusNoContent (204) to the list of expected status codes for long-running operations.
- Updated runtime version info so it's current.
## v9.8.0
### New Features
- Added type azure.AsyncOpIncompleteError to be returned from a future's Result() method when the operation has not completed.
## v9.7.1
### Bug Fixes
- Use correct AAD and Graph endpoints for US Gov environment.
## v9.7.0
### New Features
- Added support for application/octet-stream MIME types.
## v9.6.1
### Bug Fixes
- Ensure Authorization header is added to request when polling for registration status.
## v9.6.0
### New Features
- Added support for acquiring tokens via MSI with a user assigned identity.
## v9.5.3
### Bug Fixes
- Don't remove encoding of existing URL Query parameters when calling autorest.WithQueryParameters.
- Set correct Content Type when using autorest.WithFormData.
## v9.5.2
### Bug Fixes
- Check for nil *http.Response before dereferencing it.
## v9.5.1
### Bug Fixes
- Don't count http.StatusTooManyRequests (429) against the retry cap.
- Use retry logic when SkipResourceProviderRegistration is set to true.
## v9.5.0
### New Features
- Added support for username + password, API key, authoriazation code and cognitive services authentication.
- Added field SkipResourceProviderRegistration to clients to provide a way to skip auto-registration of RPs.
- Added utility function AsStringSlice() to convert its parameters to a string slice.
### Bug Fixes
- When checking for authentication failures look at the error type not the status code as it could vary.
## v9.4.2
### Bug Fixes
- Validate parameters when creating credentials.
- Don't retry requests if the returned status is a 401 (http.StatusUnauthorized) as it will never succeed.
## v9.4.1
### Bug Fixes
- Update the AccessTokensPath() to read access tokens path through AZURE_ACCESS_TOKEN_FILE. If this
environment variable is not set, it will fall back to use default path set by Azure CLI.
- Use case-insensitive string comparison for polling states.
## v9.4.0
### New Features
- Added WaitForCompletion() to Future as a default polling implementation.
### Bug Fixes
- Method Future.Done() shouldn't update polling status for unexpected HTTP status codes.
## v9.3.1
### Bug Fixes
- DoRetryForStatusCodes will retry if sender.Do returns a non-nil error.
## v9.3.0
### New Features
- Added PollingMethod() to Future so callers know what kind of polling mechanism is used.
- Added azure.ChangeToGet() which transforms an http.Request into a GET (to be used with LROs).
## v9.2.0
### New Features
- Added support for custom Azure Stack endpoints.
- Added type azure.Future used to track the status of long-running operations.
### Bug Fixes
- Preserve the original error in DoRetryWithRegistration when registration fails.
## v9.1.1
- Fixes a bug regarding the cookie jar on `autorest.Client.Sender`.
## v9.1.0
### New Features
@@ -68,7 +498,7 @@ Support for UNIX time.
- Added telemetry.
## v7.2.3
- Fixing bug in calls to `DelayForBackoff` that caused doubling of delay
- Fixing bug in calls to `DelayForBackoff` that caused doubling of delay
duration.
## v7.2.2

77
vendor/github.com/Azure/go-autorest/Gopkg.lock generated vendored Normal file
View File

@@ -0,0 +1,77 @@
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]]
digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77"
name = "github.com/davecgh/go-spew"
packages = ["spew"]
pruneopts = ""
revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73"
version = "v1.1.1"
[[projects]]
digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6"
name = "github.com/dgrijalva/jwt-go"
packages = ["."]
pruneopts = ""
revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e"
version = "v3.2.0"
[[projects]]
digest = "1:7f175a633086a933d1940a7e7dc2154a0070a7c25fb4a2f671f3eef1a34d1fd7"
name = "github.com/dimchansky/utfbom"
packages = ["."]
pruneopts = ""
revision = "5448fe645cb1964ba70ac8f9f2ffe975e61a536c"
version = "v1.0.0"
[[projects]]
digest = "1:096a8a9182648da3d00ff243b88407838902b6703fc12657f76890e08d1899bf"
name = "github.com/mitchellh/go-homedir"
packages = ["."]
pruneopts = ""
revision = "ae18d6b8b3205b561c79e8e5f69bff09736185f4"
version = "v1.0.0"
[[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
name = "github.com/pmezard/go-difflib"
packages = ["difflib"]
pruneopts = ""
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0"
[[projects]]
digest = "1:c587772fb8ad29ad4db67575dad25ba17a51f072ff18a22b4f0257a4d9c24f75"
name = "github.com/stretchr/testify"
packages = [
"assert",
"require",
]
pruneopts = ""
revision = "f35b8ab0b5a2cef36673838d662e249dd9c94686"
version = "v1.2.2"
[[projects]]
branch = "master"
digest = "1:d10ba2393047dc9ba61df4d36c849a03017175edfa3ee4da23c3a12e10326443"
name = "golang.org/x/crypto"
packages = [
"pkcs12",
"pkcs12/internal/rc2",
]
pruneopts = ""
revision = "5295e8364332db77d75fce11f1d19c053919a9c9"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
input-imports = [
"github.com/dgrijalva/jwt-go",
"github.com/dimchansky/utfbom",
"github.com/mitchellh/go-homedir",
"github.com/stretchr/testify/require",
"golang.org/x/crypto/pkcs12",
]
solver-name = "gps-cdcl"
solver-version = 1

37
vendor/github.com/Azure/go-autorest/Gopkg.toml generated vendored Normal file
View File

@@ -0,0 +1,37 @@
# Gopkg.toml example
#
# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
# for detailed Gopkg.toml documentation.
#
# required = ["github.com/user/thing/cmd/thing"]
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
#
# [[constraint]]
# name = "github.com/user/project"
# version = "1.0.0"
#
# [[constraint]]
# name = "github.com/user/project2"
# branch = "dev"
# source = "github.com/myfork/project2"
#
# [[override]]
# name = "github.com/x/y"
# version = "2.4.0"
[[constraint]]
name = "github.com/dgrijalva/jwt-go"
version = "3.1.0"
[[constraint]]
name = "github.com/dimchansky/utfbom"
version = "1.0.0"
[[constraint]]
name = "github.com/mitchellh/go-homedir"
version = "1.0.0"
[[constraint]]
name = "github.com/stretchr/testify"
version = "1.2.0"

View File

@@ -1,11 +1,22 @@
# go-autorest
[![GoDoc](https://godoc.org/github.com/Azure/go-autorest/autorest?status.png)](https://godoc.org/github.com/Azure/go-autorest/autorest) [![Build Status](https://travis-ci.org/Azure/go-autorest.svg?branch=master)](https://travis-ci.org/Azure/go-autorest) [![Go Report Card](https://goreportcard.com/badge/Azure/go-autorest)](https://goreportcard.com/report/Azure/go-autorest)
[![GoDoc](https://godoc.org/github.com/Azure/go-autorest/autorest?status.png)](https://godoc.org/github.com/Azure/go-autorest/autorest)
[![Build Status](https://dev.azure.com/azure-sdk/public/_apis/build/status/go/Azure.go-autorest?branchName=release/v11.1)](https://dev.azure.com/azure-sdk/public/_build/latest?definitionId=625&branchName=release/v11.1)
[![Go Report Card](https://goreportcard.com/badge/Azure/go-autorest)](https://goreportcard.com/report/Azure/go-autorest)
## Usage
Package autorest implements an HTTP request pipeline suitable for use across multiple go-routines
and provides the shared routines relied on by AutoRest (see https://github.com/Azure/autorest/)
generated Go code.
Package go-autorest provides an HTTP request client for use with [Autorest](https://github.com/Azure/autorest.go)-generated API client packages.
An authentication client tested with Azure Active Directory (AAD) is also
provided in this repo in the package
`github.com/Azure/go-autorest/autorest/adal`. Despite its name, this package
is maintained only as part of the Azure Go SDK and is not related to other
"ADAL" libraries in [github.com/AzureAD](https://github.com/AzureAD).
## Overview
Package go-autorest implements an HTTP request pipeline suitable for use across
multiple goroutines and provides the shared routines used by packages generated
by [Autorest](https://github.com/Azure/autorest.go).
The package breaks sending and responding to HTTP requests into three phases: Preparing, Sending,
and Responding. A typical pattern is:
@@ -129,4 +140,10 @@ go get github.com/Azure/go-autorest/autorest/to
See LICENSE file.
-----
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
see the [Code of Conduct
FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional
questions or comments.

View File

@@ -1,19 +1,24 @@
# Azure Active Directory library for Go
# Azure Active Directory authentication for Go
This project provides a stand alone Azure Active Directory library for Go. The code was extracted
from [go-autorest](https://github.com/Azure/go-autorest/) project, which is used as a base for
[azure-sdk-for-go](https://github.com/Azure/azure-sdk-for-go).
This is a standalone package for authenticating with Azure Active
Directory from other Go libraries and applications, in particular the [Azure SDK
for Go](https://github.com/Azure/azure-sdk-for-go).
Note: Despite the package's name it is not related to other "ADAL" libraries
maintained in the [github.com/AzureAD](https://github.com/AzureAD) org. Issues
should be opened in [this repo's](https://github.com/Azure/go-autorest/issues)
or [the SDK's](https://github.com/Azure/azure-sdk-for-go/issues) issue
trackers.
## Installation
## Install
```
```bash
go get -u github.com/Azure/go-autorest/autorest/adal
```
## Usage
An Active Directory application is required in order to use this library. An application can be registered in the [Azure Portal](https://portal.azure.com/) follow these [guidelines](https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-integrating-applications) or using the [Azure CLI](https://github.com/Azure/azure-cli).
An Active Directory application is required in order to use this library. An application can be registered in the [Azure Portal](https://portal.azure.com/) by following these [guidelines](https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-integrating-applications) or using the [Azure CLI](https://github.com/Azure/azure-cli).
### Register an Azure AD Application with secret
@@ -218,6 +223,40 @@ if (err == nil) {
}
```
#### Username password authenticate
```Go
spt, err := adal.NewServicePrincipalTokenFromUsernamePassword(
oauthConfig,
applicationID,
username,
password,
resource,
callbacks...)
if (err == nil) {
token := spt.Token
}
```
#### Authorization code authenticate
``` Go
spt, err := adal.NewServicePrincipalTokenFromAuthorizationCode(
oauthConfig,
applicationID,
clientSecret,
authorizationCode,
redirectURI,
resource,
callbacks...)
err = spt.Refresh()
if (err == nil) {
token := spt.Token
}
```
### Command Line Tool
A command line tool is available in `cmd/adal.go` that can acquire a token for a given resource. It supports all flows mentioned above.

View File

@@ -281,7 +281,7 @@ func main() {
resource,
callback)
if err == nil {
err = saveToken(spt.Token)
err = saveToken(spt.Token())
}
case refreshMode:
_, err = refreshToken(

View File

@@ -19,22 +19,48 @@ import (
"net/url"
)
const (
activeDirectoryAPIVersion = "1.0"
)
// OAuthConfig represents the endpoints needed
// in OAuth operations
type OAuthConfig struct {
AuthorityEndpoint url.URL
AuthorizeEndpoint url.URL
TokenEndpoint url.URL
DeviceCodeEndpoint url.URL
AuthorityEndpoint url.URL `json:"authorityEndpoint"`
AuthorizeEndpoint url.URL `json:"authorizeEndpoint"`
TokenEndpoint url.URL `json:"tokenEndpoint"`
DeviceCodeEndpoint url.URL `json:"deviceCodeEndpoint"`
}
// IsZero returns true if the OAuthConfig object is zero-initialized.
func (oac OAuthConfig) IsZero() bool {
return oac == OAuthConfig{}
}
func validateStringParam(param, name string) error {
if len(param) == 0 {
return fmt.Errorf("parameter '" + name + "' cannot be empty")
}
return nil
}
// NewOAuthConfig returns an OAuthConfig with tenant specific urls
func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, error) {
const activeDirectoryEndpointTemplate = "%s/oauth2/%s?api-version=%s"
apiVer := "1.0"
return NewOAuthConfigWithAPIVersion(activeDirectoryEndpoint, tenantID, &apiVer)
}
// NewOAuthConfigWithAPIVersion returns an OAuthConfig with tenant specific urls.
// If apiVersion is not nil the "api-version" query parameter will be appended to the endpoint URLs with the specified value.
func NewOAuthConfigWithAPIVersion(activeDirectoryEndpoint, tenantID string, apiVersion *string) (*OAuthConfig, error) {
if err := validateStringParam(activeDirectoryEndpoint, "activeDirectoryEndpoint"); err != nil {
return nil, err
}
api := ""
// it's legal for tenantID to be empty so don't validate it
if apiVersion != nil {
if err := validateStringParam(*apiVersion, "apiVersion"); err != nil {
return nil, err
}
api = fmt.Sprintf("?api-version=%s", *apiVersion)
}
const activeDirectoryEndpointTemplate = "%s/oauth2/%s%s"
u, err := url.Parse(activeDirectoryEndpoint)
if err != nil {
return nil, err
@@ -43,15 +69,15 @@ func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, err
if err != nil {
return nil, err
}
authorizeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "authorize", activeDirectoryAPIVersion))
authorizeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "authorize", api))
if err != nil {
return nil, err
}
tokenURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "token", activeDirectoryAPIVersion))
tokenURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "token", api))
if err != nil {
return nil, err
}
deviceCodeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "devicecode", activeDirectoryAPIVersion))
deviceCodeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "devicecode", api))
if err != nil {
return nil, err
}

View File

@@ -42,3 +42,54 @@ func TestNewOAuthConfig(t *testing.T) {
t.Fatalf("autorest/adal Incorrect devicecode url for Tenant from Environment. expected(%s). actual(%v).", expected, config.DeviceCodeEndpoint)
}
}
func TestNewOAuthConfigWithAPIVersionNil(t *testing.T) {
const testActiveDirectoryEndpoint = "https://login.test.com"
const testTenantID = "tenant-id-test"
config, err := NewOAuthConfigWithAPIVersion(testActiveDirectoryEndpoint, testTenantID, nil)
if err != nil {
t.Fatalf("autorest/adal: Unexpected error while creating oauth configuration for tenant: %v.", err)
}
expected := "https://login.test.com/tenant-id-test/oauth2/authorize"
if config.AuthorizeEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.AuthorizeEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/token"
if config.TokenEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.TokenEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/devicecode"
if config.DeviceCodeEndpoint.String() != expected {
t.Fatalf("autorest/adal Incorrect devicecode url for Tenant from Environment. expected(%s). actual(%v).", expected, config.DeviceCodeEndpoint)
}
}
func TestNewOAuthConfigWithAPIVersionNotNil(t *testing.T) {
const testActiveDirectoryEndpoint = "https://login.test.com"
const testTenantID = "tenant-id-test"
apiVersion := "2.0"
config, err := NewOAuthConfigWithAPIVersion(testActiveDirectoryEndpoint, testTenantID, &apiVersion)
if err != nil {
t.Fatalf("autorest/adal: Unexpected error while creating oauth configuration for tenant: %v.", err)
}
expected := "https://login.test.com/tenant-id-test/oauth2/authorize?api-version=2.0"
if config.AuthorizeEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.AuthorizeEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/token?api-version=2.0"
if config.TokenEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.TokenEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/devicecode?api-version=2.0"
if config.DeviceCodeEndpoint.String() != expected {
t.Fatalf("autorest/adal Incorrect devicecode url for Tenant from Environment. expected(%s). actual(%v).", expected, config.DeviceCodeEndpoint)
}
}

View File

@@ -1,20 +0,0 @@
// +build !windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = "/var/lib/waagent/ManagedIdentity-Settings"

View File

@@ -1,25 +0,0 @@
// +build windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"os"
"strings"
)
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = strings.Join([]string{os.Getenv("SystemDrive"), "WindowsAzure/Config/ManagedIdentity-Settings"}, "/")

View File

@@ -133,13 +133,13 @@ func TestSaveToken(t *testing.T) {
var actualToken Token
var expectedToken Token
json.Unmarshal([]byte(MockTokenJSON), expectedToken)
json.Unmarshal([]byte(MockTokenJSON), &expectedToken)
contents, err := ioutil.ReadFile(f.Name())
if err != nil {
t.Fatal("!!")
}
json.Unmarshal(contents, actualToken)
json.Unmarshal(contents, &actualToken)
if !reflect.DeepEqual(actualToken, expectedToken) {
t.Fatal("azure: token was not serialized correctly")
@@ -159,7 +159,7 @@ func TestSaveTokenFailsNoPermission(t *testing.T) {
}
func TestSaveTokenFailsCantCreate(t *testing.T) {
tokenPath := "/thiswontwork"
tokenPath := "/usr/thiswontwork"
if runtime.GOOS == "windows" {
tokenPath = path.Join(os.Getenv("windir"), "system32")
}

View File

@@ -38,7 +38,7 @@ func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) {
return sf(r)
}
// SendDecorator takes and possibily decorates, by wrapping, a Sender. Decorators may affect the
// SendDecorator takes and possibly decorates, by wrapping, a Sender. Decorators may affect the
// http.Request and pass it along or, first, pass the http.Request along then react to the
// http.Response result.
type SendDecorator func(Sender) Sender

View File

@@ -15,21 +15,26 @@ package adal
// limitations under the License.
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/version"
"github.com/dgrijalva/jwt-go"
)
@@ -42,11 +47,23 @@ const (
// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
OAuthGrantTypeClientCredentials = "client_credentials"
// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
OAuthGrantTypeUserPass = "password"
// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
OAuthGrantTypeRefreshToken = "refresh_token"
// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
OAuthGrantTypeAuthorizationCode = "authorization_code"
// metadataHeader is the header required by MSI extension
metadataHeader = "Metadata"
// msiEndpoint is the well known endpoint for getting MSI authentications tokens
msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
// the default number of attempts to refresh an MSI authentication token
defaultMaxMSIRefreshAttempts = 5
)
// OAuthTokenProvider is an interface which should be implemented by an access token retriever
@@ -54,6 +71,12 @@ type OAuthTokenProvider interface {
OAuthToken() string
}
// TokenRefreshError is an interface used by errors returned during token refresh.
type TokenRefreshError interface {
error
Response() *http.Response
}
// Refresher is an interface for token refresh functionality
type Refresher interface {
Refresh() error
@@ -61,31 +84,52 @@ type Refresher interface {
EnsureFresh() error
}
// RefresherWithContext is an interface for token refresh functionality
type RefresherWithContext interface {
RefreshWithContext(ctx context.Context) error
RefreshExchangeWithContext(ctx context.Context, resource string) error
EnsureFreshWithContext(ctx context.Context) error
}
// TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh
type TokenRefreshCallback func(Token) error
// Token encapsulates the access token used to authorize Azure requests.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
type Token struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn string `json:"expires_in"`
ExpiresOn string `json:"expires_on"`
NotBefore string `json:"not_before"`
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn json.Number `json:"expires_on"`
NotBefore json.Number `json:"not_before"`
Resource string `json:"resource"`
Type string `json:"token_type"`
}
func newToken() Token {
return Token{
ExpiresIn: "0",
ExpiresOn: "0",
NotBefore: "0",
}
}
// IsZero returns true if the token object is zero-initialized.
func (t Token) IsZero() bool {
return t == Token{}
}
// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
s, err := strconv.Atoi(t.ExpiresOn)
s, err := t.ExpiresOn.Float64()
if err != nil {
s = -3600
}
expiration := date.NewUnixTimeFromSeconds(float64(s))
expiration := date.NewUnixTimeFromSeconds(s)
return time.Time(expiration).UTC()
}
@@ -106,6 +150,12 @@ func (t *Token) OAuthToken() string {
return t.AccessToken
}
// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
}
// ServicePrincipalNoSecret represents a secret type that contains no secret
// meaning it is not valid for fetching a fresh token. This is used by Manual
type ServicePrincipalNoSecret struct {
@@ -117,15 +167,19 @@ func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePr
return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
}
// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
// MarshalJSON implements the json.Marshaler interface.
func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalNoSecret",
})
}
// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
type ServicePrincipalTokenSecret struct {
ClientSecret string
ClientSecret string `json:"value"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
@@ -135,21 +189,24 @@ func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *Ser
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Value string `json:"value"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalTokenSecret",
Value: tokenSecret.ClientSecret,
})
}
// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
type ServicePrincipalCertificateSecret struct {
Certificate *x509.Certificate
PrivateKey *rsa.PrivateKey
}
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
}
// SignJwt returns the JWT signed with the certificate's private key.
func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
hasher := sha1.New()
@@ -169,10 +226,12 @@ func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalTo
token := jwt.New(jwt.SigningMethodRS256)
token.Header["x5t"] = thumbprint
x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
token.Header["x5c"] = x5c
token.Claims = jwt.MapClaims{
"aud": spt.oauthConfig.TokenEndpoint.String(),
"iss": spt.clientID,
"sub": spt.clientID,
"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
"iss": spt.inner.ClientID,
"sub": spt.inner.ClientID,
"jti": base64.URLEncoding.EncodeToString(jti),
"nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 24).Unix(),
@@ -195,30 +254,185 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
}
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
}
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string `json:"username"`
Password string `json:"password"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Username string `json:"username"`
Password string `json:"password"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalUsernamePasswordSecret",
Username: secret.Username,
Password: secret.Password,
})
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string `json:"value"`
AuthorizationCode string `json:"authCode"`
RedirectURI string `json:"redirect"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Value string `json:"value"`
AuthCode string `json:"authCode"`
Redirect string `json:"redirect"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalAuthorizationCodeSecret",
Value: secret.ClientSecret,
AuthCode: secret.AuthorizationCode,
Redirect: secret.RedirectURI,
})
}
// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
Token
secret ServicePrincipalSecret
oauthConfig OAuthConfig
clientID string
resource string
autoRefresh bool
refreshWithin time.Duration
sender Sender
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
refreshCallbacks []TokenRefreshCallback
// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
MaxMSIRefreshAttempts int
}
// MarshalTokenJSON returns the marshalled inner token.
func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
return json.Marshal(spt.inner.Token)
}
// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
spt.refreshCallbacks = callbacks
}
// MarshalJSON implements the json.Marshaler interface.
func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
return json.Marshal(spt.inner)
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
// need to determine the token type
raw := map[string]interface{}{}
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
secret := raw["secret"].(map[string]interface{})
switch secret["type"] {
case "ServicePrincipalNoSecret":
spt.inner.Secret = &ServicePrincipalNoSecret{}
case "ServicePrincipalTokenSecret":
spt.inner.Secret = &ServicePrincipalTokenSecret{}
case "ServicePrincipalCertificateSecret":
return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
case "ServicePrincipalMSISecret":
return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
case "ServicePrincipalUsernamePasswordSecret":
spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
case "ServicePrincipalAuthorizationCodeSecret":
spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
default:
return fmt.Errorf("unrecognized token type '%s'", secret["type"])
}
err = json.Unmarshal(data, &spt.inner)
if err != nil {
return err
}
spt.refreshLock = &sync.RWMutex{}
spt.sender = &http.Client{}
return nil
}
// internal type used for marshalling/unmarshalling
type servicePrincipalToken struct {
Token Token `json:"token"`
Secret ServicePrincipalSecret `json:"secret"`
OauthConfig OAuthConfig `json:"oauth"`
ClientID string `json:"clientID"`
Resource string `json:"resource"`
AutoRefresh bool `json:"autoRefresh"`
RefreshWithin time.Duration `json:"refreshWithin"`
}
func validateOAuthConfig(oac OAuthConfig) error {
if oac.IsZero() {
return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
}
return nil
}
// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(id, "id"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if secret == nil {
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
spt := &ServicePrincipalToken{
oauthConfig: oauthConfig,
secret: secret,
clientID: id,
resource: resource,
autoRefresh: true,
refreshWithin: defaultRefresh,
inner: servicePrincipalToken{
Token: newToken(),
OauthConfig: oauthConfig,
Secret: secret,
ClientID: id,
Resource: resource,
AutoRefresh: true,
RefreshWithin: defaultRefresh,
},
refreshLock: &sync.RWMutex{},
sender: &http.Client{},
refreshCallbacks: callbacks,
}
@@ -227,6 +441,18 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso
// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if token.IsZero() {
return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
}
spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@@ -237,7 +463,39 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
return nil, err
}
spt.Token = token
spt.inner.Token = token
return spt, nil
}
// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if secret == nil {
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
if token.IsZero() {
return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
}
spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
secret,
callbacks...)
if err != nil {
return nil, err
}
spt.inner.Token = token
return spt, nil
}
@@ -245,6 +503,18 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
// credentials scoped to the named resource.
func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(secret, "secret"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@@ -256,8 +526,23 @@ func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret s
)
}
// NewServicePrincipalTokenFromCertificate create a ServicePrincipalToken from the supplied pkcs12 bytes.
// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if certificate == nil {
return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
}
if privateKey == nil {
return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@@ -270,59 +555,173 @@ func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID s
)
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) {
return getMSIVMEndpoint(msiPath)
// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(username, "username"); err != nil {
return nil, err
}
if err := validateStringParam(password, "password"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalUsernamePasswordSecret{
Username: username,
Password: password,
},
callbacks...,
)
}
func getMSIVMEndpoint(path string) (string, error) {
// Read MSI settings
bytes, err := ioutil.ReadFile(path)
if err != nil {
return "", err
// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
msiSettings := struct {
URL string `json:"url"`
}{}
err = json.Unmarshal(bytes, &msiSettings)
if err != nil {
return "", err
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
return nil, err
}
if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
return nil, err
}
if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return msiSettings.URL, nil
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalAuthorizationCodeSecret{
ClientSecret: clientSecret,
AuthorizationCode: authorizationCode,
RedirectURI: redirectURI,
},
callbacks...,
)
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) {
return msiEndpoint, nil
}
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
}
// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the specified user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
}
func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if userAssignedID != nil {
if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
return nil, err
}
}
// We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint)
if err != nil {
return nil, err
}
oauthConfig, err := NewOAuthConfig(msiEndpointURL.String(), "")
if err != nil {
return nil, err
v := url.Values{}
v.Set("resource", resource)
v.Set("api-version", "2018-02-01")
if userAssignedID != nil {
v.Set("client_id", *userAssignedID)
}
msiEndpointURL.RawQuery = v.Encode()
spt := &ServicePrincipalToken{
oauthConfig: *oauthConfig,
secret: &ServicePrincipalMSISecret{},
resource: resource,
autoRefresh: true,
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
inner: servicePrincipalToken{
Token: newToken(),
OauthConfig: OAuthConfig{
TokenEndpoint: *msiEndpointURL,
},
Secret: &ServicePrincipalMSISecret{},
Resource: resource,
AutoRefresh: true,
RefreshWithin: defaultRefresh,
},
refreshLock: &sync.RWMutex{},
sender: &http.Client{},
refreshCallbacks: callbacks,
MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
}
if userAssignedID != nil {
spt.inner.ClientID = *userAssignedID
}
return spt, nil
}
// internal type that implements TokenRefreshError
type tokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre tokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre tokenRefreshError) Response() *http.Response {
return tre.resp
}
func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
return tokenRefreshError{message: message, resp: resp}
}
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on.
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error {
if spt.autoRefresh && spt.WillExpireIn(spt.refreshWithin) {
return spt.Refresh()
return spt.EnsureFreshWithContext(context.Background())
}
// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
// take the write lock then check to see if the token was already refreshed
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
return spt.refreshInternal(ctx, spt.inner.Resource)
}
}
return nil
}
@@ -331,7 +730,7 @@ func (spt *ServicePrincipalToken) EnsureFresh() error {
func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
if spt.refreshCallbacks != nil {
for _, callback := range spt.refreshCallbacks {
err := callback(spt.Token)
err := callback(spt.inner.Token)
if err != nil {
return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
}
@@ -341,46 +740,103 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
}
// Refresh obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) Refresh() error {
return spt.refreshInternal(spt.resource)
return spt.RefreshWithContext(context.Background())
}
// RefreshWithContext obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(ctx, spt.inner.Resource)
}
// RefreshExchange refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.refreshInternal(resource)
return spt.RefreshExchangeWithContext(context.Background(), resource)
}
func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
v := url.Values{}
v.Set("client_id", spt.clientID)
v.Set("resource", resource)
// RefreshExchangeWithContext refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(ctx, resource)
}
if spt.RefreshToken != "" {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.RefreshToken)
} else {
v.Set("grant_type", OAuthGrantTypeClientCredentials)
err := spt.secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
func (spt *ServicePrincipalToken) getGrantType() string {
switch spt.inner.Secret.(type) {
case *ServicePrincipalUsernamePasswordSecret:
return OAuthGrantTypeUserPass
case *ServicePrincipalAuthorizationCodeSecret:
return OAuthGrantTypeAuthorizationCode
default:
return OAuthGrantTypeClientCredentials
}
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), body)
func isIMDS(u url.URL) bool {
imds, err := url.Parse(msiEndpoint)
if err != nil {
return false
}
return u.Host == imds.Host && u.Path == imds.Path
}
func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
}
req.Header.Add("User-Agent", version.UserAgent())
req = req.WithContext(ctx)
if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
v := url.Values{}
v.Set("client_id", spt.inner.ClientID)
v.Set("resource", resource)
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok {
if spt.inner.Token.RefreshToken != "" {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.inner.Token.RefreshToken)
// web apps must specify client_secret when refreshing tokens
// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
}
} else {
v.Set("grant_type", spt.getGrantType())
err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
req.Body = body
}
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
req.Method = http.MethodGet
req.Header.Set(metadataHeader, "true")
}
resp, err := spt.sender.Do(req)
var resp *http.Response
if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
} else {
resp, err = spt.sender.Do(req)
}
if err != nil {
return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
}
defer resp.Body.Close()
@@ -388,11 +844,15 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
if resp.StatusCode != http.StatusOK {
if err != nil {
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode)
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
}
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb))
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
}
// for the following error cases don't return a TokenRefreshError. the operation succeeded
// but some transient failure happened during deserialization. by returning a generic error
// the retry logic will kick in (we don't retry on TokenRefreshError).
if err != nil {
return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
}
@@ -405,23 +865,116 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
}
spt.Token = token
spt.inner.Token = token
return spt.InvokeRefreshCallbacks(token)
}
// retry logic specific to retrieving a token from the IMDS endpoint
func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
// copied from client.go due to circular dependency
retries := []int{
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout, // 504
}
// extra retry status codes specific to IMDS
retries = append(retries,
http.StatusNotFound,
http.StatusGone,
// all remaining 5xx
http.StatusNotImplemented,
http.StatusHTTPVersionNotSupported,
http.StatusVariantAlsoNegotiates,
http.StatusInsufficientStorage,
http.StatusLoopDetected,
http.StatusNotExtended,
http.StatusNetworkAuthenticationRequired)
// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
const maxDelay time.Duration = 60 * time.Second
attempt := 0
delay := time.Duration(0)
for attempt < maxAttempts {
resp, err = sender.Do(req)
// retry on temporary network errors, e.g. transient network failures.
// if we don't receive a response then assume we can't connect to the
// endpoint so we're likely not running on an Azure VM so don't retry.
if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
return
}
// perform exponential backoff with a cap.
// must increment attempt before calculating delay.
attempt++
// the base value of 2 is the "delta backoff" as specified in the guidance doc
delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
if delay > maxDelay {
delay = maxDelay
}
select {
case <-time.After(delay):
// intentionally left blank
case <-req.Context().Done():
err = req.Context().Err()
return
}
}
return
}
// returns true if the specified error is a temporary network error or false if it's not.
// if the error doesn't implement the net.Error interface the return value is true.
func isTemporaryNetworkError(err error) bool {
if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
return true
}
return false
}
// returns true if slice ints contains the value n
func containsInt(ints []int, n int) bool {
for _, i := range ints {
if i == n {
return true
}
}
return false
}
// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
spt.autoRefresh = autoRefresh
spt.inner.AutoRefresh = autoRefresh
}
// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
// refresh the token.
func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
spt.refreshWithin = d
spt.inner.RefreshWithin = d
return
}
// SetSender sets the http.Client used when obtaining the Service Principal token. An
// undecorated http.Client is used by default.
func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
// OAuthToken implements the OAuthTokenProvider interface. It returns the current access token.
func (spt *ServicePrincipalToken) OAuthToken() string {
spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock()
return spt.inner.Token.OAuthToken()
}
// Token returns a copy of the current token.
func (spt *ServicePrincipalToken) Token() Token {
spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock()
return spt.inner.Token
}

View File

@@ -15,24 +15,27 @@ package adal
// limitations under the License.
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/mocks"
jwt "github.com/dgrijalva/jwt-go"
)
const (
@@ -86,12 +89,12 @@ func TestTokenWillExpireIn(t *testing.T) {
func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
spt := newServicePrincipalToken()
if !spt.autoRefresh {
if !spt.inner.AutoRefresh {
t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
}
spt.SetAutoRefresh(false)
if spt.autoRefresh {
if spt.inner.AutoRefresh {
t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
}
}
@@ -99,12 +102,12 @@ func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
spt := newServicePrincipalToken()
if spt.refreshWithin != defaultRefresh {
if spt.inner.RefreshWithin != defaultRefresh {
t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
}
spt.SetRefreshWithin(2 * defaultRefresh)
if spt.refreshWithin != 2*defaultRefresh {
if spt.inner.RefreshWithin != 2*defaultRefresh {
t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
}
}
@@ -148,7 +151,7 @@ func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
}
}
func TestServicePrincipalTokenFromMSIRefreshUsesPOST(t *testing.T) {
func TestServicePrincipalTokenFromMSIRefreshUsesGET(t *testing.T) {
resource := "https://resource"
cb := func(token Token) error { return nil }
@@ -165,8 +168,8 @@ func TestServicePrincipalTokenFromMSIRefreshUsesPOST(t *testing.T) {
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "POST" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
if r.Method != "GET" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
}
if h := r.Header.Get("Metadata"); h != "true" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
@@ -186,6 +189,39 @@ func TestServicePrincipalTokenFromMSIRefreshUsesPOST(t *testing.T) {
}
}
func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
endpoint, _ := GetMSIVMEndpoint()
spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource")
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
c := mocks.NewSender()
c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5)
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
end := time.Now()
go func() {
spt.SetSender(c)
err = spt.RefreshWithContext(ctx)
end = time.Now()
wg.Done()
}()
cancel()
wg.Wait()
time.Sleep(5 * time.Millisecond)
if end.Sub(start) >= time.Second {
t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel")
}
}
func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
spt := newServicePrincipalToken()
@@ -286,6 +322,63 @@ func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
values["resource"][0] != "resource" {
t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
}
tok, _ := jwt.Parse(values["client_assertion"][0], nil)
if tok == nil {
t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT")
}
if _, ok := tok.Header["x5t"]; !ok {
t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5t header")
}
if _, ok := tok.Header["x5c"]; !ok {
t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5c header")
}
})
}
func TestServicePrincipalTokenUsernamePasswordRefreshSetsBody(t *testing.T) {
spt := newServicePrincipalTokenUsernamePassword(t)
testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
body := string(b)
values, _ := url.ParseQuery(body)
if values["client_id"][0] != "id" ||
values["grant_type"][0] != "password" ||
values["username"][0] != "username" ||
values["password"][0] != "password" ||
values["resource"][0] != "resource" {
t.Fatalf("adal: ServicePrincipalTokenUsernamePassword#Refresh did not correctly set the HTTP Request Body.")
}
})
}
func TestServicePrincipalTokenAuthorizationCodeRefreshSetsBody(t *testing.T) {
spt := newServicePrincipalTokenAuthorizationCode(t)
testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
body := string(b)
values, _ := url.ParseQuery(body)
if values["client_id"][0] != "id" ||
values["grant_type"][0] != OAuthGrantTypeAuthorizationCode ||
values["code"][0] != "code" ||
values["client_secret"][0] != "clientSecret" ||
values["redirect_uri"][0] != "http://redirectUri/getToken" ||
values["resource"][0] != "resource" {
t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
}
})
testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
body := string(b)
values, _ := url.ParseQuery(body)
if values["client_id"][0] != "id" ||
values["grant_type"][0] != OAuthGrantTypeRefreshToken ||
values["code"][0] != "code" ||
values["client_secret"][0] != "clientSecret" ||
values["redirect_uri"][0] != "http://redirectUri/getToken" ||
values["resource"][0] != "resource" {
t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
}
})
}
@@ -412,12 +505,12 @@ func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
} else if spt.AccessToken != "accessToken" ||
spt.ExpiresIn != "3600" ||
spt.ExpiresOn != expiresOn ||
spt.NotBefore != expiresOn ||
spt.Resource != "resource" ||
spt.Type != "Bearer" {
} else if spt.inner.Token.AccessToken != "accessToken" ||
spt.inner.Token.ExpiresIn != "3600" ||
spt.inner.Token.ExpiresOn != json.Number(expiresOn) ||
spt.inner.Token.NotBefore != json.Number(expiresOn) ||
spt.inner.Token.Resource != "resource" ||
spt.inner.Token.Type != "Bearer" {
t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
j, *spt)
}
@@ -425,7 +518,7 @@ func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
spt := newServicePrincipalToken()
expireToken(&spt.Token)
expireToken(&spt.inner.Token)
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
@@ -451,9 +544,26 @@ func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
}
}
func TestServicePrincipalTokenEnsureFreshFails(t *testing.T) {
spt := newServicePrincipalToken()
expireToken(&spt.inner.Token)
c := mocks.NewSender()
c.SetError(fmt.Errorf("some failure"))
spt.SetSender(c)
err := spt.EnsureFresh()
if err == nil {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
}
if _, ok := err.(TokenRefreshError); !ok {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return a TokenRefreshError")
}
}
func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
spt := newServicePrincipalToken()
setTokenToExpireIn(&spt.Token, 1000*time.Second)
setTokenToExpireIn(&spt.inner.Token, 1000*time.Second)
f := false
c := mocks.NewSender()
@@ -520,7 +630,7 @@ func TestRefreshCallbackErrorPropagates(t *testing.T) {
// This demonstrates the danger of manual token without a refresh token
func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
spt := newServicePrincipalTokenManual()
spt.RefreshToken = ""
spt.inner.Token.RefreshToken = ""
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
@@ -537,11 +647,11 @@ func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
}
// check some of the SPT fields
if _, ok := spt.secret.(*ServicePrincipalMSISecret); !ok {
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
t.Fatal("SPT secret was not of MSI type")
}
if spt.resource != resource {
if spt.inner.Resource != resource {
t.Fatal("SPT came back with incorrect resource")
}
@@ -550,36 +660,188 @@ func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
}
}
func TestGetVMEndpoint(t *testing.T) {
tempSettingsFile, err := ioutil.TempFile("", "ManagedIdentity-Settings")
func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
resource := "https://resource"
userID := "abc123"
cb := func(token Token) error { return nil }
spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb)
if err != nil {
t.Fatal("Couldn't write temp settings file")
}
defer os.Remove(tempSettingsFile.Name())
settingsContents := []byte(`{
"url": "http://msiendpoint/"
}`)
if _, err := tempSettingsFile.Write(settingsContents); err != nil {
t.Fatal("Couldn't fill temp settings file")
t.Fatalf("Failed to get MSI SPT: %v", err)
}
endpoint, err := getMSIVMEndpoint(tempSettingsFile.Name())
// check some of the SPT fields
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
t.Fatal("SPT secret was not of MSI type")
}
if spt.inner.Resource != resource {
t.Fatal("SPT came back with incorrect resource")
}
if len(spt.refreshCallbacks) != 1 {
t.Fatal("SPT had incorrect refresh callbacks.")
}
if spt.inner.ClientID != userID {
t.Fatal("SPT had incorrect client ID")
}
}
func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
token := newToken()
secret := &ServicePrincipalAuthorizationCodeSecret{
ClientSecret: "clientSecret",
AuthorizationCode: "code123",
RedirectURI: "redirect",
}
spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", token, secret, nil)
if err != nil {
t.Fatalf("Failed creating new SPT: %s", err)
}
if !reflect.DeepEqual(token, spt.inner.Token) {
t.Fatalf("Tokens do not match: %s, %s", token, spt.inner.Token)
}
if !reflect.DeepEqual(secret, spt.inner.Secret) {
t.Fatalf("Secrets do not match: %s, %s", secret, spt.inner.Secret)
}
}
func TestGetVMEndpoint(t *testing.T) {
endpoint, err := GetMSIVMEndpoint()
if err != nil {
t.Fatal("Coudn't get VM endpoint")
}
if endpoint != "http://msiendpoint/" {
if endpoint != msiEndpoint {
t.Fatal("Didn't get correct endpoint")
}
}
func newToken() *Token {
return &Token{
AccessToken: "ASECRETVALUE",
Resource: "https://azure.microsoft.com/",
Type: "Bearer",
func TestMarshalServicePrincipalNoSecret(t *testing.T) {
spt := newServicePrincipalTokenManual()
b, err := json.Marshal(spt)
if err != nil {
t.Fatalf("failed to marshal token: %+v", err)
}
var spt2 *ServicePrincipalToken
err = json.Unmarshal(b, &spt2)
if err != nil {
t.Fatalf("failed to unmarshal token: %+v", err)
}
if !reflect.DeepEqual(spt, spt2) {
t.Fatal("tokens don't match")
}
}
func TestMarshalServicePrincipalTokenSecret(t *testing.T) {
spt := newServicePrincipalToken()
b, err := json.Marshal(spt)
if err != nil {
t.Fatalf("failed to marshal token: %+v", err)
}
var spt2 *ServicePrincipalToken
err = json.Unmarshal(b, &spt2)
if err != nil {
t.Fatalf("failed to unmarshal token: %+v", err)
}
if !reflect.DeepEqual(spt, spt2) {
t.Fatal("tokens don't match")
}
}
func TestMarshalServicePrincipalCertificateSecret(t *testing.T) {
spt := newServicePrincipalTokenCertificate(t)
b, err := json.Marshal(spt)
if err == nil {
t.Fatal("expected error when marshalling certificate token")
}
var spt2 *ServicePrincipalToken
err = json.Unmarshal(b, &spt2)
if err == nil {
t.Fatal("expected error when unmarshalling certificate token")
}
}
func TestMarshalServicePrincipalMSISecret(t *testing.T) {
spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil)
if err != nil {
t.Fatalf("failed to get MSI SPT: %+v", err)
}
b, err := json.Marshal(spt)
if err == nil {
t.Fatal("expected error when marshalling MSI token")
}
var spt2 *ServicePrincipalToken
err = json.Unmarshal(b, &spt2)
if err == nil {
t.Fatal("expected error when unmarshalling MSI token")
}
}
func TestMarshalServicePrincipalUsernamePasswordSecret(t *testing.T) {
spt := newServicePrincipalTokenUsernamePassword(t)
b, err := json.Marshal(spt)
if err != nil {
t.Fatalf("failed to marshal token: %+v", err)
}
var spt2 *ServicePrincipalToken
err = json.Unmarshal(b, &spt2)
if err != nil {
t.Fatalf("failed to unmarshal token: %+v", err)
}
if !reflect.DeepEqual(spt, spt2) {
t.Fatal("tokens don't match")
}
}
func TestMarshalServicePrincipalAuthorizationCodeSecret(t *testing.T) {
spt := newServicePrincipalTokenAuthorizationCode(t)
b, err := json.Marshal(spt)
if err != nil {
t.Fatalf("failed to marshal token: %+v", err)
}
var spt2 *ServicePrincipalToken
err = json.Unmarshal(b, &spt2)
if err != nil {
t.Fatalf("failed to unmarshal token: %+v", err)
}
if !reflect.DeepEqual(spt, spt2) {
t.Fatal("tokens don't match")
}
}
func TestMarshalInnerToken(t *testing.T) {
spt := newServicePrincipalTokenManual()
tokenJSON, err := spt.MarshalTokenJSON()
if err != nil {
t.Fatalf("failed to marshal token: %+v", err)
}
testToken := newToken()
testToken.RefreshToken = "refreshtoken"
testTokenJSON, err := json.Marshal(testToken)
if err != nil {
t.Fatalf("failed to marshal test token: %+v", err)
}
if !reflect.DeepEqual(tokenJSON, testTokenJSON) {
t.Fatalf("tokens don't match: %s, %s", tokenJSON, testTokenJSON)
}
var t1 Token
err = json.Unmarshal(tokenJSON, &t1)
if err != nil {
t.Fatalf("failed to unmarshal token: %+v", err)
}
if !reflect.DeepEqual(t1, testToken) {
t.Fatalf("tokens don't match: %s, %s", t1, testToken)
}
}
@@ -590,17 +852,20 @@ func newTokenJSON(expiresOn string, resource string) string {
"expires_on" : "%s",
"not_before" : "%s",
"resource" : "%s",
"token_type" : "Bearer"
"token_type" : "Bearer",
"refresh_token": "ABC123"
}`,
expiresOn, expiresOn, resource)
}
func newTokenExpiresIn(expireIn time.Duration) *Token {
return setTokenToExpireIn(newToken(), expireIn)
t := newToken()
return setTokenToExpireIn(&t, expireIn)
}
func newTokenExpiresAt(expireAt time.Time) *Token {
return setTokenToExpireAt(newToken(), expireAt)
t := newToken()
return setTokenToExpireAt(&t, expireAt)
}
func expireToken(t *Token) *Token {
@@ -609,7 +874,7 @@ func expireToken(t *Token) *Token {
func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
t.ExpiresIn = "3600"
t.ExpiresOn = strconv.Itoa(int(expireAt.Sub(date.UnixEpoch()).Seconds()))
t.ExpiresOn = json.Number(strconv.FormatInt(int64(expireAt.Sub(date.UnixEpoch())/time.Second), 10))
t.NotBefore = t.ExpiresOn
return t
}
@@ -626,7 +891,7 @@ func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincip
func newServicePrincipalTokenManual() *ServicePrincipalToken {
token := newToken()
token.RefreshToken = "refreshtoken"
spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", *token)
spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", token)
return spt
}
@@ -652,3 +917,13 @@ func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
return spt
}
func newServicePrincipalTokenUsernamePassword(t *testing.T) *ServicePrincipalToken {
spt, _ := NewServicePrincipalTokenFromUsernamePassword(TestOAuthConfig, "id", "username", "password", "resource")
return spt
}
func newServicePrincipalTokenAuthorizationCode(t *testing.T) *ServicePrincipalToken {
spt, _ := NewServicePrincipalTokenFromAuthorizationCode(TestOAuthConfig, "id", "clientSecret", "code", "http://redirectUri/getToken", "resource")
return spt
}

View File

@@ -24,9 +24,12 @@ import (
)
const (
bearerChallengeHeader = "Www-Authenticate"
bearer = "Bearer"
tenantID = "tenantID"
bearerChallengeHeader = "Www-Authenticate"
bearer = "Bearer"
tenantID = "tenantID"
apiKeyAuthorizerHeader = "Ocp-Apim-Subscription-Key"
bingAPISdkHeader = "X-BingApis-SDK-Client"
golangBingAPISdkHeaderValue = "Go-SDK"
)
// Authorizer is the interface that provides a PrepareDecorator used to supply request
@@ -44,6 +47,53 @@ func (na NullAuthorizer) WithAuthorization() PrepareDecorator {
return WithNothing()
}
// APIKeyAuthorizer implements API Key authorization.
type APIKeyAuthorizer struct {
headers map[string]interface{}
queryParameters map[string]interface{}
}
// NewAPIKeyAuthorizerWithHeaders creates an ApiKeyAuthorizer with headers.
func NewAPIKeyAuthorizerWithHeaders(headers map[string]interface{}) *APIKeyAuthorizer {
return NewAPIKeyAuthorizer(headers, nil)
}
// NewAPIKeyAuthorizerWithQueryParameters creates an ApiKeyAuthorizer with query parameters.
func NewAPIKeyAuthorizerWithQueryParameters(queryParameters map[string]interface{}) *APIKeyAuthorizer {
return NewAPIKeyAuthorizer(nil, queryParameters)
}
// NewAPIKeyAuthorizer creates an ApiKeyAuthorizer with headers.
func NewAPIKeyAuthorizer(headers map[string]interface{}, queryParameters map[string]interface{}) *APIKeyAuthorizer {
return &APIKeyAuthorizer{headers: headers, queryParameters: queryParameters}
}
// WithAuthorization returns a PrepareDecorator that adds an HTTP headers and Query Paramaters
func (aka *APIKeyAuthorizer) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return DecoratePreparer(p, WithHeaders(aka.headers), WithQueryParameters(aka.queryParameters))
}
}
// CognitiveServicesAuthorizer implements authorization for Cognitive Services.
type CognitiveServicesAuthorizer struct {
subscriptionKey string
}
// NewCognitiveServicesAuthorizer is
func NewCognitiveServicesAuthorizer(subscriptionKey string) *CognitiveServicesAuthorizer {
return &CognitiveServicesAuthorizer{subscriptionKey: subscriptionKey}
}
// WithAuthorization is
func (csa *CognitiveServicesAuthorizer) WithAuthorization() PrepareDecorator {
headers := make(map[string]interface{})
headers[apiKeyAuthorizerHeader] = csa.subscriptionKey
headers[bingAPISdkHeader] = golangBingAPISdkHeaderValue
return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization()
}
// BearerAuthorizer implements the bearer authorization
type BearerAuthorizer struct {
tokenProvider adal.OAuthTokenProvider
@@ -54,10 +104,6 @@ func NewBearerAuthorizer(tp adal.OAuthTokenProvider) *BearerAuthorizer {
return &BearerAuthorizer{tokenProvider: tp}
}
func (ba *BearerAuthorizer) withBearerAuthorization() PrepareDecorator {
return WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken()))
}
// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the token.
//
@@ -65,15 +111,25 @@ func (ba *BearerAuthorizer) withBearerAuthorization() PrepareDecorator {
func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
refresher, ok := ba.tokenProvider.(adal.Refresher)
if ok {
err := refresher.EnsureFresh()
r, err := p.Prepare(r)
if err == nil {
// the ordering is important here, prefer RefresherWithContext if available
if refresher, ok := ba.tokenProvider.(adal.RefresherWithContext); ok {
err = refresher.EnsureFreshWithContext(r.Context())
} else if refresher, ok := ba.tokenProvider.(adal.Refresher); ok {
err = refresher.EnsureFresh()
}
if err != nil {
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", nil,
var resp *http.Response
if tokError, ok := err.(adal.TokenRefreshError); ok {
resp = tokError.Response()
}
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL)
}
return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken())))
}
return (ba.withBearerAuthorization()(p)).Prepare(r)
return r, err
})
}
}
@@ -103,25 +159,28 @@ func NewBearerAuthorizerCallback(sender Sender, callback BearerAuthorizerCallbac
func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
// make a copy of the request and remove the body as it's not
// required and avoids us having to create a copy of it.
rCopy := *r
removeRequestBody(&rCopy)
r, err := p.Prepare(r)
if err == nil {
// make a copy of the request and remove the body as it's not
// required and avoids us having to create a copy of it.
rCopy := *r
removeRequestBody(&rCopy)
resp, err := bacb.sender.Do(&rCopy)
if err == nil && resp.StatusCode == 401 {
defer resp.Body.Close()
if hasBearerChallenge(resp) {
bc, err := newBearerChallenge(resp)
if err != nil {
return r, err
}
if bacb.callback != nil {
ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"])
resp, err := bacb.sender.Do(&rCopy)
if err == nil && resp.StatusCode == 401 {
defer resp.Body.Close()
if hasBearerChallenge(resp) {
bc, err := newBearerChallenge(resp)
if err != nil {
return r, err
}
return ba.WithAuthorization()(p).Prepare(r)
if bacb.callback != nil {
ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"])
if err != nil {
return r, err
}
return Prepare(r, ba.WithAuthorization())
}
}
}
}
@@ -179,3 +238,22 @@ func newBearerChallenge(resp *http.Response) (bc bearerChallenge, err error) {
return bc, err
}
// EventGridKeyAuthorizer implements authorization for event grid using key authentication.
type EventGridKeyAuthorizer struct {
topicKey string
}
// NewEventGridKeyAuthorizer creates a new EventGridKeyAuthorizer
// with the specified topic key.
func NewEventGridKeyAuthorizer(topicKey string) EventGridKeyAuthorizer {
return EventGridKeyAuthorizer{topicKey: topicKey}
}
// WithAuthorization returns a PrepareDecorator that adds the aeg-sas-key authentication header.
func (egta EventGridKeyAuthorizer) WithAuthorization() PrepareDecorator {
headers := map[string]interface{}{
"aeg-sas-key": egta.topicKey,
}
return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization()
}

View File

@@ -75,7 +75,7 @@ func TestServicePrincipalTokenWithAuthorizationNoRefresh(t *testing.T) {
req, err := Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.AccessToken) {
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.OAuthToken()) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
}
@@ -120,7 +120,7 @@ func TestServicePrincipalTokenWithAuthorizationRefresh(t *testing.T) {
req, err := Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.AccessToken) {
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.OAuthToken()) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
@@ -186,3 +186,45 @@ func TestBearerAuthorizerCallback(t *testing.T) {
t.Fatal("azure: BearerAuthorizerCallback#WithAuthorization failed to return an error when refresh fails")
}
}
func TestApiKeyAuthorization(t *testing.T) {
headers := make(map[string]interface{})
queryParameters := make(map[string]interface{})
dummyAuthHeader := "dummyAuthHeader"
dummyAuthHeaderValue := "dummyAuthHeaderValue"
dummyAuthQueryParameter := "dummyAuthQueryParameter"
dummyAuthQueryParameterValue := "dummyAuthQueryParameterValue"
headers[dummyAuthHeader] = dummyAuthHeaderValue
queryParameters[dummyAuthQueryParameter] = dummyAuthQueryParameterValue
aka := NewAPIKeyAuthorizer(headers, queryParameters)
req, err := Prepare(mocks.NewRequest(), aka.WithAuthorization())
if err != nil {
t.Fatalf("azure: APIKeyAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey(dummyAuthHeader)) != dummyAuthHeaderValue {
t.Fatalf("azure: APIKeyAuthorizer#WithAuthorization failed to set %s header", dummyAuthHeader)
} else if req.URL.Query().Get(dummyAuthQueryParameter) != dummyAuthQueryParameterValue {
t.Fatalf("azure: APIKeyAuthorizer#WithAuthorization failed to set %s query parameter", dummyAuthQueryParameterValue)
}
}
func TestCognitivesServicesAuthorization(t *testing.T) {
subscriptionKey := "dummyKey"
csa := NewCognitiveServicesAuthorizer(subscriptionKey)
req, err := Prepare(mocks.NewRequest(), csa.WithAuthorization())
if err != nil {
t.Fatalf("azure: CognitiveServicesAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey(bingAPISdkHeader)) != golangBingAPISdkHeaderValue {
t.Fatalf("azure: CognitiveServicesAuthorizer#WithAuthorization failed to set %s header", bingAPISdkHeader)
} else if req.Header.Get(http.CanonicalHeaderKey(apiKeyAuthorizerHeader)) != subscriptionKey {
t.Fatalf("azure: CognitiveServicesAuthorizer#WithAuthorization failed to set %s header", apiKeyAuthorizerHeader)
}
}

View File

@@ -72,6 +72,7 @@ package autorest
// limitations under the License.
import (
"context"
"net/http"
"time"
)
@@ -87,6 +88,9 @@ const (
// ResponseHasStatusCode returns true if the status code in the HTTP Response is in the passed set
// and false otherwise.
func ResponseHasStatusCode(resp *http.Response, codes ...int) bool {
if resp == nil {
return false
}
return containsInt(codes, resp.StatusCode)
}
@@ -127,3 +131,20 @@ func NewPollingRequest(resp *http.Response, cancel <-chan struct{}) (*http.Reque
return req, nil
}
// NewPollingRequestWithContext allocates and returns a new http.Request with the specified context to poll for the passed response.
func NewPollingRequestWithContext(ctx context.Context, resp *http.Response) (*http.Request, error) {
location := GetLocation(resp)
if location == "" {
return nil, NewErrorWithResponse("autorest", "NewPollingRequestWithContext", resp, "Location header missing from response that requires polling")
}
req, err := Prepare((&http.Request{}).WithContext(ctx),
AsGet(),
WithBaseURL(location))
if err != nil {
return nil, NewErrorWithError(err, "autorest", "NewPollingRequestWithContext", nil, "Failure creating poll request to %s", location)
}
return req, nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,492 @@
package auth
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"crypto/rsa"
"crypto/x509"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"strings"
"unicode/utf16"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/azure/cli"
"github.com/dimchansky/utfbom"
"golang.org/x/crypto/pkcs12"
)
// NewAuthorizerFromEnvironment creates an Authorizer configured from environment variables in the order:
// 1. Client credentials
// 2. Client certificate
// 3. Username password
// 4. MSI
func NewAuthorizerFromEnvironment() (autorest.Authorizer, error) {
settings, err := getAuthenticationSettings()
if err != nil {
return nil, err
}
if settings.resource == "" {
settings.resource = settings.environment.ResourceManagerEndpoint
}
return settings.getAuthorizer()
}
// NewAuthorizerFromEnvironmentWithResource creates an Authorizer configured from environment variables in the order:
// 1. Client credentials
// 2. Client certificate
// 3. Username password
// 4. MSI
func NewAuthorizerFromEnvironmentWithResource(resource string) (autorest.Authorizer, error) {
settings, err := getAuthenticationSettings()
if err != nil {
return nil, err
}
settings.resource = resource
return settings.getAuthorizer()
}
type settings struct {
tenantID string
clientID string
clientSecret string
certificatePath string
certificatePassword string
username string
password string
envName string
resource string
environment azure.Environment
}
func getAuthenticationSettings() (s settings, err error) {
s = settings{
tenantID: os.Getenv("AZURE_TENANT_ID"),
clientID: os.Getenv("AZURE_CLIENT_ID"),
clientSecret: os.Getenv("AZURE_CLIENT_SECRET"),
certificatePath: os.Getenv("AZURE_CERTIFICATE_PATH"),
certificatePassword: os.Getenv("AZURE_CERTIFICATE_PASSWORD"),
username: os.Getenv("AZURE_USERNAME"),
password: os.Getenv("AZURE_PASSWORD"),
envName: os.Getenv("AZURE_ENVIRONMENT"),
resource: os.Getenv("AZURE_AD_RESOURCE"),
}
if s.envName == "" {
s.environment = azure.PublicCloud
} else {
s.environment, err = azure.EnvironmentFromName(s.envName)
}
return
}
func (settings settings) getAuthorizer() (autorest.Authorizer, error) {
//1.Client Credentials
if settings.clientSecret != "" {
config := NewClientCredentialsConfig(settings.clientID, settings.clientSecret, settings.tenantID)
config.AADEndpoint = settings.environment.ActiveDirectoryEndpoint
config.Resource = settings.resource
return config.Authorizer()
}
//2. Client Certificate
if settings.certificatePath != "" {
config := NewClientCertificateConfig(settings.certificatePath, settings.certificatePassword, settings.clientID, settings.tenantID)
config.AADEndpoint = settings.environment.ActiveDirectoryEndpoint
config.Resource = settings.resource
return config.Authorizer()
}
//3. Username Password
if settings.username != "" && settings.password != "" {
config := NewUsernamePasswordConfig(settings.username, settings.password, settings.clientID, settings.tenantID)
config.AADEndpoint = settings.environment.ActiveDirectoryEndpoint
config.Resource = settings.resource
return config.Authorizer()
}
// 4. MSI
config := NewMSIConfig()
config.Resource = settings.resource
config.ClientID = settings.clientID
return config.Authorizer()
}
// NewAuthorizerFromFile creates an Authorizer configured from a configuration file.
func NewAuthorizerFromFile(baseURI string) (autorest.Authorizer, error) {
file, err := getAuthFile()
if err != nil {
return nil, err
}
resource, err := getResourceForToken(*file, baseURI)
if err != nil {
return nil, err
}
return NewAuthorizerFromFileWithResource(resource)
}
// NewAuthorizerFromFileWithResource creates an Authorizer configured from a configuration file.
func NewAuthorizerFromFileWithResource(resource string) (autorest.Authorizer, error) {
file, err := getAuthFile()
if err != nil {
return nil, err
}
config, err := adal.NewOAuthConfig(file.ActiveDirectoryEndpoint, file.TenantID)
if err != nil {
return nil, err
}
spToken, err := adal.NewServicePrincipalToken(*config, file.ClientID, file.ClientSecret, resource)
if err != nil {
return nil, err
}
return autorest.NewBearerAuthorizer(spToken), nil
}
// NewAuthorizerFromCLI creates an Authorizer configured from Azure CLI 2.0 for local development scenarios.
func NewAuthorizerFromCLI() (autorest.Authorizer, error) {
settings, err := getAuthenticationSettings()
if err != nil {
return nil, err
}
if settings.resource == "" {
settings.resource = settings.environment.ResourceManagerEndpoint
}
return NewAuthorizerFromCLIWithResource(settings.resource)
}
// NewAuthorizerFromCLIWithResource creates an Authorizer configured from Azure CLI 2.0 for local development scenarios.
func NewAuthorizerFromCLIWithResource(resource string) (autorest.Authorizer, error) {
token, err := cli.GetTokenFromCLI(resource)
if err != nil {
return nil, err
}
adalToken, err := token.ToADALToken()
if err != nil {
return nil, err
}
return autorest.NewBearerAuthorizer(&adalToken), nil
}
func getAuthFile() (*file, error) {
fileLocation := os.Getenv("AZURE_AUTH_LOCATION")
if fileLocation == "" {
return nil, errors.New("environment variable AZURE_AUTH_LOCATION is not set")
}
contents, err := ioutil.ReadFile(fileLocation)
if err != nil {
return nil, err
}
// Auth file might be encoded
decoded, err := decode(contents)
if err != nil {
return nil, err
}
authFile := file{}
err = json.Unmarshal(decoded, &authFile)
if err != nil {
return nil, err
}
return &authFile, nil
}
// File represents the authentication file
type file struct {
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
ActiveDirectoryEndpoint string `json:"activeDirectoryEndpointUrl,omitempty"`
ResourceManagerEndpoint string `json:"resourceManagerEndpointUrl,omitempty"`
GraphResourceID string `json:"activeDirectoryGraphResourceId,omitempty"`
SQLManagementEndpoint string `json:"sqlManagementEndpointUrl,omitempty"`
GalleryEndpoint string `json:"galleryEndpointUrl,omitempty"`
ManagementEndpoint string `json:"managementEndpointUrl,omitempty"`
}
func decode(b []byte) ([]byte, error) {
reader, enc := utfbom.Skip(bytes.NewReader(b))
switch enc {
case utfbom.UTF16LittleEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.LittleEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
case utfbom.UTF16BigEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.BigEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
}
return ioutil.ReadAll(reader)
}
func getResourceForToken(f file, baseURI string) (string, error) {
// Compare dafault base URI from the SDK to the endpoints from the public cloud
// Base URI and token resource are the same string. This func finds the authentication
// file field that matches the SDK base URI. The SDK defines the public cloud
// endpoint as its default base URI
if !strings.HasSuffix(baseURI, "/") {
baseURI += "/"
}
switch baseURI {
case azure.PublicCloud.ServiceManagementEndpoint:
return f.ManagementEndpoint, nil
case azure.PublicCloud.ResourceManagerEndpoint:
return f.ResourceManagerEndpoint, nil
case azure.PublicCloud.ActiveDirectoryEndpoint:
return f.ActiveDirectoryEndpoint, nil
case azure.PublicCloud.GalleryEndpoint:
return f.GalleryEndpoint, nil
case azure.PublicCloud.GraphEndpoint:
return f.GraphResourceID, nil
}
return "", fmt.Errorf("auth: base URI not found in endpoints")
}
// NewClientCredentialsConfig creates an AuthorizerConfig object configured to obtain an Authorizer through Client Credentials.
// Defaults to Public Cloud and Resource Manager Endpoint.
func NewClientCredentialsConfig(clientID string, clientSecret string, tenantID string) ClientCredentialsConfig {
return ClientCredentialsConfig{
ClientID: clientID,
ClientSecret: clientSecret,
TenantID: tenantID,
Resource: azure.PublicCloud.ResourceManagerEndpoint,
AADEndpoint: azure.PublicCloud.ActiveDirectoryEndpoint,
}
}
// NewClientCertificateConfig creates a ClientCertificateConfig object configured to obtain an Authorizer through client certificate.
// Defaults to Public Cloud and Resource Manager Endpoint.
func NewClientCertificateConfig(certificatePath string, certificatePassword string, clientID string, tenantID string) ClientCertificateConfig {
return ClientCertificateConfig{
CertificatePath: certificatePath,
CertificatePassword: certificatePassword,
ClientID: clientID,
TenantID: tenantID,
Resource: azure.PublicCloud.ResourceManagerEndpoint,
AADEndpoint: azure.PublicCloud.ActiveDirectoryEndpoint,
}
}
// NewUsernamePasswordConfig creates an UsernamePasswordConfig object configured to obtain an Authorizer through username and password.
// Defaults to Public Cloud and Resource Manager Endpoint.
func NewUsernamePasswordConfig(username string, password string, clientID string, tenantID string) UsernamePasswordConfig {
return UsernamePasswordConfig{
Username: username,
Password: password,
ClientID: clientID,
TenantID: tenantID,
Resource: azure.PublicCloud.ResourceManagerEndpoint,
AADEndpoint: azure.PublicCloud.ActiveDirectoryEndpoint,
}
}
// NewMSIConfig creates an MSIConfig object configured to obtain an Authorizer through MSI.
func NewMSIConfig() MSIConfig {
return MSIConfig{
Resource: azure.PublicCloud.ResourceManagerEndpoint,
}
}
// NewDeviceFlowConfig creates a DeviceFlowConfig object configured to obtain an Authorizer through device flow.
// Defaults to Public Cloud and Resource Manager Endpoint.
func NewDeviceFlowConfig(clientID string, tenantID string) DeviceFlowConfig {
return DeviceFlowConfig{
ClientID: clientID,
TenantID: tenantID,
Resource: azure.PublicCloud.ResourceManagerEndpoint,
AADEndpoint: azure.PublicCloud.ActiveDirectoryEndpoint,
}
}
//AuthorizerConfig provides an authorizer from the configuration provided.
type AuthorizerConfig interface {
Authorizer() (autorest.Authorizer, error)
}
// ClientCredentialsConfig provides the options to get a bearer authorizer from client credentials.
type ClientCredentialsConfig struct {
ClientID string
ClientSecret string
TenantID string
AADEndpoint string
Resource string
}
// Authorizer gets the authorizer from client credentials.
func (ccc ClientCredentialsConfig) Authorizer() (autorest.Authorizer, error) {
oauthConfig, err := adal.NewOAuthConfig(ccc.AADEndpoint, ccc.TenantID)
if err != nil {
return nil, err
}
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, ccc.ClientID, ccc.ClientSecret, ccc.Resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from client credentials: %v", err)
}
return autorest.NewBearerAuthorizer(spToken), nil
}
// ClientCertificateConfig provides the options to get a bearer authorizer from a client certificate.
type ClientCertificateConfig struct {
ClientID string
CertificatePath string
CertificatePassword string
TenantID string
AADEndpoint string
Resource string
}
// Authorizer gets an authorizer object from client certificate.
func (ccc ClientCertificateConfig) Authorizer() (autorest.Authorizer, error) {
oauthConfig, err := adal.NewOAuthConfig(ccc.AADEndpoint, ccc.TenantID)
certData, err := ioutil.ReadFile(ccc.CertificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", ccc.CertificatePath, err)
}
certificate, rsaPrivateKey, err := decodePkcs12(certData, ccc.CertificatePassword)
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, ccc.ClientID, certificate, rsaPrivateKey, ccc.Resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from certificate auth: %v", err)
}
return autorest.NewBearerAuthorizer(spToken), nil
}
// DeviceFlowConfig provides the options to get a bearer authorizer using device flow authentication.
type DeviceFlowConfig struct {
ClientID string
TenantID string
AADEndpoint string
Resource string
}
// Authorizer gets the authorizer from device flow.
func (dfc DeviceFlowConfig) Authorizer() (autorest.Authorizer, error) {
oauthClient := &autorest.Client{}
oauthConfig, err := adal.NewOAuthConfig(dfc.AADEndpoint, dfc.TenantID)
deviceCode, err := adal.InitiateDeviceAuth(oauthClient, *oauthConfig, dfc.ClientID, dfc.Resource)
if err != nil {
return nil, fmt.Errorf("failed to start device auth flow: %s", err)
}
log.Println(*deviceCode.Message)
token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
if err != nil {
return nil, fmt.Errorf("failed to finish device auth flow: %s", err)
}
spToken, err := adal.NewServicePrincipalTokenFromManualToken(*oauthConfig, dfc.ClientID, dfc.Resource, *token)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err)
}
return autorest.NewBearerAuthorizer(spToken), nil
}
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
if err != nil {
return nil, nil, err
}
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
if !isRsaKey {
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
}
return certificate, rsaPrivateKey, nil
}
// UsernamePasswordConfig provides the options to get a bearer authorizer from a username and a password.
type UsernamePasswordConfig struct {
ClientID string
Username string
Password string
TenantID string
AADEndpoint string
Resource string
}
// Authorizer gets the authorizer from a username and a password.
func (ups UsernamePasswordConfig) Authorizer() (autorest.Authorizer, error) {
oauthConfig, err := adal.NewOAuthConfig(ups.AADEndpoint, ups.TenantID)
spToken, err := adal.NewServicePrincipalTokenFromUsernamePassword(*oauthConfig, ups.ClientID, ups.Username, ups.Password, ups.Resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from username and password auth: %v", err)
}
return autorest.NewBearerAuthorizer(spToken), nil
}
// MSIConfig provides the options to get a bearer authorizer through MSI.
type MSIConfig struct {
Resource string
ClientID string
}
// Authorizer gets the authorizer from MSI.
func (mc MSIConfig) Authorizer() (autorest.Authorizer, error) {
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, err
}
spToken, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, mc.Resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from MSI: %v", err)
}
return autorest.NewBearerAuthorizer(spToken), nil
}

View File

@@ -0,0 +1,129 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
)
var (
expectedFile = file{
ClientID: "client-id-123",
ClientSecret: "client-secret-456",
SubscriptionID: "sub-id-789",
TenantID: "tenant-id-123",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com",
ResourceManagerEndpoint: "https://management.azure.com/",
GraphResourceID: "https://graph.windows.net/",
SQLManagementEndpoint: "https://management.core.windows.net:8443/",
GalleryEndpoint: "https://gallery.azure.com/",
ManagementEndpoint: "https://management.core.windows.net/",
}
)
func TestNewAuthorizerFromFile(t *testing.T) {
os.Setenv("AZURE_AUTH_LOCATION", filepath.Join(getCredsPath(), "credsutf16le.json"))
authorizer, err := NewAuthorizerFromFile("https://management.azure.com")
if err != nil || authorizer == nil {
t.Logf("NewAuthorizerFromFile failed, got error %v", err)
t.Fail()
}
}
func TestNewAuthorizerFromFileWithResource(t *testing.T) {
os.Setenv("AZURE_AUTH_LOCATION", filepath.Join(getCredsPath(), "credsutf16le.json"))
authorizer, err := NewAuthorizerFromFileWithResource("https://my.vault.azure.net")
if err != nil || authorizer == nil {
t.Logf("NewAuthorizerFromFileWithResource failed, got error %v", err)
t.Fail()
}
}
func TestNewAuthorizerFromEnvironment(t *testing.T) {
os.Setenv("AZURE_TENANT_ID", expectedFile.TenantID)
os.Setenv("AZURE_CLIENT_ID", expectedFile.ClientID)
os.Setenv("AZURE_CLIENT_SECRET", expectedFile.ClientSecret)
authorizer, err := NewAuthorizerFromEnvironment()
if err != nil || authorizer == nil {
t.Logf("NewAuthorizerFromEnvironment failed, got error %v", err)
t.Fail()
}
}
func TestNewAuthorizerFromEnvironmentWithResource(t *testing.T) {
os.Setenv("AZURE_TENANT_ID", expectedFile.TenantID)
os.Setenv("AZURE_CLIENT_ID", expectedFile.ClientID)
os.Setenv("AZURE_CLIENT_SECRET", expectedFile.ClientSecret)
authorizer, err := NewAuthorizerFromEnvironmentWithResource("https://my.vault.azure.net")
if err != nil || authorizer == nil {
t.Logf("NewAuthorizerFromEnvironmentWithResource failed, got error %v", err)
t.Fail()
}
}
func TestDecodeAndUnmarshal(t *testing.T) {
tests := []string{
"credsutf8.json",
"credsutf16le.json",
"credsutf16be.json",
}
creds := getCredsPath()
for _, test := range tests {
b, err := ioutil.ReadFile(filepath.Join(creds, test))
if err != nil {
t.Logf("error reading file '%s': %s", test, err)
t.Fail()
}
decoded, err := decode(b)
if err != nil {
t.Logf("error decoding file '%s': %s", test, err)
t.Fail()
}
var got file
err = json.Unmarshal(decoded, &got)
if err != nil {
t.Logf("error unmarshaling file '%s': %s", test, err)
t.Fail()
}
if !reflect.DeepEqual(expectedFile, got) {
t.Logf("unmarshaled map expected %v, got %v", expectedFile, got)
t.Fail()
}
}
}
func getCredsPath() string {
gopath := os.Getenv("GOPATH")
return filepath.Join(gopath, "src", "github.com", "Azure", "go-autorest", "testdata")
}
func areMapsEqual(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for k := range a {
if a[k] != b[k] {
return false
}
}
return true
}

View File

@@ -1,129 +0,0 @@
package auth
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"strings"
"unicode/utf16"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/dimchansky/utfbom"
)
// ClientSetup includes authentication details and cloud specific
// parameters for ARM clients
type ClientSetup struct {
*autorest.BearerAuthorizer
File
BaseURI string
}
// File represents the authentication file
type File struct {
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
ActiveDirectoryEndpoint string `json:"activeDirectoryEndpointUrl,omitempty"`
ResourceManagerEndpoint string `json:"resourceManagerEndpointUrl,omitempty"`
GraphResourceID string `json:"activeDirectoryGraphResourceId,omitempty"`
SQLManagementEndpoint string `json:"sqlManagementEndpointUrl,omitempty"`
GalleryEndpoint string `json:"galleryEndpointUrl,omitempty"`
ManagementEndpoint string `json:"managementEndpointUrl,omitempty"`
}
// GetClientSetup provides an authorizer, base URI, subscriptionID and
// tenantID parameters from an Azure CLI auth file
func GetClientSetup(baseURI string) (auth ClientSetup, err error) {
fileLocation := os.Getenv("AZURE_AUTH_LOCATION")
if fileLocation == "" {
return auth, errors.New("auth file not found. Environment variable AZURE_AUTH_LOCATION is not set")
}
contents, err := ioutil.ReadFile(fileLocation)
if err != nil {
return
}
// Auth file might be encoded
decoded, err := decode(contents)
if err != nil {
return
}
err = json.Unmarshal(decoded, &auth.File)
if err != nil {
return
}
resource, err := getResourceForToken(auth.File, baseURI)
if err != nil {
return
}
auth.BaseURI = resource
config, err := adal.NewOAuthConfig(auth.ActiveDirectoryEndpoint, auth.TenantID)
if err != nil {
return
}
spToken, err := adal.NewServicePrincipalToken(*config, auth.ClientID, auth.ClientSecret, resource)
if err != nil {
return
}
auth.BearerAuthorizer = autorest.NewBearerAuthorizer(spToken)
return
}
func decode(b []byte) ([]byte, error) {
reader, enc := utfbom.Skip(bytes.NewReader(b))
switch enc {
case utfbom.UTF16LittleEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.LittleEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
case utfbom.UTF16BigEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.BigEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
}
return ioutil.ReadAll(reader)
}
func getResourceForToken(f File, baseURI string) (string, error) {
// Compare dafault base URI from the SDK to the endpoints from the public cloud
// Base URI and token resource are the same string. This func finds the authentication
// file field that matches the SDK base URI. The SDK defines the public cloud
// endpoint as its default base URI
if !strings.HasSuffix(baseURI, "/") {
baseURI += "/"
}
switch baseURI {
case azure.PublicCloud.ServiceManagementEndpoint:
return f.ManagementEndpoint, nil
case azure.PublicCloud.ResourceManagerEndpoint:
return f.ResourceManagerEndpoint, nil
case azure.PublicCloud.ActiveDirectoryEndpoint:
return f.ActiveDirectoryEndpoint, nil
case azure.PublicCloud.GalleryEndpoint:
return f.GalleryEndpoint, nil
case azure.PublicCloud.GraphEndpoint:
return f.GraphResourceID, nil
}
return "", fmt.Errorf("auth: base URI not found in endpoints")
}

View File

@@ -1,97 +0,0 @@
package auth
import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
)
var (
expectedFile = File{
ClientID: "client-id-123",
ClientSecret: "client-secret-456",
SubscriptionID: "sub-id-789",
TenantID: "tenant-id-123",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com",
ResourceManagerEndpoint: "https://management.azure.com/",
GraphResourceID: "https://graph.windows.net/",
SQLManagementEndpoint: "https://management.core.windows.net:8443/",
GalleryEndpoint: "https://gallery.azure.com/",
ManagementEndpoint: "https://management.core.windows.net/",
}
)
func TestGetClientSetup(t *testing.T) {
os.Setenv("AZURE_AUTH_LOCATION", filepath.Join(getCredsPath(), "credsutf16le.json"))
setup, err := GetClientSetup("https://management.azure.com")
if err != nil {
t.Logf("GetClientSetup failed, got error %v", err)
t.Fail()
}
if setup.BaseURI != "https://management.azure.com/" {
t.Logf("auth.BaseURI not set correctly, expected 'https://management.azure.com/', got '%s'", setup.BaseURI)
t.Fail()
}
if !reflect.DeepEqual(expectedFile, setup.File) {
t.Logf("auth.File not set correctly, expected %v, got %v", expectedFile, setup.File)
t.Fail()
}
if setup.BearerAuthorizer == nil {
t.Log("auth.Authorizer not set correctly, got nil")
t.Fail()
}
}
func TestDecodeAndUnmarshal(t *testing.T) {
tests := []string{
"credsutf8.json",
"credsutf16le.json",
"credsutf16be.json",
}
creds := getCredsPath()
for _, test := range tests {
b, err := ioutil.ReadFile(filepath.Join(creds, test))
if err != nil {
t.Logf("error reading file '%s': %s", test, err)
t.Fail()
}
decoded, err := decode(b)
if err != nil {
t.Logf("error decoding file '%s': %s", test, err)
t.Fail()
}
var got File
err = json.Unmarshal(decoded, &got)
if err != nil {
t.Logf("error unmarshaling file '%s': %s", test, err)
t.Fail()
}
if !reflect.DeepEqual(expectedFile, got) {
t.Logf("unmarshaled map expected %v, got %v", expectedFile, got)
t.Fail()
}
}
}
func getCredsPath() string {
gopath := os.Getenv("GOPATH")
return filepath.Join(gopath, "src", "github.com", "Azure", "go-autorest", "testdata")
}
func areMapsEqual(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for k := range a {
if a[k] != b[k] {
return false
}
}
return true
}

View File

@@ -1,8 +1,5 @@
/*
Package azure provides Azure-specific implementations used with AutoRest.
See the included examples for more detail.
*/
// Package azure provides Azure-specific implementations used with AutoRest.
// See the included examples for more detail.
package azure
// Copyright 2017 Microsoft Corporation
@@ -24,7 +21,9 @@ import (
"fmt"
"io/ioutil"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/Azure/go-autorest/autorest"
)
@@ -43,21 +42,100 @@ const (
)
// ServiceError encapsulates the error response from an Azure service.
// It adhears to the OData v4 specification for error responses.
type ServiceError struct {
Code string `json:"code"`
Message string `json:"message"`
Details *[]interface{} `json:"details"`
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details []map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
}
func (se ServiceError) Error() string {
if se.Details != nil {
d, err := json.Marshal(*(se.Details))
if err != nil {
return fmt.Sprintf("Code=%q Message=%q Details=%v", se.Code, se.Message, *se.Details)
}
return fmt.Sprintf("Code=%q Message=%q Details=%v", se.Code, se.Message, string(d))
result := fmt.Sprintf("Code=%q Message=%q", se.Code, se.Message)
if se.Target != nil {
result += fmt.Sprintf(" Target=%q", *se.Target)
}
return fmt.Sprintf("Code=%q Message=%q", se.Code, se.Message)
if se.Details != nil {
d, err := json.Marshal(se.Details)
if err != nil {
result += fmt.Sprintf(" Details=%v", se.Details)
}
result += fmt.Sprintf(" Details=%v", string(d))
}
if se.InnerError != nil {
d, err := json.Marshal(se.InnerError)
if err != nil {
result += fmt.Sprintf(" InnerError=%v", se.InnerError)
}
result += fmt.Sprintf(" InnerError=%v", string(d))
}
if se.AdditionalInfo != nil {
d, err := json.Marshal(se.AdditionalInfo)
if err != nil {
result += fmt.Sprintf(" AdditionalInfo=%v", se.AdditionalInfo)
}
result += fmt.Sprintf(" AdditionalInfo=%v", string(d))
}
return result
}
// UnmarshalJSON implements the json.Unmarshaler interface for the ServiceError type.
func (se *ServiceError) UnmarshalJSON(b []byte) error {
// per the OData v4 spec the details field must be an array of JSON objects.
// unfortunately not all services adhear to the spec and just return a single
// object instead of an array with one object. so we have to perform some
// shenanigans to accommodate both cases.
// http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091
type serviceError1 struct {
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details []map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
}
type serviceError2 struct {
Code string `json:"code"`
Message string `json:"message"`
Target *string `json:"target"`
Details map[string]interface{} `json:"details"`
InnerError map[string]interface{} `json:"innererror"`
AdditionalInfo []map[string]interface{} `json:"additionalInfo"`
}
se1 := serviceError1{}
err := json.Unmarshal(b, &se1)
if err == nil {
se.populate(se1.Code, se1.Message, se1.Target, se1.Details, se1.InnerError, se1.AdditionalInfo)
return nil
}
se2 := serviceError2{}
err = json.Unmarshal(b, &se2)
if err == nil {
se.populate(se2.Code, se2.Message, se2.Target, nil, se2.InnerError, se2.AdditionalInfo)
se.Details = append(se.Details, se2.Details)
return nil
}
return err
}
func (se *ServiceError) populate(code, message string, target *string, details []map[string]interface{}, inner map[string]interface{}, additional []map[string]interface{}) {
se.Code = code
se.Message = message
se.Target = target
se.Details = details
se.InnerError = inner
se.AdditionalInfo = additional
}
// RequestError describes an error response returned by Azure service.
@@ -83,6 +161,41 @@ func IsAzureError(e error) bool {
return ok
}
// Resource contains details about an Azure resource.
type Resource struct {
SubscriptionID string
ResourceGroup string
Provider string
ResourceType string
ResourceName string
}
// ParseResourceID parses a resource ID into a ResourceDetails struct.
// See https://docs.microsoft.com/en-us/azure/azure-resource-manager/resource-group-template-functions-resource#return-value-4.
func ParseResourceID(resourceID string) (Resource, error) {
const resourceIDPatternText = `(?i)subscriptions/(.+)/resourceGroups/(.+)/providers/(.+?)/(.+?)/(.+)`
resourceIDPattern := regexp.MustCompile(resourceIDPatternText)
match := resourceIDPattern.FindStringSubmatch(resourceID)
if len(match) == 0 {
return Resource{}, fmt.Errorf("parsing failed for %s. Invalid resource Id format", resourceID)
}
v := strings.Split(match[5], "/")
resourceName := v[len(v)-1]
result := Resource{
SubscriptionID: match[1],
ResourceGroup: match[2],
Provider: match[3],
ResourceType: match[4],
ResourceName: resourceName,
}
return result, nil
}
// NewErrorWithError creates a new Error conforming object from the
// passed packageType, method, statusCode of the given resp (UndefinedStatusCode
// if resp is nil), message, and original error. message is treated as a format
@@ -178,16 +291,29 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator {
resp.Body = ioutil.NopCloser(&b)
if decodeErr != nil {
return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr)
} else if e.ServiceError == nil {
}
if e.ServiceError == nil {
// Check if error is unwrapped ServiceError
if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil || e.ServiceError.Message == "" {
e.ServiceError = &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil {
return err
}
}
if e.ServiceError.Message == "" {
// if we're here it means the returned error wasn't OData v4 compliant.
// try to unmarshal the body as raw JSON in hopes of getting something.
rawBody := map[string]interface{}{}
if err := json.Unmarshal(b.Bytes(), &rawBody); err != nil {
return err
}
e.ServiceError = &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
if len(rawBody) > 0 {
e.ServiceError.Details = []map[string]interface{}{rawBody}
}
}
e.Response = resp
e.RequestID = ExtractRequestID(resp)
if e.StatusCode == nil {
e.StatusCode = resp.StatusCode

View File

@@ -235,13 +235,16 @@ func TestWithErrorUnlessStatusCode_FoundAzureErrorWithoutDetails(t *testing.T) {
}
func TestWithErrorUnlessStatusCode_FoundAzureErrorWithDetails(t *testing.T) {
func TestWithErrorUnlessStatusCode_FoundAzureFullError(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Azure is having trouble right now.",
"target": "target1",
"details": [{"code": "conflict1", "message":"error message1"},
{"code": "conflict2", "message":"error message2"}]
{"code": "conflict2", "message":"error message2"}],
"innererror": { "customKey": "customValue" },
"additionalInfo": [{"type": "someErrorType", "info": {"someProperty": "someValue"}}]
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
@@ -266,14 +269,30 @@ func TestWithErrorUnlessStatusCode_FoundAzureErrorWithDetails(t *testing.T) {
if expected := "InternalError"; azErr.ServiceError.Code != expected {
t.Fatalf("azure: wrong error code. expected=%q; got=%q", expected, azErr.ServiceError.Code)
}
if azErr.ServiceError.Message == "" {
t.Fatalf("azure: error message is not unmarshaled properly")
}
b, _ := json.Marshal(*azErr.ServiceError.Details)
if string(b) != `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]` {
if *azErr.ServiceError.Target == "" {
t.Fatalf("azure: error target is not unmarshaled properly")
}
d, _ := json.Marshal(azErr.ServiceError.Details)
if string(d) != `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]` {
t.Fatalf("azure: error details is not unmarshaled properly")
}
i, _ := json.Marshal(azErr.ServiceError.InnerError)
if string(i) != `{"customKey":"customValue"}` {
t.Fatalf("azure: inner error is not unmarshaled properly")
}
a, _ := json.Marshal(azErr.ServiceError.AdditionalInfo)
if string(a) != `[{"info":{"someProperty":"someValue"},"type":"someErrorType"}]` {
t.Fatalf("azure: error additional info is not unmarshaled properly")
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Fatalf("azure: got wrong StatusCode=%v Expected=%d", azErr.StatusCode, expected)
}
@@ -285,7 +304,7 @@ func TestWithErrorUnlessStatusCode_FoundAzureErrorWithDetails(t *testing.T) {
// the error body should still be there
defer r.Body.Close()
b, err = ioutil.ReadAll(r.Body)
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
@@ -297,8 +316,8 @@ func TestWithErrorUnlessStatusCode_FoundAzureErrorWithDetails(t *testing.T) {
func TestWithErrorUnlessStatusCode_NoAzureError(t *testing.T) {
j := `{
"Status":"NotFound"
}`
"Status":"NotFound"
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
@@ -320,6 +339,9 @@ func TestWithErrorUnlessStatusCode_NoAzureError(t *testing.T) {
expected := &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
Details: []map[string]interface{}{
{"Status": "NotFound"},
},
}
if !reflect.DeepEqual(expected, azErr.ServiceError) {
@@ -349,13 +371,14 @@ func TestWithErrorUnlessStatusCode_NoAzureError(t *testing.T) {
func TestWithErrorUnlessStatusCode_UnwrappedError(t *testing.T) {
j := `{
"target": null,
"code": "InternalError",
"message": "Azure is having trouble right now.",
"target": "target1",
"details": [{"code": "conflict1", "message":"error message1"},
{"code": "conflict2", "message":"error message2"}],
"innererror": []
}`
"innererror": { "customKey": "customValue" },
"additionalInfo": [{"type": "someErrorType", "info": {"someProperty": "someValue"}}]
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
@@ -381,7 +404,7 @@ func TestWithErrorUnlessStatusCode_UnwrappedError(t *testing.T) {
t.Fail()
}
if expected := "Azure is having trouble right now."; azErr.ServiceError.Message != expected {
if expected := "Azure is having trouble right now."; azErr.Message != expected {
t.Logf("Incorrect Message\n\tgot: %q\n\twant: %q", azErr.Message, expected)
t.Fail()
}
@@ -391,15 +414,40 @@ func TestWithErrorUnlessStatusCode_UnwrappedError(t *testing.T) {
t.Fail()
}
expectedServiceErrorDetails := `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]`
if azErr.ServiceError == nil {
t.Logf("`ServiceError` was nil when it shouldn't have been.")
t.Fail()
} else if azErr.ServiceError.Details == nil {
}
if expected := "target1"; *azErr.ServiceError.Target != expected {
t.Logf("Incorrect Target\n\tgot: %q\n\twant: %q", *azErr.ServiceError.Target, expected)
t.Fail()
}
expectedServiceErrorDetails := `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]`
if azErr.ServiceError.Details == nil {
t.Logf("`ServiceError.Details` was nil when it should have been %q", expectedServiceErrorDetails)
t.Fail()
} else if details, _ := json.Marshal(*azErr.ServiceError.Details); expectedServiceErrorDetails != string(details) {
t.Logf("Error detaisl was not unmarshaled properly.\n\tgot: %q\n\twant: %q", string(details), expectedServiceErrorDetails)
} else if details, _ := json.Marshal(azErr.ServiceError.Details); expectedServiceErrorDetails != string(details) {
t.Logf("Error details was not unmarshaled properly.\n\tgot: %q\n\twant: %q", string(details), expectedServiceErrorDetails)
t.Fail()
}
expectedServiceErrorInnerError := `{"customKey":"customValue"}`
if azErr.ServiceError.InnerError == nil {
t.Logf("`ServiceError.InnerError` was nil when it should have been %q", expectedServiceErrorInnerError)
t.Fail()
} else if innerError, _ := json.Marshal(azErr.ServiceError.InnerError); expectedServiceErrorInnerError != string(innerError) {
t.Logf("Inner error was not unmarshaled properly.\n\tgot: %q\n\twant: %q", string(innerError), expectedServiceErrorInnerError)
t.Fail()
}
expectedServiceErrorAdditionalInfo := `[{"info":{"someProperty":"someValue"},"type":"someErrorType"}]`
if azErr.ServiceError.AdditionalInfo == nil {
t.Logf("`ServiceError.AdditionalInfo` was nil when it should have been %q", expectedServiceErrorAdditionalInfo)
t.Fail()
} else if additionalInfo, _ := json.Marshal(azErr.ServiceError.AdditionalInfo); expectedServiceErrorAdditionalInfo != string(additionalInfo) {
t.Logf("Additional info was not unmarshaled properly.\n\tgot: %q\n\twant: %q", string(additionalInfo), expectedServiceErrorAdditionalInfo)
t.Fail()
}
@@ -420,7 +468,39 @@ func TestRequestErrorString_WithError(t *testing.T) {
"error": {
"code": "InternalError",
"message": "Conflict",
"details": [{"code": "conflict1", "message":"error message1"}]
"target": "target1",
"details": [{"code": "conflict1", "message":"error message1"}],
"innererror": { "customKey": "customValue" },
"additionalInfo": [{"type": "someErrorType", "info": {"someProperty": "someValue"}}]
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, _ := err.(*RequestError)
expected := "autorest/azure: Service returned an error. Status=500 Code=\"InternalError\" Message=\"Conflict\" Target=\"target1\" Details=[{\"code\":\"conflict1\",\"message\":\"error message1\"}] InnerError={\"customKey\":\"customValue\"} AdditionalInfo=[{\"info\":{\"someProperty\":\"someValue\"},\"type\":\"someErrorType\"}]"
if expected != azErr.Error() {
t.Fatalf("azure: send wrong RequestError.\nexpected=%v\ngot=%v", expected, azErr.Error())
}
}
func TestRequestErrorString_WithErrorNonConforming(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Conflict",
"details": {"code": "conflict1", "message":"error message1"}
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
@@ -444,6 +524,81 @@ func TestRequestErrorString_WithError(t *testing.T) {
}
}
func TestParseResourceID_WithValidBasicResourceID(t *testing.T) {
basicResourceID := "/subscriptions/subid-3-3-4/resourceGroups/regGroupVladdb/providers/Microsoft.Network/LoadBalancer/testResourceName"
want := Resource{
SubscriptionID: "subid-3-3-4",
ResourceGroup: "regGroupVladdb",
Provider: "Microsoft.Network",
ResourceType: "LoadBalancer",
ResourceName: "testResourceName",
}
got, err := ParseResourceID(basicResourceID)
if err != nil {
t.Fatalf("azure: error returned while parsing valid resourceId")
}
if got != want {
t.Logf("got: %+v\nwant: %+v", got, want)
t.Fail()
}
}
func TestParseResourceID_WithValidSubResourceID(t *testing.T) {
subresourceID := "/subscriptions/subid-3-3-4/resourceGroups/regGroupVladdb/providers/Microsoft.Network/LoadBalancer/resource/is/a/subresource/actualresourceName"
want := Resource{
SubscriptionID: "subid-3-3-4",
ResourceGroup: "regGroupVladdb",
Provider: "Microsoft.Network",
ResourceType: "LoadBalancer",
ResourceName: "actualresourceName",
}
got, err := ParseResourceID(subresourceID)
if err != nil {
t.Fatalf("azure: error returned while parsing valid resourceId")
}
if got != want {
t.Logf("got: %+v\nwant: %+v", got, want)
t.Fail()
}
}
func TestParseResourceID_WithIncompleteResourceID(t *testing.T) {
basicResourceID := "/subscriptions/subid-3-3-4/resourceGroups/regGroupVladdb/providers/Microsoft.Network/"
want := Resource{}
got, err := ParseResourceID(basicResourceID)
if err == nil {
t.Fatalf("azure: no error returned on incomplete resource id")
}
if got != want {
t.Logf("got: %+v\nwant: %+v", got, want)
t.Fail()
}
}
func TestParseResourceID_WithMalformedResourceID(t *testing.T) {
malformedResourceID := "/providers/subid-3-3-4/resourceGroups/regGroupVladdb/subscriptions/Microsoft.Network/LoadBalancer/testResourceName"
want := Resource{}
got, err := ParseResourceID(malformedResourceID)
if err == nil {
t.Fatalf("azure: error returned while parsing malformed resourceID")
}
if got != want {
t.Logf("got: %+v\nwant: %+v", got, want)
t.Fail()
}
}
func withErrorPrepareDecorator(e *error) autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {

View File

@@ -38,6 +38,13 @@ type Subscription struct {
Name string `json:"name"`
State string `json:"state"`
TenantID string `json:"tenantId"`
User *User `json:"user"`
}
// User represents a User from the Azure CLI
type User struct {
Name string `json:"name"`
Type string `json:"type"`
}
// ProfilePath returns the path where the Azure Profile is stored from the Azure CLI

View File

@@ -15,9 +15,13 @@ package cli
// limitations under the License.
import (
"bytes"
"encoding/json"
"fmt"
"os"
"os/exec"
"regexp"
"runtime"
"strconv"
"time"
@@ -54,7 +58,7 @@ func (t Token) ToADALToken() (converted adal.Token, err error) {
AccessToken: t.AccessToken,
Type: t.TokenType,
ExpiresIn: "3600",
ExpiresOn: strconv.Itoa(int(difference.Seconds())),
ExpiresOn: json.Number(strconv.Itoa(int(difference.Seconds()))),
RefreshToken: t.RefreshToken,
Resource: t.Resource,
}
@@ -62,8 +66,19 @@ func (t Token) ToADALToken() (converted adal.Token, err error) {
}
// AccessTokensPath returns the path where access tokens are stored from the Azure CLI
// TODO(#199): add unit test.
func AccessTokensPath() (string, error) {
return homedir.Expand("~/.azure/accessTokens.json")
// Azure-CLI allows user to customize the path of access tokens thorugh environment variable.
var accessTokenPath = os.Getenv("AZURE_ACCESS_TOKEN_FILE")
var err error
// Fallback logic to default path on non-cloud-shell environment.
// TODO(#200): remove the dependency on hard-coding path.
if accessTokenPath == "" {
accessTokenPath, err = homedir.Expand("~/.azure/accessTokens.json")
}
return accessTokenPath, err
}
// ParseExpirationDate parses either a Azure CLI or CloudShell date into a time object
@@ -101,3 +116,55 @@ func LoadTokens(path string) ([]Token, error) {
return tokens, nil
}
// GetTokenFromCLI gets a token using Azure CLI 2.0 for local development scenarios.
func GetTokenFromCLI(resource string) (*Token, error) {
// This is the path that a developer can set to tell this class what the install path for Azure CLI is.
const azureCLIPath = "AzureCLIPath"
// The default install paths are used to find Azure CLI. This is for security, so that any path in the calling program's Path environment is not used to execute Azure CLI.
azureCLIDefaultPathWindows := fmt.Sprintf("%s\\Microsoft SDKs\\Azure\\CLI2\\wbin; %s\\Microsoft SDKs\\Azure\\CLI2\\wbin", os.Getenv("ProgramFiles(x86)"), os.Getenv("ProgramFiles"))
// Default path for non-Windows.
const azureCLIDefaultPath = "/usr/bin:/usr/local/bin"
// Validate resource, since it gets sent as a command line argument to Azure CLI
const invalidResourceErrorTemplate = "Resource %s is not in expected format. Only alphanumeric characters, [dot], [colon], [hyphen], and [forward slash] are allowed."
match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource)
if err != nil {
return nil, err
}
if !match {
return nil, fmt.Errorf(invalidResourceErrorTemplate, resource)
}
// Execute Azure CLI to get token
var cliCmd *exec.Cmd
if runtime.GOOS == "windows" {
cliCmd = exec.Command(fmt.Sprintf("%s\\system32\\cmd.exe", os.Getenv("windir")))
cliCmd.Env = os.Environ()
cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s;%s", os.Getenv(azureCLIPath), azureCLIDefaultPathWindows))
cliCmd.Args = append(cliCmd.Args, "/c")
} else {
cliCmd = exec.Command(os.Getenv("SHELL"))
cliCmd.Env = os.Environ()
cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s:%s", os.Getenv(azureCLIPath), azureCLIDefaultPath))
}
cliCmd.Args = append(cliCmd.Args, "az", "account", "get-access-token", "-o", "json", "--resource", resource)
var stderr bytes.Buffer
cliCmd.Stderr = &stderr
output, err := cliCmd.Output()
if err != nil {
return nil, fmt.Errorf("Invoking Azure CLI failed with the following error: %s", stderr.String())
}
tokenResponse := Token{}
err = json.Unmarshal(output, &tokenResponse)
if err != nil {
return nil, err
}
return &tokenResponse, err
}

View File

@@ -15,10 +15,17 @@ package azure
// limitations under the License.
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"strings"
)
// EnvironmentFilepathName captures the name of the environment variable containing the path to the file
// to be used while populating the Azure Environment.
const EnvironmentFilepathName = "AZURE_ENVIRONMENT_FILEPATH"
var environments = map[string]Environment{
"AZURECHINACLOUD": ChinaCloud,
"AZUREGERMANCLOUD": GermanCloud,
@@ -37,6 +44,8 @@ type Environment struct {
GalleryEndpoint string `json:"galleryEndpoint"`
KeyVaultEndpoint string `json:"keyVaultEndpoint"`
GraphEndpoint string `json:"graphEndpoint"`
ServiceBusEndpoint string `json:"serviceBusEndpoint"`
BatchManagementEndpoint string `json:"batchManagementEndpoint"`
StorageEndpointSuffix string `json:"storageEndpointSuffix"`
SQLDatabaseDNSSuffix string `json:"sqlDatabaseDNSSuffix"`
TrafficManagerDNSSuffix string `json:"trafficManagerDNSSuffix"`
@@ -45,6 +54,7 @@ type Environment struct {
ServiceManagementVMDNSSuffix string `json:"serviceManagementVMDNSSuffix"`
ResourceManagerVMDNSSuffix string `json:"resourceManagerVMDNSSuffix"`
ContainerRegistryDNSSuffix string `json:"containerRegistryDNSSuffix"`
TokenAudience string `json:"tokenAudience"`
}
var (
@@ -59,14 +69,17 @@ var (
GalleryEndpoint: "https://gallery.azure.com/",
KeyVaultEndpoint: "https://vault.azure.net/",
GraphEndpoint: "https://graph.windows.net/",
ServiceBusEndpoint: "https://servicebus.windows.net/",
BatchManagementEndpoint: "https://batch.core.windows.net/",
StorageEndpointSuffix: "core.windows.net",
SQLDatabaseDNSSuffix: "database.windows.net",
TrafficManagerDNSSuffix: "trafficmanager.net",
KeyVaultDNSSuffix: "vault.azure.net",
ServiceBusEndpointSuffix: "servicebus.azure.com",
ServiceBusEndpointSuffix: "servicebus.windows.net",
ServiceManagementVMDNSSuffix: "cloudapp.net",
ResourceManagerVMDNSSuffix: "cloudapp.azure.com",
ContainerRegistryDNSSuffix: "azurecr.io",
TokenAudience: "https://management.azure.com/",
}
// USGovernmentCloud is the cloud environment for the US Government
@@ -76,10 +89,12 @@ var (
PublishSettingsURL: "https://manage.windowsazure.us/publishsettings/index",
ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/",
ResourceManagerEndpoint: "https://management.usgovcloudapi.net/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.us/",
GalleryEndpoint: "https://gallery.usgovcloudapi.net/",
KeyVaultEndpoint: "https://vault.usgovcloudapi.net/",
GraphEndpoint: "https://graph.usgovcloudapi.net/",
GraphEndpoint: "https://graph.windows.net/",
ServiceBusEndpoint: "https://servicebus.usgovcloudapi.net/",
BatchManagementEndpoint: "https://batch.core.usgovcloudapi.net/",
StorageEndpointSuffix: "core.usgovcloudapi.net",
SQLDatabaseDNSSuffix: "database.usgovcloudapi.net",
TrafficManagerDNSSuffix: "usgovtrafficmanager.net",
@@ -88,6 +103,7 @@ var (
ServiceManagementVMDNSSuffix: "usgovcloudapp.net",
ResourceManagerVMDNSSuffix: "cloudapp.windowsazure.us",
ContainerRegistryDNSSuffix: "azurecr.io",
TokenAudience: "https://management.usgovcloudapi.net/",
}
// ChinaCloud is the cloud environment operated in China
@@ -101,14 +117,17 @@ var (
GalleryEndpoint: "https://gallery.chinacloudapi.cn/",
KeyVaultEndpoint: "https://vault.azure.cn/",
GraphEndpoint: "https://graph.chinacloudapi.cn/",
ServiceBusEndpoint: "https://servicebus.chinacloudapi.cn/",
BatchManagementEndpoint: "https://batch.chinacloudapi.cn/",
StorageEndpointSuffix: "core.chinacloudapi.cn",
SQLDatabaseDNSSuffix: "database.chinacloudapi.cn",
TrafficManagerDNSSuffix: "trafficmanager.cn",
KeyVaultDNSSuffix: "vault.azure.cn",
ServiceBusEndpointSuffix: "servicebus.chinacloudapi.net",
ServiceBusEndpointSuffix: "servicebus.chinacloudapi.cn",
ServiceManagementVMDNSSuffix: "chinacloudapp.cn",
ResourceManagerVMDNSSuffix: "cloudapp.azure.cn",
ContainerRegistryDNSSuffix: "azurecr.io",
TokenAudience: "https://management.chinacloudapi.cn/",
}
// GermanCloud is the cloud environment operated in Germany
@@ -122,6 +141,8 @@ var (
GalleryEndpoint: "https://gallery.cloudapi.de/",
KeyVaultEndpoint: "https://vault.microsoftazure.de/",
GraphEndpoint: "https://graph.cloudapi.de/",
ServiceBusEndpoint: "https://servicebus.cloudapi.de/",
BatchManagementEndpoint: "https://batch.cloudapi.de/",
StorageEndpointSuffix: "core.cloudapi.de",
SQLDatabaseDNSSuffix: "database.cloudapi.de",
TrafficManagerDNSSuffix: "azuretrafficmanager.de",
@@ -130,15 +151,41 @@ var (
ServiceManagementVMDNSSuffix: "azurecloudapp.de",
ResourceManagerVMDNSSuffix: "cloudapp.microsoftazure.de",
ContainerRegistryDNSSuffix: "azurecr.io",
TokenAudience: "https://management.microsoftazure.de/",
}
)
// EnvironmentFromName returns an Environment based on the common name specified
// EnvironmentFromName returns an Environment based on the common name specified.
func EnvironmentFromName(name string) (Environment, error) {
// IMPORTANT
// As per @radhikagupta5:
// This is technical debt, fundamentally here because Kubernetes is not currently accepting
// contributions to the providers. Once that is an option, the provider should be updated to
// directly call `EnvironmentFromFile`. Until then, we rely on dispatching Azure Stack environment creation
// from this method based on the name that is provided to us.
if strings.EqualFold(name, "AZURESTACKCLOUD") {
return EnvironmentFromFile(os.Getenv(EnvironmentFilepathName))
}
name = strings.ToUpper(name)
env, ok := environments[name]
if !ok {
return env, fmt.Errorf("autorest/azure: There is no cloud environment matching the name %q", name)
}
return env, nil
}
// EnvironmentFromFile loads an Environment from a configuration file available on disk.
// This function is particularly useful in the Hybrid Cloud model, where one must define their own
// endpoints.
func EnvironmentFromFile(location string) (unmarshaled Environment, err error) {
fileContents, err := ioutil.ReadFile(location)
if err != nil {
return
}
err = json.Unmarshal(fileContents, &unmarshaled)
return
}

View File

@@ -17,9 +17,119 @@ package azure
import (
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path"
"path/filepath"
"runtime"
"testing"
)
// This correlates to the expected contents of ./testdata/test_environment_1.json
var testEnvironment1 = Environment{
Name: "--unit-test--",
ManagementPortalURL: "--management-portal-url",
PublishSettingsURL: "--publish-settings-url--",
ServiceManagementEndpoint: "--service-management-endpoint--",
ResourceManagerEndpoint: "--resource-management-endpoint--",
ActiveDirectoryEndpoint: "--active-directory-endpoint--",
GalleryEndpoint: "--gallery-endpoint--",
KeyVaultEndpoint: "--key-vault--endpoint--",
GraphEndpoint: "--graph-endpoint--",
StorageEndpointSuffix: "--storage-endpoint-suffix--",
SQLDatabaseDNSSuffix: "--sql-database-dns-suffix--",
TrafficManagerDNSSuffix: "--traffic-manager-dns-suffix--",
KeyVaultDNSSuffix: "--key-vault-dns-suffix--",
ServiceBusEndpointSuffix: "--service-bus-endpoint-suffix--",
ServiceManagementVMDNSSuffix: "--asm-vm-dns-suffix--",
ResourceManagerVMDNSSuffix: "--arm-vm-dns-suffix--",
ContainerRegistryDNSSuffix: "--container-registry-dns-suffix--",
TokenAudience: "--token-audience",
}
func TestEnvironment_EnvironmentFromURL_NoOverride_Success(t *testing.T) {
fileContents, _ := ioutil.ReadFile(filepath.Join("testdata", "test_metadata_environment_1.json"))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fileContents))
}))
defer ts.Close()
got, err := EnvironmentFromURL(ts.URL)
if err != nil {
t.Error(err)
}
if got.Name != "HybridEnvironment" {
t.Logf("got: %v want: HybridEnvironment", got.Name)
t.Fail()
}
}
func TestEnvironment_EnvironmentFromURL_OverrideStorageSuffix_Success(t *testing.T) {
fileContents, _ := ioutil.ReadFile(filepath.Join("testdata", "test_metadata_environment_1.json"))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fileContents))
}))
defer ts.Close()
overrideProperty := OverrideProperty{
Key: EnvironmentStorageEndpointSuffix,
Value: "fakeStorageSuffix",
}
got, err := EnvironmentFromURL(ts.URL, overrideProperty)
if err != nil {
t.Error(err)
}
if got.StorageEndpointSuffix != "fakeStorageSuffix" {
t.Logf("got: %v want: fakeStorageSuffix", got.StorageEndpointSuffix)
t.Fail()
}
}
func TestEnvironment_EnvironmentFromURL_EmptyEndpoint_Failure(t *testing.T) {
_, err := EnvironmentFromURL("")
if err == nil {
t.Fail()
}
if err.Error() != "Metadata resource manager endpoint is empty" {
t.Fail()
}
}
func TestEnvironment_EnvironmentFromFile(t *testing.T) {
got, err := EnvironmentFromFile(filepath.Join("testdata", "test_environment_1.json"))
if err != nil {
t.Error(err)
}
if got != testEnvironment1 {
t.Logf("got: %v want: %v", got, testEnvironment1)
t.Fail()
}
}
func TestEnvironment_EnvironmentFromName_Stack(t *testing.T) {
_, currentFile, _, _ := runtime.Caller(0)
prevEnvFilepathValue := os.Getenv(EnvironmentFilepathName)
os.Setenv(EnvironmentFilepathName, filepath.Join(path.Dir(currentFile), "testdata", "test_environment_1.json"))
defer os.Setenv(EnvironmentFilepathName, prevEnvFilepathValue)
got, err := EnvironmentFromName("AZURESTACKCLOUD")
if err != nil {
t.Error(err)
}
if got != testEnvironment1 {
t.Logf("got: %v want: %v", got, testEnvironment1)
t.Fail()
}
}
func TestEnvironmentFromName(t *testing.T) {
name := "azurechinacloud"
if env, _ := EnvironmentFromName(name); env != ChinaCloud {
@@ -73,6 +183,7 @@ func TestDeserializeEnvironment(t *testing.T) {
"ActiveDirectoryEndpoint": "--active-directory-endpoint--",
"galleryEndpoint": "--gallery-endpoint--",
"graphEndpoint": "--graph-endpoint--",
"serviceBusEndpoint": "--service-bus-endpoint--",
"keyVaultDNSSuffix": "--key-vault-dns-suffix--",
"keyVaultEndpoint": "--key-vault-endpoint--",
"managementPortalURL": "--management-portal-url--",
@@ -118,6 +229,9 @@ func TestDeserializeEnvironment(t *testing.T) {
if "--key-vault-endpoint--" != testSubject.KeyVaultEndpoint {
t.Errorf("Expected KeyVaultEndpoint to be \"--key-vault-endpoint--\", but got %q", testSubject.KeyVaultEndpoint)
}
if "--service-bus-endpoint--" != testSubject.ServiceBusEndpoint {
t.Errorf("Expected ServiceBusEndpoint to be \"--service-bus-endpoint--\", but goet %q", testSubject.ServiceBusEndpoint)
}
if "--graph-endpoint--" != testSubject.GraphEndpoint {
t.Errorf("Expected GraphEndpoint to be \"--graph-endpoint--\", but got %q", testSubject.GraphEndpoint)
}
@@ -155,6 +269,7 @@ func TestRoundTripSerialization(t *testing.T) {
GalleryEndpoint: "--gallery-endpoint--",
KeyVaultEndpoint: "--key-vault--endpoint--",
GraphEndpoint: "--graph-endpoint--",
ServiceBusEndpoint: "--service-bus-endpoint--",
StorageEndpointSuffix: "--storage-endpoint-suffix--",
SQLDatabaseDNSSuffix: "--sql-database-dns-suffix--",
TrafficManagerDNSSuffix: "--traffic-manager-dns-suffix--",
@@ -197,6 +312,9 @@ func TestRoundTripSerialization(t *testing.T) {
if env.GalleryEndpoint != testSubject.GalleryEndpoint {
t.Errorf("Expected GalleryEndpoint to be %q, but got %q", env.GalleryEndpoint, testSubject.GalleryEndpoint)
}
if env.ServiceBusEndpoint != testSubject.ServiceBusEndpoint {
t.Errorf("Expected ServiceBusEnpoint to be %q, but got %q", env.ServiceBusEndpoint, testSubject.ServiceBusEndpoint)
}
if env.KeyVaultEndpoint != testSubject.KeyVaultEndpoint {
t.Errorf("Expected KeyVaultEndpoint to be %q, but got %q", env.KeyVaultEndpoint, testSubject.KeyVaultEndpoint)
}

View File

@@ -253,7 +253,7 @@ func main() {
// should save it as soon as you get it since Refresh won't be called for some time
if tokenCachePath != "" {
saveToken(spt.Token)
saveToken(spt.Token())
}
}

View File

@@ -0,0 +1,245 @@
package azure
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"github.com/Azure/go-autorest/autorest"
)
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
type audience []string
type authentication struct {
LoginEndpoint string `json:"loginEndpoint"`
Audiences audience `json:"audiences"`
}
type environmentMetadataInfo struct {
GalleryEndpoint string `json:"galleryEndpoint"`
GraphEndpoint string `json:"graphEndpoint"`
PortalEndpoint string `json:"portalEndpoint"`
Authentication authentication `json:"authentication"`
}
// EnvironmentProperty represent property names that clients can override
type EnvironmentProperty string
const (
// EnvironmentName ...
EnvironmentName EnvironmentProperty = "name"
// EnvironmentManagementPortalURL ..
EnvironmentManagementPortalURL EnvironmentProperty = "managementPortalURL"
// EnvironmentPublishSettingsURL ...
EnvironmentPublishSettingsURL EnvironmentProperty = "publishSettingsURL"
// EnvironmentServiceManagementEndpoint ...
EnvironmentServiceManagementEndpoint EnvironmentProperty = "serviceManagementEndpoint"
// EnvironmentResourceManagerEndpoint ...
EnvironmentResourceManagerEndpoint EnvironmentProperty = "resourceManagerEndpoint"
// EnvironmentActiveDirectoryEndpoint ...
EnvironmentActiveDirectoryEndpoint EnvironmentProperty = "activeDirectoryEndpoint"
// EnvironmentGalleryEndpoint ...
EnvironmentGalleryEndpoint EnvironmentProperty = "galleryEndpoint"
// EnvironmentKeyVaultEndpoint ...
EnvironmentKeyVaultEndpoint EnvironmentProperty = "keyVaultEndpoint"
// EnvironmentGraphEndpoint ...
EnvironmentGraphEndpoint EnvironmentProperty = "graphEndpoint"
// EnvironmentServiceBusEndpoint ...
EnvironmentServiceBusEndpoint EnvironmentProperty = "serviceBusEndpoint"
// EnvironmentBatchManagementEndpoint ...
EnvironmentBatchManagementEndpoint EnvironmentProperty = "batchManagementEndpoint"
// EnvironmentStorageEndpointSuffix ...
EnvironmentStorageEndpointSuffix EnvironmentProperty = "storageEndpointSuffix"
// EnvironmentSQLDatabaseDNSSuffix ...
EnvironmentSQLDatabaseDNSSuffix EnvironmentProperty = "sqlDatabaseDNSSuffix"
// EnvironmentTrafficManagerDNSSuffix ...
EnvironmentTrafficManagerDNSSuffix EnvironmentProperty = "trafficManagerDNSSuffix"
// EnvironmentKeyVaultDNSSuffix ...
EnvironmentKeyVaultDNSSuffix EnvironmentProperty = "keyVaultDNSSuffix"
// EnvironmentServiceBusEndpointSuffix ...
EnvironmentServiceBusEndpointSuffix EnvironmentProperty = "serviceBusEndpointSuffix"
// EnvironmentServiceManagementVMDNSSuffix ...
EnvironmentServiceManagementVMDNSSuffix EnvironmentProperty = "serviceManagementVMDNSSuffix"
// EnvironmentResourceManagerVMDNSSuffix ...
EnvironmentResourceManagerVMDNSSuffix EnvironmentProperty = "resourceManagerVMDNSSuffix"
// EnvironmentContainerRegistryDNSSuffix ...
EnvironmentContainerRegistryDNSSuffix EnvironmentProperty = "containerRegistryDNSSuffix"
// EnvironmentTokenAudience ...
EnvironmentTokenAudience EnvironmentProperty = "tokenAudience"
)
// OverrideProperty represents property name and value that clients can override
type OverrideProperty struct {
Key EnvironmentProperty
Value string
}
// EnvironmentFromURL loads an Environment from a URL
// This function is particularly useful in the Hybrid Cloud model, where one may define their own
// endpoints.
func EnvironmentFromURL(resourceManagerEndpoint string, properties ...OverrideProperty) (environment Environment, err error) {
var metadataEnvProperties environmentMetadataInfo
if resourceManagerEndpoint == "" {
return environment, fmt.Errorf("Metadata resource manager endpoint is empty")
}
if metadataEnvProperties, err = retrieveMetadataEnvironment(resourceManagerEndpoint); err != nil {
return environment, err
}
// Give priority to user's override values
overrideProperties(&environment, properties)
if environment.Name == "" {
environment.Name = "HybridEnvironment"
}
stampDNSSuffix := environment.StorageEndpointSuffix
if stampDNSSuffix == "" {
stampDNSSuffix = strings.TrimSuffix(strings.TrimPrefix(strings.Replace(resourceManagerEndpoint, strings.Split(resourceManagerEndpoint, ".")[0], "", 1), "."), "/")
environment.StorageEndpointSuffix = stampDNSSuffix
}
if environment.KeyVaultDNSSuffix == "" {
environment.KeyVaultDNSSuffix = fmt.Sprintf("%s.%s", "vault", stampDNSSuffix)
}
if environment.KeyVaultEndpoint == "" {
environment.KeyVaultEndpoint = fmt.Sprintf("%s%s", "https://", environment.KeyVaultDNSSuffix)
}
if environment.TokenAudience == "" {
environment.TokenAudience = metadataEnvProperties.Authentication.Audiences[0]
}
if environment.ActiveDirectoryEndpoint == "" {
environment.ActiveDirectoryEndpoint = metadataEnvProperties.Authentication.LoginEndpoint
}
if environment.ResourceManagerEndpoint == "" {
environment.ResourceManagerEndpoint = resourceManagerEndpoint
}
if environment.GalleryEndpoint == "" {
environment.GalleryEndpoint = metadataEnvProperties.GalleryEndpoint
}
if environment.GraphEndpoint == "" {
environment.GraphEndpoint = metadataEnvProperties.GraphEndpoint
}
return environment, nil
}
func overrideProperties(environment *Environment, properties []OverrideProperty) {
for _, property := range properties {
switch property.Key {
case EnvironmentName:
{
environment.Name = property.Value
}
case EnvironmentManagementPortalURL:
{
environment.ManagementPortalURL = property.Value
}
case EnvironmentPublishSettingsURL:
{
environment.PublishSettingsURL = property.Value
}
case EnvironmentServiceManagementEndpoint:
{
environment.ServiceManagementEndpoint = property.Value
}
case EnvironmentResourceManagerEndpoint:
{
environment.ResourceManagerEndpoint = property.Value
}
case EnvironmentActiveDirectoryEndpoint:
{
environment.ActiveDirectoryEndpoint = property.Value
}
case EnvironmentGalleryEndpoint:
{
environment.GalleryEndpoint = property.Value
}
case EnvironmentKeyVaultEndpoint:
{
environment.KeyVaultEndpoint = property.Value
}
case EnvironmentGraphEndpoint:
{
environment.GraphEndpoint = property.Value
}
case EnvironmentServiceBusEndpoint:
{
environment.ServiceBusEndpoint = property.Value
}
case EnvironmentBatchManagementEndpoint:
{
environment.BatchManagementEndpoint = property.Value
}
case EnvironmentStorageEndpointSuffix:
{
environment.StorageEndpointSuffix = property.Value
}
case EnvironmentSQLDatabaseDNSSuffix:
{
environment.SQLDatabaseDNSSuffix = property.Value
}
case EnvironmentTrafficManagerDNSSuffix:
{
environment.TrafficManagerDNSSuffix = property.Value
}
case EnvironmentKeyVaultDNSSuffix:
{
environment.KeyVaultDNSSuffix = property.Value
}
case EnvironmentServiceBusEndpointSuffix:
{
environment.ServiceBusEndpointSuffix = property.Value
}
case EnvironmentServiceManagementVMDNSSuffix:
{
environment.ServiceManagementVMDNSSuffix = property.Value
}
case EnvironmentResourceManagerVMDNSSuffix:
{
environment.ResourceManagerVMDNSSuffix = property.Value
}
case EnvironmentContainerRegistryDNSSuffix:
{
environment.ContainerRegistryDNSSuffix = property.Value
}
case EnvironmentTokenAudience:
{
environment.TokenAudience = property.Value
}
}
}
}
func retrieveMetadataEnvironment(endpoint string) (environment environmentMetadataInfo, err error) {
client := autorest.NewClientWithUserAgent("")
managementEndpoint := fmt.Sprintf("%s%s", strings.TrimSuffix(endpoint, "/"), "/metadata/endpoints?api-version=1.0")
req, _ := http.NewRequest("GET", managementEndpoint, nil)
response, err := client.Do(req)
if err != nil {
return environment, err
}
defer response.Body.Close()
jsonResponse, err := ioutil.ReadAll(response.Body)
if err != nil {
return environment, err
}
err = json.Unmarshal(jsonResponse, &environment)
return environment, err
}

View File

@@ -1,3 +1,17 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azure
import (
@@ -30,7 +44,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
return resp, err
}
if resp.StatusCode != http.StatusConflict {
if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration {
return resp, err
}
var re RequestError
@@ -41,25 +55,23 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
if err != nil {
return resp, err
}
err = re
if re.ServiceError != nil && re.ServiceError.Code == "MissingSubscriptionRegistration" {
err = register(client, r, re)
if err != nil {
return resp, fmt.Errorf("failed auto registering Resource Provider: %s", err)
regErr := register(client, r, re)
if regErr != nil {
return resp, fmt.Errorf("failed auto registering Resource Provider: %s. Original error: %s", regErr, err)
}
}
}
return resp, errors.New("failed request and resource provider registration")
return resp, err
})
}
}
func getProvider(re RequestError) (string, error) {
if re.ServiceError != nil {
if re.ServiceError.Details != nil && len(*re.ServiceError.Details) > 0 {
detail := (*re.ServiceError.Details)[0].(map[string]interface{})
return detail["target"].(string), nil
}
if re.ServiceError != nil && len(re.ServiceError.Details) > 0 {
return re.ServiceError.Details[0]["target"].(string), nil
}
return "", errors.New("provider was not found in the response")
}
@@ -103,7 +115,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
if err != nil {
return err
}
req.Cancel = originalReq.Cancel
req = req.WithContext(originalReq.Context())
resp, err := autorest.SendWithSender(client, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
@@ -128,8 +140,8 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
}
// poll for registered provisioning state
now := time.Now()
for err == nil && time.Since(now) < client.PollingDuration {
registrationStartTime := time.Now()
for err == nil && (client.PollingDuration == 0 || (client.PollingDuration != 0 && time.Since(registrationStartTime) < client.PollingDuration)) {
// taken from the resources SDK
// https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L45
preparer := autorest.CreatePreparer(
@@ -142,9 +154,9 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
if err != nil {
return err
}
req.Cancel = originalReq.Cancel
req = req.WithContext(originalReq.Context())
resp, err := autorest.SendWithSender(client.Sender, req,
resp, err := autorest.SendWithSender(client, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
)
if err != nil {
@@ -166,12 +178,12 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
break
}
delayed := autorest.DelayWithRetryAfter(resp, originalReq.Cancel)
if !delayed {
autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Cancel)
delayed := autorest.DelayWithRetryAfter(resp, originalReq.Context().Done())
if !delayed && !autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Context().Done()) {
return originalReq.Context().Err()
}
}
if !(time.Since(now) < client.PollingDuration) {
if client.PollingDuration != 0 && !(time.Since(registrationStartTime) < client.PollingDuration) {
return errors.New("polling for resource provider registration has exceeded the polling duration")
}
return err

View File

@@ -1,7 +1,23 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azure
import (
"context"
"net/http"
"sync"
"testing"
"time"
@@ -65,3 +81,89 @@ func TestDoRetryWithRegistration(t *testing.T) {
t.Fatalf("azure: Sender#DoRetryWithRegistration -- Got: StatusCode %v; Want: StatusCode 200 OK", r.StatusCode)
}
}
func TestDoRetrySkipRegistration(t *testing.T) {
client := mocks.NewSender()
// first response, should retry because it is a transient error
client.AppendResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError))
// response indicates the resource provider has not been registered
client.AppendResponse(mocks.NewResponseWithBodyAndStatus(mocks.NewBody(`{
"error":{
"code":"MissingSubscriptionRegistration",
"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.EventGrid'. See https://aka.ms/rps-not-found for how to register subscriptions.",
"details":[
{
"code":"MissingSubscriptionRegistration",
"target":"Microsoft.EventGrid",
"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.EventGrid'. See https://aka.ms/rps-not-found for how to register subscriptions."
}
]
}
}`), http.StatusConflict, "MissingSubscriptionRegistration"))
req := mocks.NewRequestForURL("https://lol/subscriptions/rofl")
req.Body = mocks.NewBody("lolol")
r, err := autorest.SendWithSender(client, req,
DoRetryWithRegistration(autorest.Client{
PollingDelay: time.Second,
PollingDuration: time.Second * 10,
RetryAttempts: 5,
RetryDuration: time.Second,
Sender: client,
SkipResourceProviderRegistration: true,
}),
)
if err != nil {
t.Fatalf("got error: %v", err)
}
autorest.Respond(r,
autorest.ByDiscardingBody(),
autorest.ByClosing(),
)
if r.StatusCode != http.StatusConflict {
t.Fatalf("azure: Sender#DoRetryWithRegistration -- Got: StatusCode %v; Want: StatusCode 409 Conflict", r.StatusCode)
}
}
func TestDoRetryWithRegistration_CanBeCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
delay := 5 * time.Second
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5)
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
end := time.Now()
var err error
go func() {
req := mocks.NewRequestForURL("https://lol/subscriptions/rofl")
req = req.WithContext(ctx)
req.Body = mocks.NewBody("lolol")
_, err = autorest.SendWithSender(client, req,
DoRetryWithRegistration(autorest.Client{
PollingDelay: time.Second,
PollingDuration: delay,
RetryAttempts: 5,
RetryDuration: time.Second,
Sender: client,
SkipResourceProviderRegistration: true,
}),
)
end = time.Now()
wg.Done()
}()
cancel()
wg.Wait()
time.Sleep(5 * time.Millisecond)
if err == nil {
t.Fatalf("azure: DoRetryWithRegistration didn't cancel")
}
if end.Sub(start) >= delay {
t.Fatalf("azure: DoRetryWithRegistration failed to cancel")
}
}

View File

@@ -0,0 +1,20 @@
{
"name": "--unit-test--",
"managementPortalURL": "--management-portal-url",
"publishSettingsURL": "--publish-settings-url--",
"serviceManagementEndpoint": "--service-management-endpoint--",
"resourceManagerEndpoint": "--resource-management-endpoint--",
"activeDirectoryEndpoint": "--active-directory-endpoint--",
"galleryEndpoint": "--gallery-endpoint--",
"keyVaultEndpoint": "--key-vault--endpoint--",
"graphEndpoint": "--graph-endpoint--",
"storageEndpointSuffix": "--storage-endpoint-suffix--",
"sqlDatabaseDNSSuffix": "--sql-database-dns-suffix--",
"trafficManagerDNSSuffix": "--traffic-manager-dns-suffix--",
"keyVaultDNSSuffix": "--key-vault-dns-suffix--",
"serviceBusEndpointSuffix": "--service-bus-endpoint-suffix--",
"serviceManagementVMDNSSuffix": "--asm-vm-dns-suffix--",
"resourceManagerVMDNSSuffix": "--arm-vm-dns-suffix--",
"containerRegistryDNSSuffix": "--container-registry-dns-suffix--",
"tokenAudience": "--token-audience"
}

View File

@@ -0,0 +1,11 @@
{
"galleryEndpoint":"https://portal.local.azurestack.external:30015/",
"graphEndpoint":"https://graph.windows.net/",
"portalEndpoint":"https://portal.local.azurestack.external/",
"authentication":{
"loginEndpoint":"https://login.windows.net/",
"audiences":[
"https://management.azurestackci04.onmicrosoft.com/3ee899c8-c137-4ce4-b230-619961a09a73"
]
}
}

View File

@@ -22,8 +22,11 @@ import (
"log"
"net/http"
"net/http/cookiejar"
"runtime"
"strings"
"time"
"github.com/Azure/go-autorest/logger"
"github.com/Azure/go-autorest/version"
)
const (
@@ -35,18 +38,12 @@ const (
// DefaultRetryAttempts is number of attempts for retry status codes (5xx).
DefaultRetryAttempts = 3
// DefaultRetryDuration is the duration to wait between retries.
DefaultRetryDuration = 30 * time.Second
)
var (
// defaultUserAgent builds a string containing the Go version, system archityecture and OS,
// and the go-autorest version.
defaultUserAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
runtime.Version(),
runtime.GOARCH,
runtime.GOOS,
Version(),
)
// StatusCodesForRetry are a defined group of status code for which the client will retry
StatusCodesForRetry = []int{
http.StatusRequestTimeout, // 408
@@ -150,6 +147,7 @@ type Client struct {
PollingDelay time.Duration
// PollingDuration sets the maximum polling time after which an error is returned.
// Setting this to zero will use the provided context to control the duration.
PollingDuration time.Duration
// RetryAttempts sets the default number of retry attempts for client.
@@ -163,6 +161,9 @@ type Client struct {
UserAgent string
Jar http.CookieJar
// Set to true to skip attempted registration of resource providers (false by default).
SkipResourceProviderRegistration bool
}
// NewClientWithUserAgent returns an instance of a Client with the UserAgent set to the passed
@@ -172,9 +173,10 @@ func NewClientWithUserAgent(ua string) Client {
PollingDelay: DefaultPollingDelay,
PollingDuration: DefaultPollingDuration,
RetryAttempts: DefaultRetryAttempts,
RetryDuration: 30 * time.Second,
UserAgent: defaultUserAgent,
RetryDuration: DefaultRetryDuration,
UserAgent: version.UserAgent(),
}
c.Sender = c.sender()
c.AddToUserAgent(ua)
return c
}
@@ -196,15 +198,31 @@ func (c Client) Do(r *http.Request) (*http.Response, error) {
r, _ = Prepare(r,
WithUserAgent(c.UserAgent))
}
// NOTE: c.WithInspection() must be last in the list so that it can inspect all preceding operations
r, err := Prepare(r,
c.WithInspection(),
c.WithAuthorization())
c.WithAuthorization(),
c.WithInspection())
if err != nil {
return nil, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed")
var resp *http.Response
if detErr, ok := err.(DetailedError); ok {
// if the authorization failed (e.g. invalid credentials) there will
// be a response associated with the error, be sure to return it.
resp = detErr.Response
}
return resp, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed")
}
logger.Instance.WriteRequest(r, logger.Filter{
Header: func(k string, v []string) (bool, []string) {
// remove the auth token from the log
if strings.EqualFold(k, "Authorization") || strings.EqualFold(k, "Ocp-Apim-Subscription-Key") {
v = []string{"**REDACTED**"}
}
return true, v
},
})
resp, err := SendWithSender(c.sender(), r)
Respond(resp,
c.ByInspecting())
logger.Instance.WriteResponse(resp, logger.Filter{})
Respond(resp, c.ByInspecting())
return resp, err
}

View File

@@ -21,11 +21,13 @@ import (
"log"
"math/rand"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/mocks"
"github.com/Azure/go-autorest/version"
)
func TestLoggingInspectorWithInspection(t *testing.T) {
@@ -125,7 +127,7 @@ func TestLoggingInspectorByInspectingRestoresBody(t *testing.T) {
func TestNewClientWithUserAgent(t *testing.T) {
ua := "UserAgent"
c := NewClientWithUserAgent(ua)
completeUA := fmt.Sprintf("%s %s", defaultUserAgent, ua)
completeUA := fmt.Sprintf("%s %s", version.UserAgent(), ua)
if c.UserAgent != completeUA {
t.Fatalf("autorest: NewClientWithUserAgent failed to set the UserAgent -- expected %s, received %s",
@@ -141,7 +143,7 @@ func TestAddToUserAgent(t *testing.T) {
if err != nil {
t.Fatalf("autorest: AddToUserAgent returned error -- expected nil, received %s", err)
}
completeUA := fmt.Sprintf("%s %s %s", defaultUserAgent, ua, ext)
completeUA := fmt.Sprintf("%s %s %s", version.UserAgent(), ua, ext)
if c.UserAgent != completeUA {
t.Fatalf("autorest: AddToUserAgent failed to add an extension to the UserAgent -- expected %s, received %s",
@@ -345,6 +347,51 @@ func TestClientByInspectingSetsDefault(t *testing.T) {
}
}
func TestCookies(t *testing.T) {
second := "second"
expected := http.Cookie{
Name: "tastes",
Value: "delicious",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &expected)
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: ioutil.ReadAll failed reading request body: %s", err)
}
if string(b) == second {
cookie, err := r.Cookie(expected.Name)
if err != nil {
t.Fatalf("autorest: r.Cookie could not get request cookie: %s", err)
}
if cookie == nil {
t.Fatalf("autorest: got nil cookie, expecting %v", expected)
}
if cookie.Value != expected.Value {
t.Fatalf("autorest: got cookie value '%s', expecting '%s'", cookie.Value, expected.Name)
}
}
}))
defer server.Close()
client := NewClientWithUserAgent("")
_, err := SendWithSender(client, mocks.NewRequestForURL(server.URL))
if err != nil {
t.Fatalf("autorest: first request failed: %s", err)
}
r2, err := http.NewRequest(http.MethodGet, server.URL, mocks.NewBody(second))
if err != nil {
t.Fatalf("autorest: failed creating second request: %s", err)
}
_, err = SendWithSender(client, r2)
if err != nil {
t.Fatalf("autorest: second request failed: %s", err)
}
}
func randomString(n int) string {
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
r := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))

View File

@@ -16,6 +16,7 @@ package mocks
import (
"fmt"
"io"
"net/http"
"time"
)
@@ -72,7 +73,12 @@ func NewRequestWithCloseBodyContent(c string) *http.Request {
// NewRequestForURL instantiates a new request using the passed URL.
func NewRequestForURL(u string) *http.Request {
r, err := http.NewRequest("GET", u, NewBody(""))
return NewRequestWithParams("GET", u, NewBody(""))
}
// NewRequestWithParams instantiates a new request using the provided parameters.
func NewRequestWithParams(method, u string, body io.Reader) *http.Request {
r, err := http.NewRequest(method, u, body)
if err != nil {
panic(fmt.Sprintf("mocks: ERROR (%v) parsing testing URL %s", err, u))
}
@@ -111,6 +117,7 @@ func NewResponseWithStatus(s string, c int) *http.Response {
func NewResponseWithBodyAndStatus(body *Body, c int, s string) *http.Response {
resp := NewResponse()
resp.Body = body
resp.ContentLength = body.Length()
resp.Status = s
resp.StatusCode = c
return resp

View File

@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"net/http"
"time"
)
// Body implements acceptable body over a string.
@@ -79,10 +80,25 @@ func (body *Body) reset() *Body {
return body
}
// Length returns the number of bytes in the body.
func (body *Body) Length() int64 {
if body == nil {
return 0
}
return int64(len(body.b))
}
type response struct {
r *http.Response
e error
d time.Duration
}
// Sender implements a simple null sender.
type Sender struct {
attempts int
responses []*http.Response
responses []response
numResponses int
repeatResponse []int
err error
repeatError int
@@ -99,12 +115,15 @@ func (c *Sender) Do(r *http.Request) (resp *http.Response, err error) {
c.attempts++
if len(c.responses) > 0 {
resp = c.responses[0]
resp = c.responses[0].r
if resp != nil {
if b, ok := resp.Body.(*Body); ok {
b.reset()
}
} else {
err = c.responses[0].e
}
time.Sleep(c.responses[0].d)
c.repeatResponse[0]--
if c.repeatResponse[0] == 0 {
c.responses = c.responses[1:]
@@ -135,16 +154,43 @@ func (c *Sender) AppendResponse(resp *http.Response) {
c.AppendAndRepeatResponse(resp, 1)
}
// AppendResponseWithDelay adds the passed http.Response to the response stack with the specified delay.
func (c *Sender) AppendResponseWithDelay(resp *http.Response, delay time.Duration) {
c.AppendAndRepeatResponseWithDelay(resp, delay, 1)
}
// AppendAndRepeatResponse adds the passed http.Response to the response stack along with a
// repeat count. A negative repeat count will return the response for all remaining calls to Do.
func (c *Sender) AppendAndRepeatResponse(resp *http.Response, repeat int) {
c.appendAndRepeat(response{r: resp}, repeat)
}
// AppendAndRepeatResponseWithDelay adds the passed http.Response to the response stack with the specified
// delay along with a repeat count. A negative repeat count will return the response for all remaining calls to Do.
func (c *Sender) AppendAndRepeatResponseWithDelay(resp *http.Response, delay time.Duration, repeat int) {
c.appendAndRepeat(response{r: resp, d: delay}, repeat)
}
// AppendError adds the passed error to the response stack.
func (c *Sender) AppendError(err error) {
c.AppendAndRepeatError(err, 1)
}
// AppendAndRepeatError adds the passed error to the response stack along with a repeat
// count. A negative repeat count will return the response for all remaining calls to Do.
func (c *Sender) AppendAndRepeatError(err error, repeat int) {
c.appendAndRepeat(response{e: err}, repeat)
}
func (c *Sender) appendAndRepeat(resp response, repeat int) {
if c.responses == nil {
c.responses = []*http.Response{resp}
c.responses = []response{resp}
c.repeatResponse = []int{repeat}
} else {
c.responses = append(c.responses, resp)
c.repeatResponse = append(c.repeatResponse, repeat)
}
c.numResponses++
}
// Attempts returns the number of times Do was called.
@@ -169,6 +215,11 @@ func (c *Sender) SetEmitErrorAfter(ea int) {
c.emitErrorAfter = ea
}
// NumResponses returns the number of responses that have been added to the sender.
func (c *Sender) NumResponses() int {
return c.numResponses
}
// T is a simple testing struct.
type T struct {
Name string `json:"name" xml:"Name"`

View File

@@ -27,8 +27,9 @@ import (
)
const (
mimeTypeJSON = "application/json"
mimeTypeFormPost = "application/x-www-form-urlencoded"
mimeTypeJSON = "application/json"
mimeTypeOctetStream = "application/octet-stream"
mimeTypeFormPost = "application/x-www-form-urlencoded"
headerAuthorization = "Authorization"
headerContentType = "Content-Type"
@@ -112,6 +113,28 @@ func WithHeader(header string, value string) PrepareDecorator {
}
}
// WithHeaders returns a PrepareDecorator that sets the specified HTTP headers of the http.Request to
// the passed value. It canonicalizes the passed headers name (via http.CanonicalHeaderKey) before
// adding them.
func WithHeaders(headers map[string]interface{}) PrepareDecorator {
h := ensureValueStrings(headers)
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.Header == nil {
r.Header = make(http.Header)
}
for name, value := range h {
r.Header.Set(http.CanonicalHeaderKey(name), value)
}
}
return r, err
})
}
}
// WithBearerAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the supplied token.
func WithBearerAuthorization(token string) PrepareDecorator {
@@ -142,6 +165,11 @@ func AsJSON() PrepareDecorator {
return AsContentType(mimeTypeJSON)
}
// AsOctetStream returns a PrepareDecorator that adds the "application/octet-stream" Content-Type header.
func AsOctetStream() PrepareDecorator {
return AsContentType(mimeTypeOctetStream)
}
// WithMethod returns a PrepareDecorator that sets the HTTP method of the passed request. The
// decorator does not validate that the passed method string is a known HTTP method.
func WithMethod(method string) PrepareDecorator {
@@ -215,6 +243,11 @@ func WithFormData(v url.Values) PrepareDecorator {
r, err := p.Prepare(r)
if err == nil {
s := v.Encode()
if r.Header == nil {
r.Header = make(http.Header)
}
r.Header.Set(http.CanonicalHeaderKey(headerContentType), mimeTypeFormPost)
r.ContentLength = int64(len(s))
r.Body = ioutil.NopCloser(strings.NewReader(s))
}
@@ -430,11 +463,16 @@ func WithQueryParameters(queryParameters map[string]interface{}) PrepareDecorato
if r.URL == nil {
return r, NewError("autorest", "WithQueryParameters", "Invoked with a nil URL")
}
v := r.URL.Query()
for key, value := range parameters {
v.Add(key, value)
d, err := url.QueryUnescape(value)
if err != nil {
return r, err
}
v.Add(key, d)
}
r.URL.RawQuery = createQuery(v)
r.URL.RawQuery = v.Encode()
}
return r, err
})

View File

@@ -483,7 +483,7 @@ func TestPrepareWithNullRequest(t *testing.T) {
}
}
func TestWithFormDataSetsContentLength(t *testing.T) {
func TestWithFormData(t *testing.T) {
v := url.Values{}
v.Add("name", "Rob Pike")
v.Add("age", "42")
@@ -507,6 +507,10 @@ func TestWithFormDataSetsContentLength(t *testing.T) {
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithFormData set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
if expected, got := r.Header.Get(http.CanonicalHeaderKey(headerContentType)), mimeTypeFormPost; expected != got {
t.Fatalf("autorest:WithFormData Content Type not set or set to wrong value. Expected %v and got %v", expected, got)
}
}
func TestWithMultiPartFormDataSetsContentLength(t *testing.T) {
@@ -760,7 +764,39 @@ func TestModifyingExistingRequest(t *testing.T) {
if err != nil {
t.Fatalf("autorest: Preparing an existing request returned an error (%v)", err)
}
if r.URL.String() != "https:/search?q=golang" && r.URL.Host != "bing.com" {
t.Fatalf("autorest: Preparing an existing request failed (%s)", r.URL)
if r.URL.Host != "bing.com" {
t.Fatalf("autorest: Preparing an existing request failed when setting the host (%s)", r.URL)
}
if r.URL.Path != "/search" {
t.Fatalf("autorest: Preparing an existing request failed when setting the path (%s)", r.URL.Path)
}
if r.URL.RawQuery != "q=golang" {
t.Fatalf("autorest: Preparing an existing request failed when setting the query parameters (%s)", r.URL.RawQuery)
}
}
func TestModifyingRequestWithExistingQueryParameters(t *testing.T) {
r, err := Prepare(
mocks.NewRequestForURL("https://bing.com"),
WithPath("search"),
WithQueryParameters(map[string]interface{}{"q": "golang the best"}),
WithQueryParameters(map[string]interface{}{"pq": "golang+encoded"}),
)
if err != nil {
t.Fatalf("autorest: Preparing an existing request returned an error (%v)", err)
}
if r.URL.Host != "bing.com" {
t.Fatalf("autorest: Preparing an existing request failed when setting the host (%s)", r.URL)
}
if r.URL.Path != "/search" {
t.Fatalf("autorest: Preparing an existing request failed when setting the path (%s)", r.URL.Path)
}
if r.URL.RawQuery != "pq=golang+encoded&q=golang+the+best" {
t.Fatalf("autorest: Preparing an existing request failed when setting the query parameters (%s)", r.URL.RawQuery)
}
}

View File

@@ -86,7 +86,7 @@ func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*ht
func AfterDelay(d time.Duration) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if !DelayForBackoff(d, 0, r.Cancel) {
if !DelayForBackoff(d, 0, r.Context().Done()) {
return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay")
}
return s.Do(r)
@@ -165,7 +165,7 @@ func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ...
resp, err = s.Do(r)
if err == nil && ResponseHasStatusCode(resp, codes...) {
r, err = NewPollingRequest(resp, r.Cancel)
r, err = NewPollingRequestWithContext(r.Context(), resp)
for err == nil && ResponseHasStatusCode(resp, codes...) {
Respond(resp,
@@ -198,7 +198,9 @@ func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
if err == nil {
return resp, err
}
DelayForBackoff(backoff, attempt, r.Cancel)
if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
return nil, r.Context().Err()
}
}
return resp, err
})
@@ -215,18 +217,29 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se
rr := NewRetriableRequest(r)
// Increment to add the first call (attempts denotes number of retries)
attempts++
for attempt := 0; attempt < attempts; attempt++ {
for attempt := 0; attempt < attempts; {
err = rr.Prepare()
if err != nil {
return resp, err
}
resp, err = s.Do(rr.Request())
if err != nil || !ResponseHasStatusCode(resp, codes...) {
// if the error isn't temporary don't bother retrying
if err != nil && !IsTemporaryNetworkError(err) {
return nil, err
}
// we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
return resp, err
}
delayed := DelayWithRetryAfter(resp, r.Cancel)
if !delayed {
DelayForBackoff(backoff, attempt, r.Cancel)
delayed := DelayWithRetryAfter(resp, r.Context().Done())
if !delayed && !DelayForBackoff(backoff, attempt, r.Context().Done()) {
return resp, r.Context().Err()
}
// don't count a 429 against the number of attempts
// so that we continue to retry until it succeeds
if resp == nil || resp.StatusCode != http.StatusTooManyRequests {
attempt++
}
}
return resp, err
@@ -237,6 +250,9 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se
// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header in
// responses with status code 429
func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool {
if resp == nil {
return false
}
retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After"))
if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 {
select {
@@ -267,7 +283,9 @@ func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
if err == nil {
return resp, err
}
DelayForBackoff(backoff, attempt, r.Cancel)
if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
return nil, r.Context().Err()
}
}
return resp, err
})

View File

@@ -16,6 +16,7 @@ package autorest
import (
"bytes"
"context"
"fmt"
"log"
"net/http"
@@ -179,24 +180,30 @@ func TestAfterDelayWaits(t *testing.T) {
func TestAfterDelay_Cancels(t *testing.T) {
client := mocks.NewSender()
cancel := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
delay := 5 * time.Second
var wg sync.WaitGroup
wg.Add(1)
tt := time.Now()
start := time.Now()
end := time.Now()
var err error
go func() {
req := mocks.NewRequest()
req.Cancel = cancel
wg.Done()
SendWithSender(client, req,
req = req.WithContext(ctx)
_, err = SendWithSender(client, req,
AfterDelay(delay))
end = time.Now()
wg.Done()
}()
cancel()
wg.Wait()
close(cancel)
time.Sleep(5 * time.Millisecond)
if time.Since(tt) >= delay {
t.Fatal("autorest: AfterDelay failed to cancel")
if end.Sub(start) >= delay {
t.Fatal("autorest: AfterDelay elapsed")
}
if err == nil {
t.Fatal("autorest: AfterDelay didn't cancel")
}
}
@@ -781,7 +788,7 @@ func newAcceptedResponse() *http.Response {
}
func TestDelayWithRetryAfterWithSuccess(t *testing.T) {
after, retries := 5, 2
after, retries := 2, 2
totalSecs := after * retries
client := mocks.NewSender()
@@ -793,7 +800,7 @@ func TestDelayWithRetryAfterWithSuccess(t *testing.T) {
d := time.Second * time.Duration(totalSecs)
start := time.Now()
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(5, time.Duration(time.Second), http.StatusTooManyRequests),
DoRetryForStatusCodes(1, time.Duration(time.Second), http.StatusTooManyRequests),
)
if time.Since(start) < d {
@@ -809,3 +816,116 @@ func TestDelayWithRetryAfterWithSuccess(t *testing.T) {
r.Status, client.Attempts()-1)
}
}
type temporaryError struct {
message string
}
func (te temporaryError) Error() string {
return te.message
}
func (te temporaryError) Timeout() bool {
return true
}
func (te temporaryError) Temporary() bool {
return true
}
func TestDoRetryForStatusCodes_NilResponseTemporaryError(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(nil)
client.SetError(temporaryError{message: "faux error"})
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(3, time.Duration(1*time.Second), StatusCodesForRetry...),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if err != nil || client.Attempts() != 2 {
t.Fatalf("autorest: Sender#TestDoRetryForStatusCodes_NilResponseTemporaryError -- Got: non-nil error or wrong number of attempts - %v", err)
}
}
func TestDoRetryForStatusCodes_NilResponseTemporaryError2(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(nil)
client.SetError(fmt.Errorf("faux error"))
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(3, time.Duration(1*time.Second), StatusCodesForRetry...),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if err != nil || client.Attempts() != 2 {
t.Fatalf("autorest: Sender#TestDoRetryForStatusCodes_NilResponseTemporaryError2 -- Got: nil error or wrong number of attempts - %v", err)
}
}
type fatalError struct {
message string
}
func (fe fatalError) Error() string {
return fe.message
}
func (fe fatalError) Timeout() bool {
return false
}
func (fe fatalError) Temporary() bool {
return false
}
func TestDoRetryForStatusCodes_NilResponseFatalError(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(nil)
client.SetError(fatalError{"fatal error"})
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(3, time.Duration(1*time.Second), StatusCodesForRetry...),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if err == nil || client.Attempts() > 1 {
t.Fatalf("autorest: Sender#TestDoRetryForStatusCodes_NilResponseFatalError -- Got: nil error or wrong number of attempts - %v", err)
}
}
func TestDoRetryForStatusCodes_Cancel429(t *testing.T) {
retries := 6
client := mocks.NewSender()
resp := mocks.NewResponseWithStatus("429 Too many requests", http.StatusTooManyRequests)
client.AppendAndRepeatResponse(resp, retries)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(retries/2)*time.Second)
defer cancel()
req := mocks.NewRequest().WithContext(ctx)
r, err := SendWithSender(client, req,
DoRetryForStatusCodes(1, time.Duration(time.Second), http.StatusTooManyRequests),
)
if err == nil {
t.Fatal("unexpected nil-error")
}
if r == nil {
t.Fatal("unexpected nil response")
}
if r.StatusCode != http.StatusTooManyRequests {
t.Fatalf("expected status code 429, got: %d", r.StatusCode)
}
if client.Attempts() >= retries {
t.Fatalf("too many attemps: %d", client.Attempts())
}
}

View File

@@ -20,10 +20,13 @@ import (
"encoding/xml"
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
"sort"
"strings"
"github.com/Azure/go-autorest/autorest/adal"
)
// EncodedAs is a series of constants specifying various data encodings
@@ -137,13 +140,38 @@ func MapToValues(m map[string]interface{}) url.Values {
return v
}
// String method converts interface v to string. If interface is a list, it
// joins list elements using separator.
func String(v interface{}, sep ...string) string {
if len(sep) > 0 {
return ensureValueString(strings.Join(v.([]string), sep[0]))
// AsStringSlice method converts interface{} to []string. This expects a
//that the parameter passed to be a slice or array of a type that has the underlying
//type a string.
func AsStringSlice(s interface{}) ([]string, error) {
v := reflect.ValueOf(s)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, NewError("autorest", "AsStringSlice", "the value's type is not an array.")
}
return ensureValueString(v)
stringSlice := make([]string, 0, v.Len())
for i := 0; i < v.Len(); i++ {
stringSlice = append(stringSlice, v.Index(i).String())
}
return stringSlice, nil
}
// String method converts interface v to string. If interface is a list, it
// joins list elements using the seperator. Note that only sep[0] will be used for
// joining if any separator is specified.
func String(v interface{}, sep ...string) string {
if len(sep) == 0 {
return ensureValueString(v)
}
stringSlice, ok := v.([]string)
if ok == false {
var err error
stringSlice, err = AsStringSlice(v)
if err != nil {
panic(fmt.Sprintf("autorest: Couldn't convert value to a string %s.", err))
}
}
return ensureValueString(strings.Join(stringSlice, sep[0]))
}
// Encode method encodes url path and query parameters.
@@ -167,26 +195,34 @@ func queryEscape(s string) string {
return url.QueryEscape(s)
}
// This method is same as Encode() method of "net/url" go package,
// except it does not encode the query parameters because they
// already come encoded. It formats values map in query format (bar=foo&a=b).
func createQuery(v url.Values) string {
var buf bytes.Buffer
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := v[k]
prefix := url.QueryEscape(k) + "="
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(prefix)
buf.WriteString(v)
}
}
return buf.String()
// ChangeToGet turns the specified http.Request into a GET (it assumes it wasn't).
// This is mainly useful for long-running operations that use the Azure-AsyncOperation
// header, so we change the initial PUT into a GET to retrieve the final result.
func ChangeToGet(req *http.Request) *http.Request {
req.Method = "GET"
req.Body = nil
req.ContentLength = 0
req.Header.Del("Content-Length")
return req
}
// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError
// interface. If err is a DetailedError it will walk the chain of Original errors.
func IsTokenRefreshError(err error) bool {
if _, ok := err.(adal.TokenRefreshError); ok {
return true
}
if de, ok := err.(DetailedError); ok {
return IsTokenRefreshError(de.Original)
}
return false
}
// IsTemporaryNetworkError returns true if the specified error is a temporary network error or false
// if it's not. If the error doesn't implement the net.Error interface the return value is true.
func IsTemporaryNetworkError(err error) bool {
if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
return true
}
return false
}

View File

@@ -18,6 +18,7 @@ import (
"bytes"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"net/http"
"net/url"
@@ -229,11 +230,51 @@ func ExampleString() {
func TestStringWithValidString(t *testing.T) {
i := 123
if String(i) != "123" {
t.Fatal("autorest: String method failed to convert integer 123 to string")
if got, want := String(i), "123"; got != want {
t.Logf("got: %q\nwant: %q", got, want)
t.Fail()
}
}
func TestStringWithStringSlice(t *testing.T) {
s := []string{"string1", "string2"}
if got, want := String(s, ","), "string1,string2"; got != want {
t.Logf("got: %q\nwant: %q", got, want)
t.Fail()
}
}
func TestStringWithEnum(t *testing.T) {
type TestEnumType string
s := TestEnumType("string1")
if got, want := String(s), "string1"; got != want {
t.Logf("got: %q\nwant: %q", got, want)
t.Fail()
}
}
func TestStringWithEnumSlice(t *testing.T) {
type TestEnumType string
s := []TestEnumType{"string1", "string2"}
if got, want := String(s, ","), "string1,string2"; got != want {
t.Logf("got: %q\nwant: %q", got, want)
t.Fail()
}
}
func ExampleAsStringSlice() {
type TestEnumType string
a := []TestEnumType{"value1", "value2"}
b, _ := AsStringSlice(a)
for _, c := range b {
fmt.Println(c)
}
// Output:
// value1
// value2
}
func TestEncodeWithValidPath(t *testing.T) {
s := Encode("Path", "Hello Gopher")
if s != "Hello%20Gopher" {
@@ -286,6 +327,45 @@ func TestMapToValuesWithArrayValues(t *testing.T) {
}
}
type someTempError struct{}
func (s someTempError) Error() string {
return "temporary error"
}
func (s someTempError) Timeout() bool {
return true
}
func (s someTempError) Temporary() bool {
return true
}
func TestIsTemporaryNetworkErrorTrue(t *testing.T) {
if !IsTemporaryNetworkError(someTempError{}) {
t.Fatal("expected someTempError to be a temporary network error")
}
if !IsTemporaryNetworkError(errors.New("non-temporary network error")) {
t.Fatal("expected random error to be a temporary network error")
}
}
type someFatalError struct{}
func (s someFatalError) Error() string {
return "fatal error"
}
func (s someFatalError) Timeout() bool {
return false
}
func (s someFatalError) Temporary() bool {
return false
}
func TestIsTemporaryNetworkErrorFalse(t *testing.T) {
if IsTemporaryNetworkError(someFatalError{}) {
t.Fatal("expected someFatalError to be a fatal network error")
}
}
func isEqual(v, u url.Values) bool {
for key, value := range v {
if len(u[key]) == 0 {

View File

@@ -26,6 +26,9 @@ import (
// GetAuthorizer gets an Azure Service Principal authorizer.
// This func assumes "AZURE_TENANT_ID", "AZURE_CLIENT_ID",
// "AZURE_CLIENT_SECRET" are set as environment variables.
//
// Deprecated: Use ClientCredentialsConfig.Authorizer() from the
// github.com/Azure/go-autorest/autorest/azure/auth package instead.
func GetAuthorizer(env azure.Environment) (*autorest.BearerAuthorizer, error) {
tenantID := GetEnvVarOrExit("AZURE_TENANT_ID")
@@ -46,6 +49,8 @@ func GetAuthorizer(env azure.Environment) (*autorest.BearerAuthorizer, error) {
}
// GetEnvVarOrExit returns the value of specified environment variable or terminates if it's not defined.
//
// Deprecated: This has been supreseeded by the github.com/Azure/go-autorest/autorest/azure/auth package.
func GetEnvVarOrExit(varName string) string {
value := os.Getenv(varName)
if value == "" {

View File

@@ -0,0 +1,48 @@
package validation
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
)
// Error is the type that's returned when the validation of an APIs arguments constraints fails.
type Error struct {
// PackageType is the package type of the object emitting the error. For types, the value
// matches that produced the the '%T' format specifier of the fmt package. For other elements,
// such as functions, it is just the package name (e.g., "autorest").
PackageType string
// Method is the name of the method raising the error.
Method string
// Message is the error message.
Message string
}
// Error returns a string containing the details of the validation failure.
func (e Error) Error() string {
return fmt.Sprintf("%s#%s: Invalid input: %s", e.PackageType, e.Method, e.Message)
}
// NewError creates a new Error object with the specified parameters.
// message is treated as a format string to which the optional args apply.
func NewError(packageType string, method string, message string, args ...interface{}) Error {
return Error{
PackageType: packageType,
Method: method,
Message: fmt.Sprintf(message, args...),
}
}

View File

@@ -136,29 +136,29 @@ func validatePtr(x reflect.Value, v Constraint) error {
func validateInt(x reflect.Value, v Constraint) error {
i := x.Int()
r, ok := v.Rule.(int)
r, ok := toInt64(v.Rule)
if !ok {
return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule))
}
switch v.Name {
case MultipleOf:
if i%int64(r) != 0 {
if i%r != 0 {
return createError(x, v, fmt.Sprintf("value must be a multiple of %v", r))
}
case ExclusiveMinimum:
if i <= int64(r) {
if i <= r {
return createError(x, v, fmt.Sprintf("value must be greater than %v", r))
}
case ExclusiveMaximum:
if i >= int64(r) {
if i >= r {
return createError(x, v, fmt.Sprintf("value must be less than %v", r))
}
case InclusiveMinimum:
if i < int64(r) {
if i < r {
return createError(x, v, fmt.Sprintf("value must be greater than or equal to %v", r))
}
case InclusiveMaximum:
if i > int64(r) {
if i > r {
return createError(x, v, fmt.Sprintf("value must be less than or equal to %v", r))
}
default:
@@ -388,8 +388,21 @@ func createError(x reflect.Value, v Constraint, err string) error {
v.Target, v.Name, getInterfaceValue(x), err)
}
func toInt64(v interface{}) (int64, bool) {
if i64, ok := v.(int64); ok {
return i64, true
}
// older generators emit max constants as int, so if int64 fails fall back to int
if i32, ok := v.(int); ok {
return int64(i32), true
}
return 0, false
}
// NewErrorWithValidationError appends package type and method name in
// validation error.
//
// Deprecated: Please use validation.NewError() instead.
func NewErrorWithValidationError(err error, packageType, method string) error {
return fmt.Errorf("%s#%s: Invalid input: %v", packageType, method, err)
return NewError(packageType, method, err.Error())
}

View File

@@ -1090,7 +1090,7 @@ func TestValidateInt_inclusiveMaximumConstraintValid(t *testing.T) {
c := Constraint{
Target: "str",
Name: InclusiveMaximum,
Rule: 2,
Rule: int64(2),
Chain: nil,
}
require.Nil(t, validateInt(reflect.ValueOf(1), c))
@@ -1101,7 +1101,7 @@ func TestValidateInt_inclusiveMaximumConstraintInvalid(t *testing.T) {
c := Constraint{
Target: "str",
Name: InclusiveMaximum,
Rule: 1,
Rule: int64(1),
Chain: nil,
}
require.Equal(t, validateInt(reflect.ValueOf(x), c).Error(),
@@ -1112,7 +1112,7 @@ func TestValidateInt_inclusiveMaximumConstraintBoundary(t *testing.T) {
c := Constraint{
Target: "str",
Name: InclusiveMaximum,
Rule: 1,
Rule: int64(1),
Chain: nil,
}
require.Nil(t, validateInt(reflect.ValueOf(1), c))
@@ -2417,6 +2417,21 @@ func TestValidate_StructInStruct(t *testing.T) {
require.Nil(t, Validate(v))
}
func TestNewError(t *testing.T) {
p := &Product{}
v := []Validation{
{p, []Constraint{{"p", Null, true,
[]Constraint{{"p.C", Null, true,
[]Constraint{{"p.C.I", Empty, true, nil}}},
},
}}},
}
err := createError(reflect.ValueOf(p.C), v[0].Constraints[0].Chain[0], "value can not be null; required parameter")
z := fmt.Sprintf("batch.AccountClient#Create: Invalid input: %s",
err.Error())
require.Equal(t, NewError("batch.AccountClient", "Create", err.Error()).Error(), z)
}
func TestNewErrorWithValidationError(t *testing.T) {
p := &Product{}
v := []Validation{
@@ -2429,5 +2444,7 @@ func TestNewErrorWithValidationError(t *testing.T) {
err := createError(reflect.ValueOf(p.C), v[0].Constraints[0].Chain[0], "value can not be null; required parameter")
z := fmt.Sprintf("batch.AccountClient#Create: Invalid input: %s",
err.Error())
require.Equal(t, NewErrorWithValidationError(err, "batch.AccountClient", "Create").Error(), z)
valError := NewErrorWithValidationError(err, "batch.AccountClient", "Create")
require.IsType(t, valError, Error{})
require.Equal(t, valError.Error(), z)
}

View File

@@ -1,5 +1,7 @@
package autorest
import "github.com/Azure/go-autorest/version"
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,36 +16,7 @@ package autorest
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"strings"
"sync"
)
const (
major = 8
minor = 0
patch = 0
tag = ""
)
var once sync.Once
var version string
// Version returns the semantic version (see http://semver.org).
func Version() string {
once.Do(func() {
semver := fmt.Sprintf("%d.%d.%d", major, minor, patch)
verBuilder := bytes.NewBufferString(semver)
if tag != "" && tag != "-" {
updated := strings.TrimPrefix(tag, "-")
_, err := verBuilder.WriteString("-" + updated)
if err == nil {
verBuilder = bytes.NewBufferString(semver)
}
}
version = verBuilder.String()
})
return version
return version.Number
}

View File

@@ -0,0 +1,50 @@
pool:
vmImage: 'Ubuntu 16.04'
variables:
GOROOT: '/usr/local/go1.12'
GOPATH: '$(system.defaultWorkingDirectory)/work'
sdkPath: '$(GOPATH)/src/github.com/$(build.repository.name)'
steps:
- script: |
set -e
mkdir -p '$(GOPATH)/bin'
mkdir -p '$(sdkPath)'
shopt -s extglob
mv !(work) '$(sdkPath)'
echo '##vso[task.prependpath]$(GOROOT)/bin'
echo '##vso[task.prependpath]$(GOPATH)/bin'
displayName: 'Create Go Workspace'
- script: |
set -e
curl -sSL https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
dep ensure -v
go get -u golang.org/x/lint/golint
go get -u github.com/stretchr/testify
workingDirectory: '$(sdkPath)'
displayName: 'Install Dependencies'
- script: go vet ./autorest/...
workingDirectory: '$(sdkPath)'
displayName: 'Vet'
- script: go build -v ./autorest/...
workingDirectory: '$(sdkPath)'
displayName: 'Build'
- script: go test -race -v ./autorest/...
workingDirectory: '$(sdkPath)'
displayName: 'Run Tests'
- script: grep -L -r --include *.go --exclude-dir vendor -P "Copyright (\d{4}|\(c\)) Microsoft" ./ | tee >&2
workingDirectory: '$(sdkPath)'
displayName: 'Copyright Header Check'
failOnStderr: true
condition: succeededOrFailed()
- script: gofmt -s -l -w ./autorest/. >&2
workingDirectory: '$(sdkPath)'
displayName: 'Format Check'
failOnStderr: true
condition: succeededOrFailed()
- script: golint ./autorest/... >&2
workingDirectory: '$(sdkPath)'
displayName: 'Linter Check'
failOnStderr: true
condition: succeededOrFailed()

View File

@@ -1,44 +0,0 @@
hash: 6e0121d946623e7e609280b1b18627e1c8a767fdece54cb97c4447c1167cbc46
updated: 2017-08-31T13:58:01.034822883+01:00
imports:
- name: github.com/dgrijalva/jwt-go
version: 2268707a8f0843315e2004ee4f1d021dc08baedf
subpackages:
- .
- name: github.com/dimchansky/utfbom
version: 6c6132ff69f0f6c088739067407b5d32c52e1d0f
- name: github.com/mitchellh/go-homedir
version: b8bc1bf767474819792c23f32d8286a45736f1c6
- name: golang.org/x/crypto
version: 81e90905daefcd6fd217b62423c0908922eadb30
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- pkcs12
- pkcs12/internal/rc2
- name: golang.org/x/net
version: 66aacef3dd8a676686c7ae3716979581e8b03c47
repo: https://github.com/golang/net.git
vcs: git
subpackages:
- .
- name: golang.org/x/text
version: 21e35d45962262c8ee80f6cb048dcf95ad0e9d79
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- .
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 890a5c3458b43e6104ff5da8dfa139d013d77544
subpackages:
- assert
- require

View File

@@ -1,22 +0,0 @@
package: github.com/Azure/go-autorest
import:
- package: github.com/dgrijalva/jwt-go
subpackages:
- .
- package: golang.org/x/crypto
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- pkcs12
- package: golang.org/x/net
repo: https://github.com/golang/net.git
vcs: git
subpackages:
- .
- package: golang.org/x/text
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- .
- package: github.com/mitchellh/go-homedir
- package: github.com/dimchansky/utfbom

328
vendor/github.com/Azure/go-autorest/logger/logger.go generated vendored Normal file
View File

@@ -0,0 +1,328 @@
package logger
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
)
// LevelType tells a logger the minimum level to log. When code reports a log entry,
// the LogLevel indicates the level of the log entry. The logger only records entries
// whose level is at least the level it was told to log. See the Log* constants.
// For example, if a logger is configured with LogError, then LogError, LogPanic,
// and LogFatal entries will be logged; lower level entries are ignored.
type LevelType uint32
const (
// LogNone tells a logger not to log any entries passed to it.
LogNone LevelType = iota
// LogFatal tells a logger to log all LogFatal entries passed to it.
LogFatal
// LogPanic tells a logger to log all LogPanic and LogFatal entries passed to it.
LogPanic
// LogError tells a logger to log all LogError, LogPanic and LogFatal entries passed to it.
LogError
// LogWarning tells a logger to log all LogWarning, LogError, LogPanic and LogFatal entries passed to it.
LogWarning
// LogInfo tells a logger to log all LogInfo, LogWarning, LogError, LogPanic and LogFatal entries passed to it.
LogInfo
// LogDebug tells a logger to log all LogDebug, LogInfo, LogWarning, LogError, LogPanic and LogFatal entries passed to it.
LogDebug
)
const (
logNone = "NONE"
logFatal = "FATAL"
logPanic = "PANIC"
logError = "ERROR"
logWarning = "WARNING"
logInfo = "INFO"
logDebug = "DEBUG"
logUnknown = "UNKNOWN"
)
// ParseLevel converts the specified string into the corresponding LevelType.
func ParseLevel(s string) (lt LevelType, err error) {
switch strings.ToUpper(s) {
case logFatal:
lt = LogFatal
case logPanic:
lt = LogPanic
case logError:
lt = LogError
case logWarning:
lt = LogWarning
case logInfo:
lt = LogInfo
case logDebug:
lt = LogDebug
default:
err = fmt.Errorf("bad log level '%s'", s)
}
return
}
// String implements the stringer interface for LevelType.
func (lt LevelType) String() string {
switch lt {
case LogNone:
return logNone
case LogFatal:
return logFatal
case LogPanic:
return logPanic
case LogError:
return logError
case LogWarning:
return logWarning
case LogInfo:
return logInfo
case LogDebug:
return logDebug
default:
return logUnknown
}
}
// Filter defines functions for filtering HTTP request/response content.
type Filter struct {
// URL returns a potentially modified string representation of a request URL.
URL func(u *url.URL) string
// Header returns a potentially modified set of values for the specified key.
// To completely exclude the header key/values return false.
Header func(key string, val []string) (bool, []string)
// Body returns a potentially modified request/response body.
Body func(b []byte) []byte
}
func (f Filter) processURL(u *url.URL) string {
if f.URL == nil {
return u.String()
}
return f.URL(u)
}
func (f Filter) processHeader(k string, val []string) (bool, []string) {
if f.Header == nil {
return true, val
}
return f.Header(k, val)
}
func (f Filter) processBody(b []byte) []byte {
if f.Body == nil {
return b
}
return f.Body(b)
}
// Writer defines methods for writing to a logging facility.
type Writer interface {
// Writeln writes the specified message with the standard log entry header and new-line character.
Writeln(level LevelType, message string)
// Writef writes the specified format specifier with the standard log entry header and no new-line character.
Writef(level LevelType, format string, a ...interface{})
// WriteRequest writes the specified HTTP request to the logger if the log level is greater than
// or equal to LogInfo. The request body, if set, is logged at level LogDebug or higher.
// Custom filters can be specified to exclude URL, header, and/or body content from the log.
// By default no request content is excluded.
WriteRequest(req *http.Request, filter Filter)
// WriteResponse writes the specified HTTP response to the logger if the log level is greater than
// or equal to LogInfo. The response body, if set, is logged at level LogDebug or higher.
// Custom filters can be specified to exclude URL, header, and/or body content from the log.
// By default no respone content is excluded.
WriteResponse(resp *http.Response, filter Filter)
}
// Instance is the default log writer initialized during package init.
// This can be replaced with a custom implementation as required.
var Instance Writer
// default log level
var logLevel = LogNone
// Level returns the value specified in AZURE_GO_AUTOREST_LOG_LEVEL.
// If no value was specified the default value is LogNone.
// Custom loggers can call this to retrieve the configured log level.
func Level() LevelType {
return logLevel
}
func init() {
// separated for testing purposes
initDefaultLogger()
}
func initDefaultLogger() {
// init with nilLogger so callers don't have to do a nil check on Default
Instance = nilLogger{}
llStr := strings.ToLower(os.Getenv("AZURE_GO_SDK_LOG_LEVEL"))
if llStr == "" {
return
}
var err error
logLevel, err = ParseLevel(llStr)
if err != nil {
fmt.Fprintf(os.Stderr, "go-autorest: failed to parse log level: %s\n", err.Error())
return
}
if logLevel == LogNone {
return
}
// default to stderr
dest := os.Stderr
lfStr := os.Getenv("AZURE_GO_SDK_LOG_FILE")
if strings.EqualFold(lfStr, "stdout") {
dest = os.Stdout
} else if lfStr != "" {
lf, err := os.Create(lfStr)
if err == nil {
dest = lf
} else {
fmt.Fprintf(os.Stderr, "go-autorest: failed to create log file, using stderr: %s\n", err.Error())
}
}
Instance = fileLogger{
logLevel: logLevel,
mu: &sync.Mutex{},
logFile: dest,
}
}
// the nil logger does nothing
type nilLogger struct{}
func (nilLogger) Writeln(LevelType, string) {}
func (nilLogger) Writef(LevelType, string, ...interface{}) {}
func (nilLogger) WriteRequest(*http.Request, Filter) {}
func (nilLogger) WriteResponse(*http.Response, Filter) {}
// A File is used instead of a Logger so the stream can be flushed after every write.
type fileLogger struct {
logLevel LevelType
mu *sync.Mutex // for synchronizing writes to logFile
logFile *os.File
}
func (fl fileLogger) Writeln(level LevelType, message string) {
fl.Writef(level, "%s\n", message)
}
func (fl fileLogger) Writef(level LevelType, format string, a ...interface{}) {
if fl.logLevel >= level {
fl.mu.Lock()
defer fl.mu.Unlock()
fmt.Fprintf(fl.logFile, "%s %s", entryHeader(level), fmt.Sprintf(format, a...))
fl.logFile.Sync()
}
}
func (fl fileLogger) WriteRequest(req *http.Request, filter Filter) {
if req == nil || fl.logLevel < LogInfo {
return
}
b := &bytes.Buffer{}
fmt.Fprintf(b, "%s REQUEST: %s %s\n", entryHeader(LogInfo), req.Method, filter.processURL(req.URL))
// dump headers
for k, v := range req.Header {
if ok, mv := filter.processHeader(k, v); ok {
fmt.Fprintf(b, "%s: %s\n", k, strings.Join(mv, ","))
}
}
if fl.shouldLogBody(req.Header, req.Body) {
// dump body
body, err := ioutil.ReadAll(req.Body)
if err == nil {
fmt.Fprintln(b, string(filter.processBody(body)))
if nc, ok := req.Body.(io.Seeker); ok {
// rewind to the beginning
nc.Seek(0, io.SeekStart)
} else {
// recreate the body
req.Body = ioutil.NopCloser(bytes.NewReader(body))
}
} else {
fmt.Fprintf(b, "failed to read body: %v\n", err)
}
}
fl.mu.Lock()
defer fl.mu.Unlock()
fmt.Fprint(fl.logFile, b.String())
fl.logFile.Sync()
}
func (fl fileLogger) WriteResponse(resp *http.Response, filter Filter) {
if resp == nil || fl.logLevel < LogInfo {
return
}
b := &bytes.Buffer{}
fmt.Fprintf(b, "%s RESPONSE: %d %s\n", entryHeader(LogInfo), resp.StatusCode, filter.processURL(resp.Request.URL))
// dump headers
for k, v := range resp.Header {
if ok, mv := filter.processHeader(k, v); ok {
fmt.Fprintf(b, "%s: %s\n", k, strings.Join(mv, ","))
}
}
if fl.shouldLogBody(resp.Header, resp.Body) {
// dump body
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err == nil {
fmt.Fprintln(b, string(filter.processBody(body)))
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
} else {
fmt.Fprintf(b, "failed to read body: %v\n", err)
}
}
fl.mu.Lock()
defer fl.mu.Unlock()
fmt.Fprint(fl.logFile, b.String())
fl.logFile.Sync()
}
// returns true if the provided body should be included in the log
func (fl fileLogger) shouldLogBody(header http.Header, body io.ReadCloser) bool {
ct := header.Get("Content-Type")
return fl.logLevel >= LogDebug && body != nil && strings.Index(ct, "application/octet-stream") == -1
}
// creates standard header for log entries, it contains a timestamp and the log level
func entryHeader(level LevelType) string {
// this format provides a fixed number of digits so the size of the timestamp is constant
return fmt.Sprintf("(%s) %s:", time.Now().Format("2006-01-02T15:04:05.0000000Z07:00"), level.String())
}

View File

@@ -0,0 +1,185 @@
package logger
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"io/ioutil"
"net/http"
"os"
"path"
"regexp"
"strings"
"testing"
)
func TestNilLogger(t *testing.T) {
// verify no crash with no logging
Instance.WriteRequest(nil, Filter{})
}
const (
reqURL = "https://fakething/dot/com"
reqHeaderKey = "x-header"
reqHeaderVal = "value"
reqBody = "the request body"
respHeaderKey = "response-header"
respHeaderVal = "something"
respBody = "the response body"
logFileTimeStampRegex = `\(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{7}(-\d{2}:\d{2}|Z)\)`
)
func TestLogReqRespNoBody(t *testing.T) {
err := os.Setenv("AZURE_GO_SDK_LOG_LEVEL", "info")
if err != nil {
t.Fatalf("failed to set log level: %v", err)
}
lf := path.Join(os.TempDir(), "testloggingbasic.log")
err = os.Setenv("AZURE_GO_SDK_LOG_FILE", lf)
if err != nil {
t.Fatalf("failed to set log file: %v", err)
}
initDefaultLogger()
if Level() != LogInfo {
t.Fatalf("wrong log level: %d", Level())
}
// create mock request and response for logging
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
t.Fatalf("failed to create mock request: %v", err)
}
req.Header.Add(reqHeaderKey, reqHeaderVal)
Instance.WriteRequest(req, Filter{})
resp := &http.Response{
Status: "200 OK",
StatusCode: 200,
Proto: "HTTP/1.0",
ProtoMajor: 1,
ProtoMinor: 0,
Request: req,
Header: http.Header{},
}
resp.Header.Add(respHeaderKey, respHeaderVal)
Instance.WriteResponse(resp, Filter{})
if fl, ok := Instance.(fileLogger); ok {
fl.logFile.Close()
} else {
t.Fatal("expected Instance to be fileLogger")
}
// parse log file to ensure contents match
b, err := ioutil.ReadFile(lf)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
parts := strings.Split(string(b), "\n")
reqMatch := fmt.Sprintf("%s INFO: REQUEST: %s %s", logFileTimeStampRegex, req.Method, req.URL.String())
respMatch := fmt.Sprintf("%s INFO: RESPONSE: %d %s", logFileTimeStampRegex, resp.StatusCode, resp.Request.URL.String())
if !matchRegex(t, reqMatch, parts[0]) {
t.Fatalf("request header doesn't match: %s", parts[0])
}
if !matchRegex(t, fmt.Sprintf("(?i)%s: %s", reqHeaderKey, reqHeaderVal), parts[1]) {
t.Fatalf("request header entry doesn't match: %s", parts[1])
}
if !matchRegex(t, respMatch, parts[2]) {
t.Fatalf("response header doesn't match: %s", parts[2])
}
if !matchRegex(t, fmt.Sprintf("(?i)%s: %s", respHeaderKey, respHeaderVal), parts[3]) {
t.Fatalf("response header value doesn't match: %s", parts[3])
}
// disable logging
err = os.Setenv("AZURE_GO_SDK_LOG_LEVEL", "")
if err != nil {
t.Fatalf("failed to clear log level: %v", err)
}
}
func TestLogReqRespWithBody(t *testing.T) {
err := os.Setenv("AZURE_GO_SDK_LOG_LEVEL", "debug")
if err != nil {
t.Fatalf("failed to set log level: %v", err)
}
lf := path.Join(os.TempDir(), "testloggingfull.log")
err = os.Setenv("AZURE_GO_SDK_LOG_FILE", lf)
if err != nil {
t.Fatalf("failed to set log file: %v", err)
}
initDefaultLogger()
if Level() != LogDebug {
t.Fatalf("wrong log level: %d", Level())
}
// create mock request and response for logging
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(reqBody))
if err != nil {
t.Fatalf("failed to create mock request: %v", err)
}
req.Header.Add(reqHeaderKey, reqHeaderVal)
Instance.WriteRequest(req, Filter{})
resp := &http.Response{
Status: "200 OK",
StatusCode: 200,
Proto: "HTTP/1.0",
ProtoMajor: 1,
ProtoMinor: 0,
Request: req,
Header: http.Header{},
Body: ioutil.NopCloser(strings.NewReader(respBody)),
}
resp.Header.Add(respHeaderKey, respHeaderVal)
Instance.WriteResponse(resp, Filter{})
if fl, ok := Instance.(fileLogger); ok {
fl.logFile.Close()
} else {
t.Fatal("expected Instance to be fileLogger")
}
// parse log file to ensure contents match
b, err := ioutil.ReadFile(lf)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
parts := strings.Split(string(b), "\n")
reqMatch := fmt.Sprintf("%s INFO: REQUEST: %s %s", logFileTimeStampRegex, req.Method, req.URL.String())
respMatch := fmt.Sprintf("%s INFO: RESPONSE: %d %s", logFileTimeStampRegex, resp.StatusCode, resp.Request.URL.String())
if !matchRegex(t, reqMatch, parts[0]) {
t.Fatalf("request header doesn't match: %s", parts[0])
}
if !matchRegex(t, fmt.Sprintf("(?i)%s: %s", reqHeaderKey, reqHeaderVal), parts[1]) {
t.Fatalf("request header value doesn't match: %s", parts[1])
}
if !matchRegex(t, reqBody, parts[2]) {
t.Fatalf("request body doesn't match: %s", parts[2])
}
if !matchRegex(t, respMatch, parts[3]) {
t.Fatalf("response header doesn't match: %s", parts[3])
}
if !matchRegex(t, fmt.Sprintf("(?i)%s: %s", respHeaderKey, respHeaderVal), parts[4]) {
t.Fatalf("response header value doesn't match: %s", parts[4])
}
if !matchRegex(t, respBody, parts[5]) {
t.Fatalf("response body doesn't match: %s", parts[5])
}
// disable logging
err = os.Setenv("AZURE_GO_SDK_LOG_LEVEL", "")
if err != nil {
t.Fatalf("failed to clear log level: %v", err)
}
}
func matchRegex(t *testing.T, pattern, s string) bool {
match, err := regexp.MatchString(pattern, s)
if err != nil {
t.Fatalf("regexp failure: %v", err)
}
return match
}

37
vendor/github.com/Azure/go-autorest/version/version.go generated vendored Normal file
View File

@@ -0,0 +1,37 @@
package version
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"runtime"
)
// Number contains the semantic version of this SDK.
const Number = "v11.1.2"
var (
userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
runtime.Version(),
runtime.GOARCH,
runtime.GOOS,
Number,
)
)
// UserAgent returns a string containing the Go version, system archityecture and OS, and the go-autorest version.
func UserAgent() string {
return userAgent
}

View File

@@ -1,5 +0,0 @@
*.sublime-*
.DS_Store
*.swp
*.swo
tags

View File

@@ -1,7 +0,0 @@
language: go
go:
- 1.4
- 1.5
- 1.6
- tip

View File

@@ -1,12 +0,0 @@
Copyright (c) 2012, Martin Angers
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the author nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,185 +0,0 @@
# Purell
Purell is a tiny Go library to normalize URLs. It returns a pure URL. Pure-ell. Sanitizer and all. Yeah, I know...
Based on the [wikipedia paper][wiki] and the [RFC 3986 document][rfc].
[![build status](https://secure.travis-ci.org/PuerkitoBio/purell.png)](http://travis-ci.org/PuerkitoBio/purell)
## Install
`go get github.com/PuerkitoBio/purell`
## Changelog
* **2016-07-27 (v1.0.0)** : Normalize IDN to ASCII (thanks to @zenovich).
* **2015-02-08** : Add fix for relative paths issue ([PR #5][pr5]) and add fix for unnecessary encoding of reserved characters ([see issue #7][iss7]).
* **v0.2.0** : Add benchmarks, Attempt IDN support.
* **v0.1.0** : Initial release.
## Examples
From `example_test.go` (note that in your code, you would import "github.com/PuerkitoBio/purell", and would prefix references to its methods and constants with "purell."):
```go
package purell
import (
"fmt"
"net/url"
)
func ExampleNormalizeURLString() {
if normalized, err := NormalizeURLString("hTTp://someWEBsite.com:80/Amazing%3f/url/",
FlagLowercaseScheme|FlagLowercaseHost|FlagUppercaseEscapes); err != nil {
panic(err)
} else {
fmt.Print(normalized)
}
// Output: http://somewebsite.com:80/Amazing%3F/url/
}
func ExampleMustNormalizeURLString() {
normalized := MustNormalizeURLString("hTTpS://someWEBsite.com:443/Amazing%fa/url/",
FlagsUnsafeGreedy)
fmt.Print(normalized)
// Output: http://somewebsite.com/Amazing%FA/url
}
func ExampleNormalizeURL() {
if u, err := url.Parse("Http://SomeUrl.com:8080/a/b/.././c///g?c=3&a=1&b=9&c=0#target"); err != nil {
panic(err)
} else {
normalized := NormalizeURL(u, FlagsUsuallySafeGreedy|FlagRemoveDuplicateSlashes|FlagRemoveFragment)
fmt.Print(normalized)
}
// Output: http://someurl.com:8080/a/c/g?c=3&a=1&b=9&c=0
}
```
## API
As seen in the examples above, purell offers three methods, `NormalizeURLString(string, NormalizationFlags) (string, error)`, `MustNormalizeURLString(string, NormalizationFlags) (string)` and `NormalizeURL(*url.URL, NormalizationFlags) (string)`. They all normalize the provided URL based on the specified flags. Here are the available flags:
```go
const (
// Safe normalizations
FlagLowercaseScheme NormalizationFlags = 1 << iota // HTTP://host -> http://host, applied by default in Go1.1
FlagLowercaseHost // http://HOST -> http://host
FlagUppercaseEscapes // http://host/t%ef -> http://host/t%EF
FlagDecodeUnnecessaryEscapes // http://host/t%41 -> http://host/tA
FlagEncodeNecessaryEscapes // http://host/!"#$ -> http://host/%21%22#$
FlagRemoveDefaultPort // http://host:80 -> http://host
FlagRemoveEmptyQuerySeparator // http://host/path? -> http://host/path
// Usually safe normalizations
FlagRemoveTrailingSlash // http://host/path/ -> http://host/path
FlagAddTrailingSlash // http://host/path -> http://host/path/ (should choose only one of these add/remove trailing slash flags)
FlagRemoveDotSegments // http://host/path/./a/b/../c -> http://host/path/a/c
// Unsafe normalizations
FlagRemoveDirectoryIndex // http://host/path/index.html -> http://host/path/
FlagRemoveFragment // http://host/path#fragment -> http://host/path
FlagForceHTTP // https://host -> http://host
FlagRemoveDuplicateSlashes // http://host/path//a///b -> http://host/path/a/b
FlagRemoveWWW // http://www.host/ -> http://host/
FlagAddWWW // http://host/ -> http://www.host/ (should choose only one of these add/remove WWW flags)
FlagSortQuery // http://host/path?c=3&b=2&a=1&b=1 -> http://host/path?a=1&b=1&b=2&c=3
// Normalizations not in the wikipedia article, required to cover tests cases
// submitted by jehiah
FlagDecodeDWORDHost // http://1113982867 -> http://66.102.7.147
FlagDecodeOctalHost // http://0102.0146.07.0223 -> http://66.102.7.147
FlagDecodeHexHost // http://0x42660793 -> http://66.102.7.147
FlagRemoveUnnecessaryHostDots // http://.host../path -> http://host/path
FlagRemoveEmptyPortSeparator // http://host:/path -> http://host/path
// Convenience set of safe normalizations
FlagsSafe NormalizationFlags = FlagLowercaseHost | FlagLowercaseScheme | FlagUppercaseEscapes | FlagDecodeUnnecessaryEscapes | FlagEncodeNecessaryEscapes | FlagRemoveDefaultPort | FlagRemoveEmptyQuerySeparator
// For convenience sets, "greedy" uses the "remove trailing slash" and "remove www. prefix" flags,
// while "non-greedy" uses the "add (or keep) the trailing slash" and "add www. prefix".
// Convenience set of usually safe normalizations (includes FlagsSafe)
FlagsUsuallySafeGreedy NormalizationFlags = FlagsSafe | FlagRemoveTrailingSlash | FlagRemoveDotSegments
FlagsUsuallySafeNonGreedy NormalizationFlags = FlagsSafe | FlagAddTrailingSlash | FlagRemoveDotSegments
// Convenience set of unsafe normalizations (includes FlagsUsuallySafe)
FlagsUnsafeGreedy NormalizationFlags = FlagsUsuallySafeGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagRemoveWWW | FlagSortQuery
FlagsUnsafeNonGreedy NormalizationFlags = FlagsUsuallySafeNonGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagAddWWW | FlagSortQuery
// Convenience set of all available flags
FlagsAllGreedy = FlagsUnsafeGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
FlagsAllNonGreedy = FlagsUnsafeNonGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
)
```
For convenience, the set of flags `FlagsSafe`, `FlagsUsuallySafe[Greedy|NonGreedy]`, `FlagsUnsafe[Greedy|NonGreedy]` and `FlagsAll[Greedy|NonGreedy]` are provided for the similarly grouped normalizations on [wikipedia's URL normalization page][wiki]. You can add (using the bitwise OR `|` operator) or remove (using the bitwise AND NOT `&^` operator) individual flags from the sets if required, to build your own custom set.
The [full godoc reference is available on gopkgdoc][godoc].
Some things to note:
* `FlagDecodeUnnecessaryEscapes`, `FlagEncodeNecessaryEscapes`, `FlagUppercaseEscapes` and `FlagRemoveEmptyQuerySeparator` are always implicitly set, because internally, the URL string is parsed as an URL object, which automatically decodes unnecessary escapes, uppercases and encodes necessary ones, and removes empty query separators (an unnecessary `?` at the end of the url). So this operation cannot **not** be done. For this reason, `FlagRemoveEmptyQuerySeparator` (as well as the other three) has been included in the `FlagsSafe` convenience set, instead of `FlagsUnsafe`, where Wikipedia puts it.
* The `FlagDecodeUnnecessaryEscapes` decodes the following escapes (*from -> to*):
- %24 -> $
- %26 -> &
- %2B-%3B -> +,-./0123456789:;
- %3D -> =
- %40-%5A -> @ABCDEFGHIJKLMNOPQRSTUVWXYZ
- %5F -> _
- %61-%7A -> abcdefghijklmnopqrstuvwxyz
- %7E -> ~
* When the `NormalizeURL` function is used (passing an URL object), this source URL object is modified (that is, after the call, the URL object will be modified to reflect the normalization).
* The *replace IP with domain name* normalization (`http://208.77.188.166/ → http://www.example.com/`) is obviously not possible for a library without making some network requests. This is not implemented in purell.
* The *remove unused query string parameters* and *remove default query parameters* are also not implemented, since this is a very case-specific normalization, and it is quite trivial to do with an URL object.
### Safe vs Usually Safe vs Unsafe
Purell allows you to control the level of risk you take while normalizing an URL. You can aggressively normalize, play it totally safe, or anything in between.
Consider the following URL:
`HTTPS://www.RooT.com/toto/t%45%1f///a/./b/../c/?z=3&w=2&a=4&w=1#invalid`
Normalizing with the `FlagsSafe` gives:
`https://www.root.com/toto/tE%1F///a/./b/../c/?z=3&w=2&a=4&w=1#invalid`
With the `FlagsUsuallySafeGreedy`:
`https://www.root.com/toto/tE%1F///a/c?z=3&w=2&a=4&w=1#invalid`
And with `FlagsUnsafeGreedy`:
`http://root.com/toto/tE%1F/a/c?a=4&w=1&w=2&z=3`
## TODOs
* Add a class/default instance to allow specifying custom directory index names? At the moment, removing directory index removes `(^|/)((?:default|index)\.\w{1,4})$`.
## Thanks / Contributions
@rogpeppe
@jehiah
@opennota
@pchristopher1275
@zenovich
## License
The [BSD 3-Clause license][bsd].
[bsd]: http://opensource.org/licenses/BSD-3-Clause
[wiki]: http://en.wikipedia.org/wiki/URL_normalization
[rfc]: http://tools.ietf.org/html/rfc3986#section-6
[godoc]: http://go.pkgdoc.org/github.com/PuerkitoBio/purell
[pr5]: https://github.com/PuerkitoBio/purell/pull/5
[iss7]: https://github.com/PuerkitoBio/purell/issues/7

View File

@@ -1,57 +0,0 @@
package purell
import (
"testing"
)
var (
safeUrl = "HttPS://..iaMHost..Test:443/paTh^A%ef//./%41PaTH/..//?"
usuallySafeUrl = "HttPS://..iaMHost..Test:443/paTh^A%ef//./%41PaTH/../final/"
unsafeUrl = "HttPS://..www.iaMHost..Test:443/paTh^A%ef//./%41PaTH/../final/index.html?t=val1&a=val4&z=val5&a=val1#fragment"
allDWORDUrl = "HttPS://1113982867:/paTh^A%ef//./%41PaTH/../final/index.html?t=val1&a=val4&z=val5&a=val1#fragment"
allOctalUrl = "HttPS://0102.0146.07.0223:/paTh^A%ef//./%41PaTH/../final/index.html?t=val1&a=val4&z=val5&a=val1#fragment"
allHexUrl = "HttPS://0x42660793:/paTh^A%ef//./%41PaTH/../final/index.html?t=val1&a=val4&z=val5&a=val1#fragment"
allCombinedUrl = "HttPS://..0x42660793.:/paTh^A%ef//./%41PaTH/../final/index.html?t=val1&a=val4&z=val5&a=val1#fragment"
)
func BenchmarkSafe(b *testing.B) {
for i := 0; i < b.N; i++ {
NormalizeURLString(safeUrl, FlagsSafe)
}
}
func BenchmarkUsuallySafe(b *testing.B) {
for i := 0; i < b.N; i++ {
NormalizeURLString(usuallySafeUrl, FlagsUsuallySafeGreedy)
}
}
func BenchmarkUnsafe(b *testing.B) {
for i := 0; i < b.N; i++ {
NormalizeURLString(unsafeUrl, FlagsUnsafeGreedy)
}
}
func BenchmarkAllDWORD(b *testing.B) {
for i := 0; i < b.N; i++ {
NormalizeURLString(allDWORDUrl, FlagsAllGreedy)
}
}
func BenchmarkAllOctal(b *testing.B) {
for i := 0; i < b.N; i++ {
NormalizeURLString(allOctalUrl, FlagsAllGreedy)
}
}
func BenchmarkAllHex(b *testing.B) {
for i := 0; i < b.N; i++ {
NormalizeURLString(allHexUrl, FlagsAllGreedy)
}
}
func BenchmarkAllCombined(b *testing.B) {
for i := 0; i < b.N; i++ {
NormalizeURLString(allCombinedUrl, FlagsAllGreedy)
}
}

View File

@@ -1,9 +0,0 @@
PASS
BenchmarkSafe 500000 6131 ns/op
BenchmarkUsuallySafe 200000 7864 ns/op
BenchmarkUnsafe 100000 28560 ns/op
BenchmarkAllDWORD 50000 38722 ns/op
BenchmarkAllOctal 50000 40941 ns/op
BenchmarkAllHex 50000 44063 ns/op
BenchmarkAllCombined 50000 33613 ns/op
ok github.com/PuerkitoBio/purell 17.404s

View File

@@ -1,35 +0,0 @@
package purell
import (
"fmt"
"net/url"
)
func ExampleNormalizeURLString() {
if normalized, err := NormalizeURLString("hTTp://someWEBsite.com:80/Amazing%3f/url/",
FlagLowercaseScheme|FlagLowercaseHost|FlagUppercaseEscapes); err != nil {
panic(err)
} else {
fmt.Print(normalized)
}
// Output: http://somewebsite.com:80/Amazing%3F/url/
}
func ExampleMustNormalizeURLString() {
normalized := MustNormalizeURLString("hTTpS://someWEBsite.com:443/Amazing%fa/url/",
FlagsUnsafeGreedy)
fmt.Print(normalized)
// Output: http://somewebsite.com/Amazing%FA/url
}
func ExampleNormalizeURL() {
if u, err := url.Parse("Http://SomeUrl.com:8080/a/b/.././c///g?c=3&a=1&b=9&c=0#target"); err != nil {
panic(err)
} else {
normalized := NormalizeURL(u, FlagsUsuallySafeGreedy|FlagRemoveDuplicateSlashes|FlagRemoveFragment)
fmt.Print(normalized)
}
// Output: http://someurl.com:8080/a/c/g?c=3&a=1&b=9&c=0
}

View File

@@ -1,375 +0,0 @@
/*
Package purell offers URL normalization as described on the wikipedia page:
http://en.wikipedia.org/wiki/URL_normalization
*/
package purell
import (
"bytes"
"fmt"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"github.com/PuerkitoBio/urlesc"
"golang.org/x/net/idna"
"golang.org/x/text/secure/precis"
"golang.org/x/text/unicode/norm"
)
// A set of normalization flags determines how a URL will
// be normalized.
type NormalizationFlags uint
const (
// Safe normalizations
FlagLowercaseScheme NormalizationFlags = 1 << iota // HTTP://host -> http://host, applied by default in Go1.1
FlagLowercaseHost // http://HOST -> http://host
FlagUppercaseEscapes // http://host/t%ef -> http://host/t%EF
FlagDecodeUnnecessaryEscapes // http://host/t%41 -> http://host/tA
FlagEncodeNecessaryEscapes // http://host/!"#$ -> http://host/%21%22#$
FlagRemoveDefaultPort // http://host:80 -> http://host
FlagRemoveEmptyQuerySeparator // http://host/path? -> http://host/path
// Usually safe normalizations
FlagRemoveTrailingSlash // http://host/path/ -> http://host/path
FlagAddTrailingSlash // http://host/path -> http://host/path/ (should choose only one of these add/remove trailing slash flags)
FlagRemoveDotSegments // http://host/path/./a/b/../c -> http://host/path/a/c
// Unsafe normalizations
FlagRemoveDirectoryIndex // http://host/path/index.html -> http://host/path/
FlagRemoveFragment // http://host/path#fragment -> http://host/path
FlagForceHTTP // https://host -> http://host
FlagRemoveDuplicateSlashes // http://host/path//a///b -> http://host/path/a/b
FlagRemoveWWW // http://www.host/ -> http://host/
FlagAddWWW // http://host/ -> http://www.host/ (should choose only one of these add/remove WWW flags)
FlagSortQuery // http://host/path?c=3&b=2&a=1&b=1 -> http://host/path?a=1&b=1&b=2&c=3
// Normalizations not in the wikipedia article, required to cover tests cases
// submitted by jehiah
FlagDecodeDWORDHost // http://1113982867 -> http://66.102.7.147
FlagDecodeOctalHost // http://0102.0146.07.0223 -> http://66.102.7.147
FlagDecodeHexHost // http://0x42660793 -> http://66.102.7.147
FlagRemoveUnnecessaryHostDots // http://.host../path -> http://host/path
FlagRemoveEmptyPortSeparator // http://host:/path -> http://host/path
// Convenience set of safe normalizations
FlagsSafe NormalizationFlags = FlagLowercaseHost | FlagLowercaseScheme | FlagUppercaseEscapes | FlagDecodeUnnecessaryEscapes | FlagEncodeNecessaryEscapes | FlagRemoveDefaultPort | FlagRemoveEmptyQuerySeparator
// For convenience sets, "greedy" uses the "remove trailing slash" and "remove www. prefix" flags,
// while "non-greedy" uses the "add (or keep) the trailing slash" and "add www. prefix".
// Convenience set of usually safe normalizations (includes FlagsSafe)
FlagsUsuallySafeGreedy NormalizationFlags = FlagsSafe | FlagRemoveTrailingSlash | FlagRemoveDotSegments
FlagsUsuallySafeNonGreedy NormalizationFlags = FlagsSafe | FlagAddTrailingSlash | FlagRemoveDotSegments
// Convenience set of unsafe normalizations (includes FlagsUsuallySafe)
FlagsUnsafeGreedy NormalizationFlags = FlagsUsuallySafeGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagRemoveWWW | FlagSortQuery
FlagsUnsafeNonGreedy NormalizationFlags = FlagsUsuallySafeNonGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagAddWWW | FlagSortQuery
// Convenience set of all available flags
FlagsAllGreedy = FlagsUnsafeGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
FlagsAllNonGreedy = FlagsUnsafeNonGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
)
const (
defaultHttpPort = ":80"
defaultHttpsPort = ":443"
)
// Regular expressions used by the normalizations
var rxPort = regexp.MustCompile(`(:\d+)/?$`)
var rxDirIndex = regexp.MustCompile(`(^|/)((?:default|index)\.\w{1,4})$`)
var rxDupSlashes = regexp.MustCompile(`/{2,}`)
var rxDWORDHost = regexp.MustCompile(`^(\d+)((?:\.+)?(?:\:\d*)?)$`)
var rxOctalHost = regexp.MustCompile(`^(0\d*)\.(0\d*)\.(0\d*)\.(0\d*)((?:\.+)?(?:\:\d*)?)$`)
var rxHexHost = regexp.MustCompile(`^0x([0-9A-Fa-f]+)((?:\.+)?(?:\:\d*)?)$`)
var rxHostDots = regexp.MustCompile(`^(.+?)(:\d+)?$`)
var rxEmptyPort = regexp.MustCompile(`:+$`)
// Map of flags to implementation function.
// FlagDecodeUnnecessaryEscapes has no action, since it is done automatically
// by parsing the string as an URL. Same for FlagUppercaseEscapes and FlagRemoveEmptyQuerySeparator.
// Since maps have undefined traversing order, make a slice of ordered keys
var flagsOrder = []NormalizationFlags{
FlagLowercaseScheme,
FlagLowercaseHost,
FlagRemoveDefaultPort,
FlagRemoveDirectoryIndex,
FlagRemoveDotSegments,
FlagRemoveFragment,
FlagForceHTTP, // Must be after remove default port (because https=443/http=80)
FlagRemoveDuplicateSlashes,
FlagRemoveWWW,
FlagAddWWW,
FlagSortQuery,
FlagDecodeDWORDHost,
FlagDecodeOctalHost,
FlagDecodeHexHost,
FlagRemoveUnnecessaryHostDots,
FlagRemoveEmptyPortSeparator,
FlagRemoveTrailingSlash, // These two (add/remove trailing slash) must be last
FlagAddTrailingSlash,
}
// ... and then the map, where order is unimportant
var flags = map[NormalizationFlags]func(*url.URL){
FlagLowercaseScheme: lowercaseScheme,
FlagLowercaseHost: lowercaseHost,
FlagRemoveDefaultPort: removeDefaultPort,
FlagRemoveDirectoryIndex: removeDirectoryIndex,
FlagRemoveDotSegments: removeDotSegments,
FlagRemoveFragment: removeFragment,
FlagForceHTTP: forceHTTP,
FlagRemoveDuplicateSlashes: removeDuplicateSlashes,
FlagRemoveWWW: removeWWW,
FlagAddWWW: addWWW,
FlagSortQuery: sortQuery,
FlagDecodeDWORDHost: decodeDWORDHost,
FlagDecodeOctalHost: decodeOctalHost,
FlagDecodeHexHost: decodeHexHost,
FlagRemoveUnnecessaryHostDots: removeUnncessaryHostDots,
FlagRemoveEmptyPortSeparator: removeEmptyPortSeparator,
FlagRemoveTrailingSlash: removeTrailingSlash,
FlagAddTrailingSlash: addTrailingSlash,
}
// MustNormalizeURLString returns the normalized string, and panics if an error occurs.
// It takes an URL string as input, as well as the normalization flags.
func MustNormalizeURLString(u string, f NormalizationFlags) string {
result, e := NormalizeURLString(u, f)
if e != nil {
panic(e)
}
return result
}
// NormalizeURLString returns the normalized string, or an error if it can't be parsed into an URL object.
// It takes an URL string as input, as well as the normalization flags.
func NormalizeURLString(u string, f NormalizationFlags) (string, error) {
if parsed, e := url.Parse(u); e != nil {
return "", e
} else {
options := make([]precis.Option, 1, 3)
options[0] = precis.IgnoreCase
if f&FlagLowercaseHost == FlagLowercaseHost {
options = append(options, precis.FoldCase())
}
options = append(options, precis.Norm(norm.NFC))
profile := precis.NewFreeform(options...)
if parsed.Host, e = idna.ToASCII(profile.NewTransformer().String(parsed.Host)); e != nil {
return "", e
}
return NormalizeURL(parsed, f), nil
}
panic("Unreachable code.")
}
// NormalizeURL returns the normalized string.
// It takes a parsed URL object as input, as well as the normalization flags.
func NormalizeURL(u *url.URL, f NormalizationFlags) string {
for _, k := range flagsOrder {
if f&k == k {
flags[k](u)
}
}
return urlesc.Escape(u)
}
func lowercaseScheme(u *url.URL) {
if len(u.Scheme) > 0 {
u.Scheme = strings.ToLower(u.Scheme)
}
}
func lowercaseHost(u *url.URL) {
if len(u.Host) > 0 {
u.Host = strings.ToLower(u.Host)
}
}
func removeDefaultPort(u *url.URL) {
if len(u.Host) > 0 {
scheme := strings.ToLower(u.Scheme)
u.Host = rxPort.ReplaceAllStringFunc(u.Host, func(val string) string {
if (scheme == "http" && val == defaultHttpPort) || (scheme == "https" && val == defaultHttpsPort) {
return ""
}
return val
})
}
}
func removeTrailingSlash(u *url.URL) {
if l := len(u.Path); l > 0 {
if strings.HasSuffix(u.Path, "/") {
u.Path = u.Path[:l-1]
}
} else if l = len(u.Host); l > 0 {
if strings.HasSuffix(u.Host, "/") {
u.Host = u.Host[:l-1]
}
}
}
func addTrailingSlash(u *url.URL) {
if l := len(u.Path); l > 0 {
if !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
}
} else if l = len(u.Host); l > 0 {
if !strings.HasSuffix(u.Host, "/") {
u.Host += "/"
}
}
}
func removeDotSegments(u *url.URL) {
if len(u.Path) > 0 {
var dotFree []string
var lastIsDot bool
sections := strings.Split(u.Path, "/")
for _, s := range sections {
if s == ".." {
if len(dotFree) > 0 {
dotFree = dotFree[:len(dotFree)-1]
}
} else if s != "." {
dotFree = append(dotFree, s)
}
lastIsDot = (s == "." || s == "..")
}
// Special case if host does not end with / and new path does not begin with /
u.Path = strings.Join(dotFree, "/")
if u.Host != "" && !strings.HasSuffix(u.Host, "/") && !strings.HasPrefix(u.Path, "/") {
u.Path = "/" + u.Path
}
// Special case if the last segment was a dot, make sure the path ends with a slash
if lastIsDot && !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
}
}
}
func removeDirectoryIndex(u *url.URL) {
if len(u.Path) > 0 {
u.Path = rxDirIndex.ReplaceAllString(u.Path, "$1")
}
}
func removeFragment(u *url.URL) {
u.Fragment = ""
}
func forceHTTP(u *url.URL) {
if strings.ToLower(u.Scheme) == "https" {
u.Scheme = "http"
}
}
func removeDuplicateSlashes(u *url.URL) {
if len(u.Path) > 0 {
u.Path = rxDupSlashes.ReplaceAllString(u.Path, "/")
}
}
func removeWWW(u *url.URL) {
if len(u.Host) > 0 && strings.HasPrefix(strings.ToLower(u.Host), "www.") {
u.Host = u.Host[4:]
}
}
func addWWW(u *url.URL) {
if len(u.Host) > 0 && !strings.HasPrefix(strings.ToLower(u.Host), "www.") {
u.Host = "www." + u.Host
}
}
func sortQuery(u *url.URL) {
q := u.Query()
if len(q) > 0 {
arKeys := make([]string, len(q))
i := 0
for k, _ := range q {
arKeys[i] = k
i++
}
sort.Strings(arKeys)
buf := new(bytes.Buffer)
for _, k := range arKeys {
sort.Strings(q[k])
for _, v := range q[k] {
if buf.Len() > 0 {
buf.WriteRune('&')
}
buf.WriteString(fmt.Sprintf("%s=%s", k, urlesc.QueryEscape(v)))
}
}
// Rebuild the raw query string
u.RawQuery = buf.String()
}
}
func decodeDWORDHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxDWORDHost.FindStringSubmatch(u.Host); len(matches) > 2 {
var parts [4]int64
dword, _ := strconv.ParseInt(matches[1], 10, 0)
for i, shift := range []uint{24, 16, 8, 0} {
parts[i] = dword >> shift & 0xFF
}
u.Host = fmt.Sprintf("%d.%d.%d.%d%s", parts[0], parts[1], parts[2], parts[3], matches[2])
}
}
}
func decodeOctalHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxOctalHost.FindStringSubmatch(u.Host); len(matches) > 5 {
var parts [4]int64
for i := 1; i <= 4; i++ {
parts[i-1], _ = strconv.ParseInt(matches[i], 8, 0)
}
u.Host = fmt.Sprintf("%d.%d.%d.%d%s", parts[0], parts[1], parts[2], parts[3], matches[5])
}
}
}
func decodeHexHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxHexHost.FindStringSubmatch(u.Host); len(matches) > 2 {
// Conversion is safe because of regex validation
parsed, _ := strconv.ParseInt(matches[1], 16, 0)
// Set host as DWORD (base 10) encoded host
u.Host = fmt.Sprintf("%d%s", parsed, matches[2])
// The rest is the same as decoding a DWORD host
decodeDWORDHost(u)
}
}
}
func removeUnncessaryHostDots(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxHostDots.FindStringSubmatch(u.Host); len(matches) > 1 {
// Trim the leading and trailing dots
u.Host = strings.Trim(matches[1], ".")
if len(matches) > 2 {
u.Host += matches[2]
}
}
}
}
func removeEmptyPortSeparator(u *url.URL) {
if len(u.Host) > 0 {
u.Host = rxEmptyPort.ReplaceAllString(u.Host, "")
}
}

View File

@@ -1,768 +0,0 @@
package purell
import (
"fmt"
"net/url"
"testing"
)
type testCase struct {
nm string
src string
flgs NormalizationFlags
res string
parsed bool
}
var (
cases = [...]*testCase{
&testCase{
"LowerScheme",
"HTTP://www.SRC.ca",
FlagLowercaseScheme,
"http://www.SRC.ca",
false,
},
&testCase{
"LowerScheme2",
"http://www.SRC.ca",
FlagLowercaseScheme,
"http://www.SRC.ca",
false,
},
&testCase{
"LowerHost",
"HTTP://www.SRC.ca/",
FlagLowercaseHost,
"http://www.src.ca/", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"UpperEscapes",
`http://www.whatever.com/Some%aa%20Special%8Ecases/`,
FlagUppercaseEscapes,
"http://www.whatever.com/Some%AA%20Special%8Ecases/",
false,
},
&testCase{
"UnnecessaryEscapes",
`http://www.toto.com/%41%42%2E%44/%32%33%52%2D/%5f%7E`,
FlagDecodeUnnecessaryEscapes,
"http://www.toto.com/AB.D/23R-/_~",
false,
},
&testCase{
"RemoveDefaultPort",
"HTTP://www.SRC.ca:80/",
FlagRemoveDefaultPort,
"http://www.SRC.ca/", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveDefaultPort2",
"HTTP://www.SRC.ca:80",
FlagRemoveDefaultPort,
"http://www.SRC.ca", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveDefaultPort3",
"HTTP://www.SRC.ca:8080",
FlagRemoveDefaultPort,
"http://www.SRC.ca:8080", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"Safe",
"HTTP://www.SRC.ca:80/to%1ato%8b%ee/OKnow%41%42%43%7e",
FlagsSafe,
"http://www.src.ca/to%1Ato%8B%EE/OKnowABC~",
false,
},
&testCase{
"BothLower",
"HTTP://www.SRC.ca:80/to%1ato%8b%ee/OKnow%41%42%43%7e",
FlagLowercaseHost | FlagLowercaseScheme,
"http://www.src.ca:80/to%1Ato%8B%EE/OKnowABC~",
false,
},
&testCase{
"RemoveTrailingSlash",
"HTTP://www.SRC.ca:80/",
FlagRemoveTrailingSlash,
"http://www.SRC.ca:80", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveTrailingSlash2",
"HTTP://www.SRC.ca:80/toto/titi/",
FlagRemoveTrailingSlash,
"http://www.SRC.ca:80/toto/titi", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveTrailingSlash3",
"HTTP://www.SRC.ca:80/toto/titi/fin/?a=1",
FlagRemoveTrailingSlash,
"http://www.SRC.ca:80/toto/titi/fin?a=1", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"AddTrailingSlash",
"HTTP://www.SRC.ca:80",
FlagAddTrailingSlash,
"http://www.SRC.ca:80/", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"AddTrailingSlash2",
"HTTP://www.SRC.ca:80/toto/titi.html",
FlagAddTrailingSlash,
"http://www.SRC.ca:80/toto/titi.html/", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"AddTrailingSlash3",
"HTTP://www.SRC.ca:80/toto/titi/fin?a=1",
FlagAddTrailingSlash,
"http://www.SRC.ca:80/toto/titi/fin/?a=1", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveDotSegments",
"HTTP://root/a/b/./../../c/",
FlagRemoveDotSegments,
"http://root/c/", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveDotSegments2",
"HTTP://root/../a/b/./../c/../d",
FlagRemoveDotSegments,
"http://root/a/d", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"UsuallySafe",
"HTTP://www.SRC.ca:80/to%1ato%8b%ee/./c/d/../OKnow%41%42%43%7e/?a=b#test",
FlagsUsuallySafeGreedy,
"http://www.src.ca/to%1Ato%8B%EE/c/OKnowABC~?a=b#test",
false,
},
&testCase{
"RemoveDirectoryIndex",
"HTTP://root/a/b/c/default.aspx",
FlagRemoveDirectoryIndex,
"http://root/a/b/c/", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveDirectoryIndex2",
"HTTP://root/a/b/c/default#a=b",
FlagRemoveDirectoryIndex,
"http://root/a/b/c/default#a=b", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"RemoveFragment",
"HTTP://root/a/b/c/default#toto=tata",
FlagRemoveFragment,
"http://root/a/b/c/default", // Since Go1.1, scheme is automatically lowercased
false,
},
&testCase{
"ForceHTTP",
"https://root/a/b/c/default#toto=tata",
FlagForceHTTP,
"http://root/a/b/c/default#toto=tata",
false,
},
&testCase{
"RemoveDuplicateSlashes",
"https://root/a//b///c////default#toto=tata",
FlagRemoveDuplicateSlashes,
"https://root/a/b/c/default#toto=tata",
false,
},
&testCase{
"RemoveDuplicateSlashes2",
"https://root//a//b///c////default#toto=tata",
FlagRemoveDuplicateSlashes,
"https://root/a/b/c/default#toto=tata",
false,
},
&testCase{
"RemoveWWW",
"https://www.root/a/b/c/",
FlagRemoveWWW,
"https://root/a/b/c/",
false,
},
&testCase{
"RemoveWWW2",
"https://WwW.Root/a/b/c/",
FlagRemoveWWW,
"https://Root/a/b/c/",
false,
},
&testCase{
"AddWWW",
"https://Root/a/b/c/",
FlagAddWWW,
"https://www.Root/a/b/c/",
false,
},
&testCase{
"SortQuery",
"http://root/toto/?b=4&a=1&c=3&b=2&a=5",
FlagSortQuery,
"http://root/toto/?a=1&a=5&b=2&b=4&c=3",
false,
},
&testCase{
"RemoveEmptyQuerySeparator",
"http://root/toto/?",
FlagRemoveEmptyQuerySeparator,
"http://root/toto/",
false,
},
&testCase{
"Unsafe",
"HTTPS://www.RooT.com/toto/t%45%1f///a/./b/../c/?z=3&w=2&a=4&w=1#invalid",
FlagsUnsafeGreedy,
"http://root.com/toto/tE%1F/a/c?a=4&w=1&w=2&z=3",
false,
},
&testCase{
"Safe2",
"HTTPS://www.RooT.com/toto/t%45%1f///a/./b/../c/?z=3&w=2&a=4&w=1#invalid",
FlagsSafe,
"https://www.root.com/toto/tE%1F///a/./b/../c/?z=3&w=2&a=4&w=1#invalid",
false,
},
&testCase{
"UsuallySafe2",
"HTTPS://www.RooT.com/toto/t%45%1f///a/./b/../c/?z=3&w=2&a=4&w=1#invalid",
FlagsUsuallySafeGreedy,
"https://www.root.com/toto/tE%1F///a/c?z=3&w=2&a=4&w=1#invalid",
false,
},
&testCase{
"AddTrailingSlashBug",
"http://src.ca/",
FlagsAllNonGreedy,
"http://www.src.ca/",
false,
},
&testCase{
"SourceModified",
"HTTPS://www.RooT.com/toto/t%45%1f///a/./b/../c/?z=3&w=2&a=4&w=1#invalid",
FlagsUnsafeGreedy,
"http://root.com/toto/tE%1F/a/c?a=4&w=1&w=2&z=3",
true,
},
&testCase{
"IPv6-1",
"http://[2001:db8:1f70::999:de8:7648:6e8]/test",
FlagsSafe | FlagRemoveDotSegments,
"http://[2001:db8:1f70::999:de8:7648:6e8]/test",
false,
},
&testCase{
"IPv6-2",
"http://[::ffff:192.168.1.1]/test",
FlagsSafe | FlagRemoveDotSegments,
"http://[::ffff:192.168.1.1]/test",
false,
},
&testCase{
"IPv6-3",
"http://[::ffff:192.168.1.1]:80/test",
FlagsSafe | FlagRemoveDotSegments,
"http://[::ffff:192.168.1.1]/test",
false,
},
&testCase{
"IPv6-4",
"htTps://[::fFff:192.168.1.1]:443/test",
FlagsSafe | FlagRemoveDotSegments,
"https://[::ffff:192.168.1.1]/test",
false,
},
&testCase{
"FTP",
"ftp://user:pass@ftp.foo.net/foo/bar",
FlagsSafe | FlagRemoveDotSegments,
"ftp://user:pass@ftp.foo.net/foo/bar",
false,
},
&testCase{
"Standard-1",
"http://www.foo.com:80/foo",
FlagsSafe | FlagRemoveDotSegments,
"http://www.foo.com/foo",
false,
},
&testCase{
"Standard-2",
"http://www.foo.com:8000/foo",
FlagsSafe | FlagRemoveDotSegments,
"http://www.foo.com:8000/foo",
false,
},
&testCase{
"Standard-3",
"http://www.foo.com/%7ebar",
FlagsSafe | FlagRemoveDotSegments,
"http://www.foo.com/~bar",
false,
},
&testCase{
"Standard-4",
"http://www.foo.com/%7Ebar",
FlagsSafe | FlagRemoveDotSegments,
"http://www.foo.com/~bar",
false,
},
&testCase{
"Standard-5",
"http://USER:pass@www.Example.COM/foo/bar",
FlagsSafe | FlagRemoveDotSegments,
"http://USER:pass@www.example.com/foo/bar",
false,
},
&testCase{
"Standard-6",
"http://test.example/?a=%26&b=1",
FlagsSafe | FlagRemoveDotSegments,
"http://test.example/?a=%26&b=1",
false,
},
&testCase{
"Standard-7",
"http://test.example/%25/?p=%20val%20%25",
FlagsSafe | FlagRemoveDotSegments,
"http://test.example/%25/?p=%20val%20%25",
false,
},
&testCase{
"Standard-8",
"http://test.example/path/with a%20space+/",
FlagsSafe | FlagRemoveDotSegments,
"http://test.example/path/with%20a%20space+/",
false,
},
&testCase{
"Standard-9",
"http://test.example/?",
FlagsSafe | FlagRemoveDotSegments,
"http://test.example/",
false,
},
&testCase{
"Standard-10",
"http://a.COM/path/?b&a",
FlagsSafe | FlagRemoveDotSegments,
"http://a.com/path/?b&a",
false,
},
&testCase{
"StandardCasesAddTrailingSlash",
"http://test.example?",
FlagsSafe | FlagAddTrailingSlash,
"http://test.example/",
false,
},
&testCase{
"OctalIP-1",
"http://0123.011.0.4/",
FlagsSafe | FlagDecodeOctalHost,
"http://0123.011.0.4/",
false,
},
&testCase{
"OctalIP-2",
"http://0102.0146.07.0223/",
FlagsSafe | FlagDecodeOctalHost,
"http://66.102.7.147/",
false,
},
&testCase{
"OctalIP-3",
"http://0102.0146.07.0223.:23/",
FlagsSafe | FlagDecodeOctalHost,
"http://66.102.7.147.:23/",
false,
},
&testCase{
"OctalIP-4",
"http://USER:pass@0102.0146.07.0223../",
FlagsSafe | FlagDecodeOctalHost,
"http://USER:pass@66.102.7.147../",
false,
},
&testCase{
"DWORDIP-1",
"http://123.1113982867/",
FlagsSafe | FlagDecodeDWORDHost,
"http://123.1113982867/",
false,
},
&testCase{
"DWORDIP-2",
"http://1113982867/",
FlagsSafe | FlagDecodeDWORDHost,
"http://66.102.7.147/",
false,
},
&testCase{
"DWORDIP-3",
"http://1113982867.:23/",
FlagsSafe | FlagDecodeDWORDHost,
"http://66.102.7.147.:23/",
false,
},
&testCase{
"DWORDIP-4",
"http://USER:pass@1113982867../",
FlagsSafe | FlagDecodeDWORDHost,
"http://USER:pass@66.102.7.147../",
false,
},
&testCase{
"HexIP-1",
"http://0x123.1113982867/",
FlagsSafe | FlagDecodeHexHost,
"http://0x123.1113982867/",
false,
},
&testCase{
"HexIP-2",
"http://0x42660793/",
FlagsSafe | FlagDecodeHexHost,
"http://66.102.7.147/",
false,
},
&testCase{
"HexIP-3",
"http://0x42660793.:23/",
FlagsSafe | FlagDecodeHexHost,
"http://66.102.7.147.:23/",
false,
},
&testCase{
"HexIP-4",
"http://USER:pass@0x42660793../",
FlagsSafe | FlagDecodeHexHost,
"http://USER:pass@66.102.7.147../",
false,
},
&testCase{
"UnnecessaryHostDots-1",
"http://.www.foo.com../foo/bar.html",
FlagsSafe | FlagRemoveUnnecessaryHostDots,
"http://www.foo.com/foo/bar.html",
false,
},
&testCase{
"UnnecessaryHostDots-2",
"http://www.foo.com./foo/bar.html",
FlagsSafe | FlagRemoveUnnecessaryHostDots,
"http://www.foo.com/foo/bar.html",
false,
},
&testCase{
"UnnecessaryHostDots-3",
"http://www.foo.com.:81/foo",
FlagsSafe | FlagRemoveUnnecessaryHostDots,
"http://www.foo.com:81/foo",
false,
},
&testCase{
"UnnecessaryHostDots-4",
"http://www.example.com./",
FlagsSafe | FlagRemoveUnnecessaryHostDots,
"http://www.example.com/",
false,
},
&testCase{
"EmptyPort-1",
"http://www.thedraymin.co.uk:/main/?p=308",
FlagsSafe | FlagRemoveEmptyPortSeparator,
"http://www.thedraymin.co.uk/main/?p=308",
false,
},
&testCase{
"EmptyPort-2",
"http://www.src.ca:",
FlagsSafe | FlagRemoveEmptyPortSeparator,
"http://www.src.ca",
false,
},
&testCase{
"Slashes-1",
"http://test.example/foo/bar/.",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/bar/",
false,
},
&testCase{
"Slashes-2",
"http://test.example/foo/bar/./",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/bar/",
false,
},
&testCase{
"Slashes-3",
"http://test.example/foo/bar/..",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/",
false,
},
&testCase{
"Slashes-4",
"http://test.example/foo/bar/../",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/",
false,
},
&testCase{
"Slashes-5",
"http://test.example/foo/bar/../baz",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/baz",
false,
},
&testCase{
"Slashes-6",
"http://test.example/foo/bar/../..",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/",
false,
},
&testCase{
"Slashes-7",
"http://test.example/foo/bar/../../",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/",
false,
},
&testCase{
"Slashes-8",
"http://test.example/foo/bar/../../baz",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/baz",
false,
},
&testCase{
"Slashes-9",
"http://test.example/foo/bar/../../../baz",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/baz",
false,
},
&testCase{
"Slashes-10",
"http://test.example/foo/bar/../../../../baz",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/baz",
false,
},
&testCase{
"Slashes-11",
"http://test.example/./foo",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo",
false,
},
&testCase{
"Slashes-12",
"http://test.example/../foo",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo",
false,
},
&testCase{
"Slashes-13",
"http://test.example/foo.",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo.",
false,
},
&testCase{
"Slashes-14",
"http://test.example/.foo",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/.foo",
false,
},
&testCase{
"Slashes-15",
"http://test.example/foo..",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo..",
false,
},
&testCase{
"Slashes-16",
"http://test.example/..foo",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/..foo",
false,
},
&testCase{
"Slashes-17",
"http://test.example/./../foo",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo",
false,
},
&testCase{
"Slashes-18",
"http://test.example/./foo/.",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/",
false,
},
&testCase{
"Slashes-19",
"http://test.example/foo/./bar",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/bar",
false,
},
&testCase{
"Slashes-20",
"http://test.example/foo/../bar",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/bar",
false,
},
&testCase{
"Slashes-21",
"http://test.example/foo//",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/",
false,
},
&testCase{
"Slashes-22",
"http://test.example/foo///bar//",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"http://test.example/foo/bar/",
false,
},
&testCase{
"Relative",
"foo/bar",
FlagsAllGreedy,
"foo/bar",
false,
},
&testCase{
"Relative-1",
"./../foo",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"foo",
false,
},
&testCase{
"Relative-2",
"./foo/bar/../baz/../bang/..",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"foo/",
false,
},
&testCase{
"Relative-3",
"foo///bar//",
FlagsSafe | FlagRemoveDotSegments | FlagRemoveDuplicateSlashes,
"foo/bar/",
false,
},
&testCase{
"Relative-4",
"www.youtube.com",
FlagsUsuallySafeGreedy,
"www.youtube.com",
false,
},
/*&testCase{
"UrlNorm-5",
"http://ja.wikipedia.org/wiki/%E3%82%AD%E3%83%A3%E3%82%BF%E3%83%94%E3%83%A9%E3%83%BC%E3%82%B8%E3%83%A3%E3%83%91%E3%83%B3",
FlagsSafe | FlagRemoveDotSegments,
"http://ja.wikipedia.org/wiki/\xe3\x82\xad\xe3\x83\xa3\xe3\x82\xbf\xe3\x83\x94\xe3\x83\xa9\xe3\x83\xbc\xe3\x82\xb8\xe3\x83\xa3\xe3\x83\x91\xe3\x83\xb3",
false,
},
&testCase{
"UrlNorm-1",
"http://test.example/?a=%e3%82%82%26",
FlagsAllGreedy,
"http://test.example/?a=\xe3\x82\x82%26",
false,
},*/
}
)
func TestRunner(t *testing.T) {
for _, tc := range cases {
runCase(tc, t)
}
}
func runCase(tc *testCase, t *testing.T) {
t.Logf("running %s...", tc.nm)
if tc.parsed {
u, e := url.Parse(tc.src)
if e != nil {
t.Errorf("%s - FAIL : %s", tc.nm, e)
return
} else {
NormalizeURL(u, tc.flgs)
if s := u.String(); s != tc.res {
t.Errorf("%s - FAIL expected '%s', got '%s'", tc.nm, tc.res, s)
}
}
} else {
if s, e := NormalizeURLString(tc.src, tc.flgs); e != nil {
t.Errorf("%s - FAIL : %s", tc.nm, e)
} else if s != tc.res {
t.Errorf("%s - FAIL expected '%s', got '%s'", tc.nm, tc.res, s)
}
}
}
func TestDecodeUnnecessaryEscapesAll(t *testing.T) {
var url = "http://host/"
for i := 0; i < 256; i++ {
url += fmt.Sprintf("%%%02x", i)
}
if s, e := NormalizeURLString(url, FlagDecodeUnnecessaryEscapes); e != nil {
t.Fatalf("Got error %s", e.Error())
} else {
const want = "http://host/%00%01%02%03%04%05%06%07%08%09%0A%0B%0C%0D%0E%0F%10%11%12%13%14%15%16%17%18%19%1A%1B%1C%1D%1E%1F%20!%22%23$%25&'()*+,-./0123456789:;%3C=%3E%3F@ABCDEFGHIJKLMNOPQRSTUVWXYZ[%5C]%5E_%60abcdefghijklmnopqrstuvwxyz%7B%7C%7D~%7F%80%81%82%83%84%85%86%87%88%89%8A%8B%8C%8D%8E%8F%90%91%92%93%94%95%96%97%98%99%9A%9B%9C%9D%9E%9F%A0%A1%A2%A3%A4%A5%A6%A7%A8%A9%AA%AB%AC%AD%AE%AF%B0%B1%B2%B3%B4%B5%B6%B7%B8%B9%BA%BB%BC%BD%BE%BF%C0%C1%C2%C3%C4%C5%C6%C7%C8%C9%CA%CB%CC%CD%CE%CF%D0%D1%D2%D3%D4%D5%D6%D7%D8%D9%DA%DB%DC%DD%DE%DF%E0%E1%E2%E3%E4%E5%E6%E7%E8%E9%EA%EB%EC%ED%EE%EF%F0%F1%F2%F3%F4%F5%F6%F7%F8%F9%FA%FB%FC%FD%FE%FF"
if s != want {
t.Errorf("DecodeUnnecessaryEscapesAll:\nwant\n%s\ngot\n%s", want, s)
}
}
}
func TestEncodeNecessaryEscapesAll(t *testing.T) {
var url = "http://host/"
for i := 0; i < 256; i++ {
if i != 0x25 {
url += string(i)
}
}
if s, e := NormalizeURLString(url, FlagEncodeNecessaryEscapes); e != nil {
t.Fatalf("Got error %s", e.Error())
} else {
const want = "http://host/%00%01%02%03%04%05%06%07%08%09%0A%0B%0C%0D%0E%0F%10%11%12%13%14%15%16%17%18%19%1A%1B%1C%1D%1E%1F%20!%22#$&'()*+,-./0123456789:;%3C=%3E?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[%5C]%5E_%60abcdefghijklmnopqrstuvwxyz%7B%7C%7D~%7F%C2%80%C2%81%C2%82%C2%83%C2%84%C2%85%C2%86%C2%87%C2%88%C2%89%C2%8A%C2%8B%C2%8C%C2%8D%C2%8E%C2%8F%C2%90%C2%91%C2%92%C2%93%C2%94%C2%95%C2%96%C2%97%C2%98%C2%99%C2%9A%C2%9B%C2%9C%C2%9D%C2%9E%C2%9F%C2%A0%C2%A1%C2%A2%C2%A3%C2%A4%C2%A5%C2%A6%C2%A7%C2%A8%C2%A9%C2%AA%C2%AB%C2%AC%C2%AD%C2%AE%C2%AF%C2%B0%C2%B1%C2%B2%C2%B3%C2%B4%C2%B5%C2%B6%C2%B7%C2%B8%C2%B9%C2%BA%C2%BB%C2%BC%C2%BD%C2%BE%C2%BF%C3%80%C3%81%C3%82%C3%83%C3%84%C3%85%C3%86%C3%87%C3%88%C3%89%C3%8A%C3%8B%C3%8C%C3%8D%C3%8E%C3%8F%C3%90%C3%91%C3%92%C3%93%C3%94%C3%95%C3%96%C3%97%C3%98%C3%99%C3%9A%C3%9B%C3%9C%C3%9D%C3%9E%C3%9F%C3%A0%C3%A1%C3%A2%C3%A3%C3%A4%C3%A5%C3%A6%C3%A7%C3%A8%C3%A9%C3%AA%C3%AB%C3%AC%C3%AD%C3%AE%C3%AF%C3%B0%C3%B1%C3%B2%C3%B3%C3%B4%C3%B5%C3%B6%C3%B7%C3%B8%C3%B9%C3%BA%C3%BB%C3%BC%C3%BD%C3%BE%C3%BF"
if s != want {
t.Errorf("EncodeNecessaryEscapesAll:\nwant\n%s\ngot\n%s", want, s)
}
}
}

View File

@@ -1,52 +0,0 @@
package purell
import (
"testing"
)
// Test cases merged from PR #1
// Originally from https://github.com/jehiah/urlnorm/blob/master/test_urlnorm.py
func assertMap(t *testing.T, cases map[string]string, f NormalizationFlags) {
for bad, good := range cases {
s, e := NormalizeURLString(bad, f)
if e != nil {
t.Errorf("%s normalizing %v to %v", e.Error(), bad, good)
} else {
if s != good {
t.Errorf("source: %v expected: %v got: %v", bad, good, s)
}
}
}
}
// This tests normalization to a unicode representation
// precent escapes for unreserved values are unescaped to their unicode value
// tests normalization to idna domains
// test ip word handling, ipv6 address handling, and trailing domain periods
// in general, this matches google chromes unescaping for things in the address bar.
// spaces are converted to '+' (perhaphs controversial)
// http://code.google.com/p/google-url/ probably is another good reference for this approach
func TestUrlnorm(t *testing.T) {
testcases := map[string]string{
"http://test.example/?a=%e3%82%82%26": "http://test.example/?a=%e3%82%82%26",
//"http://test.example/?a=%e3%82%82%26": "http://test.example/?a=\xe3\x82\x82%26", //should return a unicode character
"http://s.xn--q-bga.DE/": "http://s.xn--q-bga.de/", //should be in idna format
"http://XBLA\u306eXbox.com": "http://xn--xblaxbox-jf4g.com", //test utf8 and unicode
"http://президент.рф": "http://xn--d1abbgf6aiiy.xn--p1ai",
"http://ПРЕЗИДЕНТ.РФ": "http://xn--d1abbgf6aiiy.xn--p1ai",
"http://\u00e9.com": "http://xn--9ca.com",
"http://e\u0301.com": "http://xn--9ca.com",
"http://ja.wikipedia.org/wiki/%E3%82%AD%E3%83%A3%E3%82%BF%E3%83%94%E3%83%A9%E3%83%BC%E3%82%B8%E3%83%A3%E3%83%91%E3%83%B3": "http://ja.wikipedia.org/wiki/%E3%82%AD%E3%83%A3%E3%82%BF%E3%83%94%E3%83%A9%E3%83%BC%E3%82%B8%E3%83%A3%E3%83%91%E3%83%B3",
//"http://ja.wikipedia.org/wiki/%E3%82%AD%E3%83%A3%E3%82%BF%E3%83%94%E3%83%A9%E3%83%BC%E3%82%B8%E3%83%A3%E3%83%91%E3%83%B3": "http://ja.wikipedia.org/wiki/\xe3\x82\xad\xe3\x83\xa3\xe3\x82\xbf\xe3\x83\x94\xe3\x83\xa9\xe3\x83\xbc\xe3\x82\xb8\xe3\x83\xa3\xe3\x83\x91\xe3\x83\xb3",
"http://test.example/\xe3\x82\xad": "http://test.example/%E3%82%AD",
//"http://test.example/\xe3\x82\xad": "http://test.example/\xe3\x82\xad",
"http://test.example/?p=%23val#test-%23-val%25": "http://test.example/?p=%23val#test-%23-val%25", //check that %23 (#) is not escaped where it shouldn't be
"http://test.domain/I%C3%B1t%C3%ABrn%C3%A2ti%C3%B4n%EF%BF%BDliz%C3%A6ti%C3%B8n": "http://test.domain/I%C3%B1t%C3%ABrn%C3%A2ti%C3%B4n%EF%BF%BDliz%C3%A6ti%C3%B8n",
//"http://test.domain/I%C3%B1t%C3%ABrn%C3%A2ti%C3%B4n%EF%BF%BDliz%C3%A6ti%C3%B8n": "http://test.domain/I\xc3\xb1t\xc3\xabrn\xc3\xa2ti\xc3\xb4n\xef\xbf\xbdliz\xc3\xa6ti\xc3\xb8n",
}
assertMap(t, testcases, FlagsSafe|FlagRemoveDotSegments)
}

View File

@@ -1,11 +0,0 @@
language: go
go:
- 1.4
- tip
install:
- go build .
script:
- go test -v

View File

@@ -1,27 +0,0 @@
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,16 +0,0 @@
urlesc [![Build Status](https://travis-ci.org/PuerkitoBio/urlesc.png?branch=master)](https://travis-ci.org/PuerkitoBio/urlesc) [![GoDoc](http://godoc.org/github.com/PuerkitoBio/urlesc?status.svg)](http://godoc.org/github.com/PuerkitoBio/urlesc)
======
Package urlesc implements query escaping as per RFC 3986.
It contains some parts of the net/url package, modified so as to allow
some reserved characters incorrectly escaped by net/url (see [issue 5684](https://github.com/golang/go/issues/5684)).
## Install
go get github.com/PuerkitoBio/urlesc
## License
Go license (BSD-3-Clause)

View File

@@ -1,180 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package urlesc implements query escaping as per RFC 3986.
// It contains some parts of the net/url package, modified so as to allow
// some reserved characters incorrectly escaped by net/url.
// See https://github.com/golang/go/issues/5684
package urlesc
import (
"bytes"
"net/url"
"strings"
)
type encoding int
const (
encodePath encoding = 1 + iota
encodeUserPassword
encodeQueryComponent
encodeFragment
)
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
return false
}
switch c {
case '-', '.', '_', '~': // §2.3 Unreserved characters (mark)
return false
// §2.2 Reserved characters (reserved)
case ':', '/', '?', '#', '[', ']', '@', // gen-delims
'!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=': // sub-delims
// Different sections of the URL allow a few of
// the reserved characters to appear unescaped.
switch mode {
case encodePath: // §3.3
// The RFC allows sub-delims and : @.
// '/', '[' and ']' can be used to assign meaning to individual path
// segments. This package only manipulates the path as a whole,
// so we allow those as well. That leaves only ? and # to escape.
return c == '?' || c == '#'
case encodeUserPassword: // §3.2.1
// The RFC allows : and sub-delims in
// userinfo. The parsing of userinfo treats ':' as special so we must escape
// all the gen-delims.
return c == ':' || c == '/' || c == '?' || c == '#' || c == '[' || c == ']' || c == '@'
case encodeQueryComponent: // §3.4
// The RFC allows / and ?.
return c != '/' && c != '?'
case encodeFragment: // §4.1
// The RFC text is silent but the grammar allows
// everything, so escape nothing but #
return c == '#'
}
}
// Everything else must be escaped.
return true
}
// QueryEscape escapes the string so it can be safely placed
// inside a URL query.
func QueryEscape(s string) string {
return escape(s, encodeQueryComponent)
}
func escape(s string, mode encoding) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c, mode) {
if c == ' ' && mode == encodeQueryComponent {
spaceCount++
} else {
hexCount++
}
}
}
if spaceCount == 0 && hexCount == 0 {
return s
}
t := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == ' ' && mode == encodeQueryComponent:
t[j] = '+'
j++
case shouldEscape(c, mode):
t[j] = '%'
t[j+1] = "0123456789ABCDEF"[c>>4]
t[j+2] = "0123456789ABCDEF"[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}
var uiReplacer = strings.NewReplacer(
"%21", "!",
"%27", "'",
"%28", "(",
"%29", ")",
"%2A", "*",
)
// unescapeUserinfo unescapes some characters that need not to be escaped as per RFC3986.
func unescapeUserinfo(s string) string {
return uiReplacer.Replace(s)
}
// Escape reassembles the URL into a valid URL string.
// The general form of the result is one of:
//
// scheme:opaque
// scheme://userinfo@host/path?query#fragment
//
// If u.Opaque is non-empty, String uses the first form;
// otherwise it uses the second form.
//
// In the second form, the following rules apply:
// - if u.Scheme is empty, scheme: is omitted.
// - if u.User is nil, userinfo@ is omitted.
// - if u.Host is empty, host/ is omitted.
// - if u.Scheme and u.Host are empty and u.User is nil,
// the entire scheme://userinfo@host/ is omitted.
// - if u.Host is non-empty and u.Path begins with a /,
// the form host/path does not add its own /.
// - if u.RawQuery is empty, ?query is omitted.
// - if u.Fragment is empty, #fragment is omitted.
func Escape(u *url.URL) string {
var buf bytes.Buffer
if u.Scheme != "" {
buf.WriteString(u.Scheme)
buf.WriteByte(':')
}
if u.Opaque != "" {
buf.WriteString(u.Opaque)
} else {
if u.Scheme != "" || u.Host != "" || u.User != nil {
buf.WriteString("//")
if ui := u.User; ui != nil {
buf.WriteString(unescapeUserinfo(ui.String()))
buf.WriteByte('@')
}
if h := u.Host; h != "" {
buf.WriteString(h)
}
}
if u.Path != "" && u.Path[0] != '/' && u.Host != "" {
buf.WriteByte('/')
}
buf.WriteString(escape(u.Path, encodePath))
}
if u.RawQuery != "" {
buf.WriteByte('?')
buf.WriteString(u.RawQuery)
}
if u.Fragment != "" {
buf.WriteByte('#')
buf.WriteString(escape(u.Fragment, encodeFragment))
}
return buf.String()
}

View File

@@ -1,641 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package urlesc
import (
"net/url"
"testing"
)
type URLTest struct {
in string
out *url.URL
roundtrip string // expected result of reserializing the URL; empty means same as "in".
}
var urltests = []URLTest{
// no path
{
"http://www.google.com",
&url.URL{
Scheme: "http",
Host: "www.google.com",
},
"",
},
// path
{
"http://www.google.com/",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "/",
},
"",
},
// path with hex escaping
{
"http://www.google.com/file%20one%26two",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "/file one&two",
},
"http://www.google.com/file%20one&two",
},
// user
{
"ftp://webmaster@www.google.com/",
&url.URL{
Scheme: "ftp",
User: url.User("webmaster"),
Host: "www.google.com",
Path: "/",
},
"",
},
// escape sequence in username
{
"ftp://john%20doe@www.google.com/",
&url.URL{
Scheme: "ftp",
User: url.User("john doe"),
Host: "www.google.com",
Path: "/",
},
"ftp://john%20doe@www.google.com/",
},
// query
{
"http://www.google.com/?q=go+language",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "/",
RawQuery: "q=go+language",
},
"",
},
// query with hex escaping: NOT parsed
{
"http://www.google.com/?q=go%20language",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "/",
RawQuery: "q=go%20language",
},
"",
},
// %20 outside query
{
"http://www.google.com/a%20b?q=c+d",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "/a b",
RawQuery: "q=c+d",
},
"",
},
// path without leading /, so no parsing
{
"http:www.google.com/?q=go+language",
&url.URL{
Scheme: "http",
Opaque: "www.google.com/",
RawQuery: "q=go+language",
},
"http:www.google.com/?q=go+language",
},
// path without leading /, so no parsing
{
"http:%2f%2fwww.google.com/?q=go+language",
&url.URL{
Scheme: "http",
Opaque: "%2f%2fwww.google.com/",
RawQuery: "q=go+language",
},
"http:%2f%2fwww.google.com/?q=go+language",
},
// non-authority with path
{
"mailto:/webmaster@golang.org",
&url.URL{
Scheme: "mailto",
Path: "/webmaster@golang.org",
},
"mailto:///webmaster@golang.org", // unfortunate compromise
},
// non-authority
{
"mailto:webmaster@golang.org",
&url.URL{
Scheme: "mailto",
Opaque: "webmaster@golang.org",
},
"",
},
// unescaped :// in query should not create a scheme
{
"/foo?query=http://bad",
&url.URL{
Path: "/foo",
RawQuery: "query=http://bad",
},
"",
},
// leading // without scheme should create an authority
{
"//foo",
&url.URL{
Host: "foo",
},
"",
},
// leading // without scheme, with userinfo, path, and query
{
"//user@foo/path?a=b",
&url.URL{
User: url.User("user"),
Host: "foo",
Path: "/path",
RawQuery: "a=b",
},
"",
},
// Three leading slashes isn't an authority, but doesn't return an error.
// (We can't return an error, as this code is also used via
// ServeHTTP -> ReadRequest -> Parse, which is arguably a
// different URL parsing context, but currently shares the
// same codepath)
{
"///threeslashes",
&url.URL{
Path: "///threeslashes",
},
"",
},
{
"http://user:password@google.com",
&url.URL{
Scheme: "http",
User: url.UserPassword("user", "password"),
Host: "google.com",
},
"http://user:password@google.com",
},
// unescaped @ in username should not confuse host
{
"http://j@ne:password@google.com",
&url.URL{
Scheme: "http",
User: url.UserPassword("j@ne", "password"),
Host: "google.com",
},
"http://j%40ne:password@google.com",
},
// unescaped @ in password should not confuse host
{
"http://jane:p@ssword@google.com",
&url.URL{
Scheme: "http",
User: url.UserPassword("jane", "p@ssword"),
Host: "google.com",
},
"http://jane:p%40ssword@google.com",
},
{
"http://j@ne:password@google.com/p@th?q=@go",
&url.URL{
Scheme: "http",
User: url.UserPassword("j@ne", "password"),
Host: "google.com",
Path: "/p@th",
RawQuery: "q=@go",
},
"http://j%40ne:password@google.com/p@th?q=@go",
},
{
"http://www.google.com/?q=go+language#foo",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "/",
RawQuery: "q=go+language",
Fragment: "foo",
},
"",
},
{
"http://www.google.com/?q=go+language#foo%26bar",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "/",
RawQuery: "q=go+language",
Fragment: "foo&bar",
},
"http://www.google.com/?q=go+language#foo&bar",
},
{
"file:///home/adg/rabbits",
&url.URL{
Scheme: "file",
Host: "",
Path: "/home/adg/rabbits",
},
"file:///home/adg/rabbits",
},
// "Windows" paths are no exception to the rule.
// See golang.org/issue/6027, especially comment #9.
{
"file:///C:/FooBar/Baz.txt",
&url.URL{
Scheme: "file",
Host: "",
Path: "/C:/FooBar/Baz.txt",
},
"file:///C:/FooBar/Baz.txt",
},
// case-insensitive scheme
{
"MaIlTo:webmaster@golang.org",
&url.URL{
Scheme: "mailto",
Opaque: "webmaster@golang.org",
},
"mailto:webmaster@golang.org",
},
// Relative path
{
"a/b/c",
&url.URL{
Path: "a/b/c",
},
"a/b/c",
},
// escaped '?' in username and password
{
"http://%3Fam:pa%3Fsword@google.com",
&url.URL{
Scheme: "http",
User: url.UserPassword("?am", "pa?sword"),
Host: "google.com",
},
"",
},
// escaped '?' and '#' in path
{
"http://example.com/%3F%23",
&url.URL{
Scheme: "http",
Host: "example.com",
Path: "?#",
},
"",
},
// unescaped [ ] ! ' ( ) * in path
{
"http://example.com/[]!'()*",
&url.URL{
Scheme: "http",
Host: "example.com",
Path: "[]!'()*",
},
"http://example.com/[]!'()*",
},
// escaped : / ? # [ ] @ in username and password
{
"http://%3A%2F%3F:%23%5B%5D%40@example.com",
&url.URL{
Scheme: "http",
User: url.UserPassword(":/?", "#[]@"),
Host: "example.com",
},
"",
},
// unescaped ! $ & ' ( ) * + , ; = in username and password
{
"http://!$&'():*+,;=@example.com",
&url.URL{
Scheme: "http",
User: url.UserPassword("!$&'()", "*+,;="),
Host: "example.com",
},
"",
},
// unescaped = : / . ? = in query component
{
"http://example.com/?q=http://google.com/?q=",
&url.URL{
Scheme: "http",
Host: "example.com",
Path: "/",
RawQuery: "q=http://google.com/?q=",
},
"",
},
// unescaped : / ? [ ] @ ! $ & ' ( ) * + , ; = in fragment
{
"http://example.com/#:/?%23[]@!$&'()*+,;=",
&url.URL{
Scheme: "http",
Host: "example.com",
Path: "/",
Fragment: ":/?#[]@!$&'()*+,;=",
},
"",
},
}
func DoTestString(t *testing.T, parse func(string) (*url.URL, error), name string, tests []URLTest) {
for _, tt := range tests {
u, err := parse(tt.in)
if err != nil {
t.Errorf("%s(%q) returned error %s", name, tt.in, err)
continue
}
expected := tt.in
if len(tt.roundtrip) > 0 {
expected = tt.roundtrip
}
s := Escape(u)
if s != expected {
t.Errorf("Escape(%s(%q)) == %q (expected %q)", name, tt.in, s, expected)
}
}
}
func TestURLString(t *testing.T) {
DoTestString(t, url.Parse, "Parse", urltests)
// no leading slash on path should prepend
// slash on String() call
noslash := URLTest{
"http://www.google.com/search",
&url.URL{
Scheme: "http",
Host: "www.google.com",
Path: "search",
},
"",
}
s := Escape(noslash.out)
if s != noslash.in {
t.Errorf("Expected %s; go %s", noslash.in, s)
}
}
type EscapeTest struct {
in string
out string
err error
}
var escapeTests = []EscapeTest{
{
"",
"",
nil,
},
{
"abc",
"abc",
nil,
},
{
"one two",
"one+two",
nil,
},
{
"10%",
"10%25",
nil,
},
{
" ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;",
"+?%26%3D%23%2B%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09%3A/%40%24%27%28%29%2A%2C%3B",
nil,
},
}
func TestEscape(t *testing.T) {
for _, tt := range escapeTests {
actual := QueryEscape(tt.in)
if tt.out != actual {
t.Errorf("QueryEscape(%q) = %q, want %q", tt.in, actual, tt.out)
}
// for bonus points, verify that escape:unescape is an identity.
roundtrip, err := url.QueryUnescape(actual)
if roundtrip != tt.in || err != nil {
t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", actual, roundtrip, err, tt.in, "[no error]")
}
}
}
var resolveReferenceTests = []struct {
base, rel, expected string
}{
// Absolute URL references
{"http://foo.com?a=b", "https://bar.com/", "https://bar.com/"},
{"http://foo.com/", "https://bar.com/?a=b", "https://bar.com/?a=b"},
{"http://foo.com/bar", "mailto:foo@example.com", "mailto:foo@example.com"},
// Path-absolute references
{"http://foo.com/bar", "/baz", "http://foo.com/baz"},
{"http://foo.com/bar?a=b#f", "/baz", "http://foo.com/baz"},
{"http://foo.com/bar?a=b", "/baz?c=d", "http://foo.com/baz?c=d"},
// Scheme-relative
{"https://foo.com/bar?a=b", "//bar.com/quux", "https://bar.com/quux"},
// Path-relative references:
// ... current directory
{"http://foo.com", ".", "http://foo.com/"},
{"http://foo.com/bar", ".", "http://foo.com/"},
{"http://foo.com/bar/", ".", "http://foo.com/bar/"},
// ... going down
{"http://foo.com", "bar", "http://foo.com/bar"},
{"http://foo.com/", "bar", "http://foo.com/bar"},
{"http://foo.com/bar/baz", "quux", "http://foo.com/bar/quux"},
// ... going up
{"http://foo.com/bar/baz", "../quux", "http://foo.com/quux"},
{"http://foo.com/bar/baz", "../../../../../quux", "http://foo.com/quux"},
{"http://foo.com/bar", "..", "http://foo.com/"},
{"http://foo.com/bar/baz", "./..", "http://foo.com/"},
// ".." in the middle (issue 3560)
{"http://foo.com/bar/baz", "quux/dotdot/../tail", "http://foo.com/bar/quux/tail"},
{"http://foo.com/bar/baz", "quux/./dotdot/../tail", "http://foo.com/bar/quux/tail"},
{"http://foo.com/bar/baz", "quux/./dotdot/.././tail", "http://foo.com/bar/quux/tail"},
{"http://foo.com/bar/baz", "quux/./dotdot/./../tail", "http://foo.com/bar/quux/tail"},
{"http://foo.com/bar/baz", "quux/./dotdot/dotdot/././../../tail", "http://foo.com/bar/quux/tail"},
{"http://foo.com/bar/baz", "quux/./dotdot/dotdot/./.././../tail", "http://foo.com/bar/quux/tail"},
{"http://foo.com/bar/baz", "quux/./dotdot/dotdot/dotdot/./../../.././././tail", "http://foo.com/bar/quux/tail"},
{"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot/"},
// Remove any dot-segments prior to forming the target URI.
// http://tools.ietf.org/html/rfc3986#section-5.2.4
{"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/baz"},
// Triple dot isn't special
{"http://foo.com/bar", "...", "http://foo.com/..."},
// Fragment
{"http://foo.com/bar", ".#frag", "http://foo.com/#frag"},
// RFC 3986: Normal Examples
// http://tools.ietf.org/html/rfc3986#section-5.4.1
{"http://a/b/c/d;p?q", "g:h", "g:h"},
{"http://a/b/c/d;p?q", "g", "http://a/b/c/g"},
{"http://a/b/c/d;p?q", "./g", "http://a/b/c/g"},
{"http://a/b/c/d;p?q", "g/", "http://a/b/c/g/"},
{"http://a/b/c/d;p?q", "/g", "http://a/g"},
{"http://a/b/c/d;p?q", "//g", "http://g"},
{"http://a/b/c/d;p?q", "?y", "http://a/b/c/d;p?y"},
{"http://a/b/c/d;p?q", "g?y", "http://a/b/c/g?y"},
{"http://a/b/c/d;p?q", "#s", "http://a/b/c/d;p?q#s"},
{"http://a/b/c/d;p?q", "g#s", "http://a/b/c/g#s"},
{"http://a/b/c/d;p?q", "g?y#s", "http://a/b/c/g?y#s"},
{"http://a/b/c/d;p?q", ";x", "http://a/b/c/;x"},
{"http://a/b/c/d;p?q", "g;x", "http://a/b/c/g;x"},
{"http://a/b/c/d;p?q", "g;x?y#s", "http://a/b/c/g;x?y#s"},
{"http://a/b/c/d;p?q", "", "http://a/b/c/d;p?q"},
{"http://a/b/c/d;p?q", ".", "http://a/b/c/"},
{"http://a/b/c/d;p?q", "./", "http://a/b/c/"},
{"http://a/b/c/d;p?q", "..", "http://a/b/"},
{"http://a/b/c/d;p?q", "../", "http://a/b/"},
{"http://a/b/c/d;p?q", "../g", "http://a/b/g"},
{"http://a/b/c/d;p?q", "../..", "http://a/"},
{"http://a/b/c/d;p?q", "../../", "http://a/"},
{"http://a/b/c/d;p?q", "../../g", "http://a/g"},
// RFC 3986: Abnormal Examples
// http://tools.ietf.org/html/rfc3986#section-5.4.2
{"http://a/b/c/d;p?q", "../../../g", "http://a/g"},
{"http://a/b/c/d;p?q", "../../../../g", "http://a/g"},
{"http://a/b/c/d;p?q", "/./g", "http://a/g"},
{"http://a/b/c/d;p?q", "/../g", "http://a/g"},
{"http://a/b/c/d;p?q", "g.", "http://a/b/c/g."},
{"http://a/b/c/d;p?q", ".g", "http://a/b/c/.g"},
{"http://a/b/c/d;p?q", "g..", "http://a/b/c/g.."},
{"http://a/b/c/d;p?q", "..g", "http://a/b/c/..g"},
{"http://a/b/c/d;p?q", "./../g", "http://a/b/g"},
{"http://a/b/c/d;p?q", "./g/.", "http://a/b/c/g/"},
{"http://a/b/c/d;p?q", "g/./h", "http://a/b/c/g/h"},
{"http://a/b/c/d;p?q", "g/../h", "http://a/b/c/h"},
{"http://a/b/c/d;p?q", "g;x=1/./y", "http://a/b/c/g;x=1/y"},
{"http://a/b/c/d;p?q", "g;x=1/../y", "http://a/b/c/y"},
{"http://a/b/c/d;p?q", "g?y/./x", "http://a/b/c/g?y/./x"},
{"http://a/b/c/d;p?q", "g?y/../x", "http://a/b/c/g?y/../x"},
{"http://a/b/c/d;p?q", "g#s/./x", "http://a/b/c/g#s/./x"},
{"http://a/b/c/d;p?q", "g#s/../x", "http://a/b/c/g#s/../x"},
// Extras.
{"https://a/b/c/d;p?q", "//g?q", "https://g?q"},
{"https://a/b/c/d;p?q", "//g#s", "https://g#s"},
{"https://a/b/c/d;p?q", "//g/d/e/f?y#s", "https://g/d/e/f?y#s"},
{"https://a/b/c/d;p#s", "?y", "https://a/b/c/d;p?y"},
{"https://a/b/c/d;p?q#s", "?y", "https://a/b/c/d;p?y"},
}
func TestResolveReference(t *testing.T) {
mustParse := func(url_ string) *url.URL {
u, err := url.Parse(url_)
if err != nil {
t.Fatalf("Expected URL to parse: %q, got error: %v", url_, err)
}
return u
}
opaque := &url.URL{Scheme: "scheme", Opaque: "opaque"}
for _, test := range resolveReferenceTests {
base := mustParse(test.base)
rel := mustParse(test.rel)
url := base.ResolveReference(rel)
if Escape(url) != test.expected {
t.Errorf("URL(%q).ResolveReference(%q) == %q, got %q", test.base, test.rel, test.expected, Escape(url))
}
// Ensure that new instances are returned.
if base == url {
t.Errorf("Expected URL.ResolveReference to return new URL instance.")
}
// Test the convenience wrapper too.
url, err := base.Parse(test.rel)
if err != nil {
t.Errorf("URL(%q).Parse(%q) failed: %v", test.base, test.rel, err)
} else if Escape(url) != test.expected {
t.Errorf("URL(%q).Parse(%q) == %q, got %q", test.base, test.rel, test.expected, Escape(url))
} else if base == url {
// Ensure that new instances are returned for the wrapper too.
t.Errorf("Expected URL.Parse to return new URL instance.")
}
// Ensure Opaque resets the URL.
url = base.ResolveReference(opaque)
if *url != *opaque {
t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", url, opaque)
}
// Test the convenience wrapper with an opaque URL too.
url, err = base.Parse("scheme:opaque")
if err != nil {
t.Errorf(`URL(%q).Parse("scheme:opaque") failed: %v`, test.base, err)
} else if *url != *opaque {
t.Errorf("Parse failed to resolve opaque URL: want %#v, got %#v", url, opaque)
} else if base == url {
// Ensure that new instances are returned, again.
t.Errorf("Expected URL.Parse to return new URL instance.")
}
}
}
type shouldEscapeTest struct {
in byte
mode encoding
escape bool
}
var shouldEscapeTests = []shouldEscapeTest{
// Unreserved characters (§2.3)
{'a', encodePath, false},
{'a', encodeUserPassword, false},
{'a', encodeQueryComponent, false},
{'a', encodeFragment, false},
{'z', encodePath, false},
{'A', encodePath, false},
{'Z', encodePath, false},
{'0', encodePath, false},
{'9', encodePath, false},
{'-', encodePath, false},
{'-', encodeUserPassword, false},
{'-', encodeQueryComponent, false},
{'-', encodeFragment, false},
{'.', encodePath, false},
{'_', encodePath, false},
{'~', encodePath, false},
// User information (§3.2.1)
{':', encodeUserPassword, true},
{'/', encodeUserPassword, true},
{'?', encodeUserPassword, true},
{'@', encodeUserPassword, true},
{'$', encodeUserPassword, false},
{'&', encodeUserPassword, false},
{'+', encodeUserPassword, false},
{',', encodeUserPassword, false},
{';', encodeUserPassword, false},
{'=', encodeUserPassword, false},
}
func TestShouldEscape(t *testing.T) {
for _, tt := range shouldEscapeTests {
if shouldEscape(tt.in, tt.mode) != tt.escape {
t.Errorf("shouldEscape(%q, %v) returned %v; expected %v", tt.in, tt.mode, !tt.escape, tt.escape)
}
}
}

View File

@@ -1,14 +1,28 @@
language: go
go_import_path: github.com/davecgh/go-spew
go:
- 1.5.4
- 1.6.3
- 1.7
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- tip
sudo: false
install:
- go get -v golang.org/x/tools/cmd/cover
- go get -v github.com/alecthomas/gometalinter
- gometalinter --install
script:
- go test -v -tags=safe ./spew
- go test -v -tags=testcgo ./spew -covermode=count -coverprofile=profile.cov
- export PATH=$PATH:$HOME/gopath/bin
- export GORACE="halt_on_error=1"
- test -z "$(gometalinter --disable-all
--enable=gofmt
--enable=golint
--enable=vet
--enable=gosimple
--enable=unconvert
--deadline=4m ./spew | tee /dev/stderr)"
- go test -v -race -tags safe ./spew
- go test -v -race -tags testcgo ./spew -covermode=atomic -coverprofile=profile.cov
after_success:
- go get -v github.com/mattn/goveralls
- export PATH=$PATH:$HOME/gopath/bin
- goveralls -coverprofile=profile.cov -service=travis-ci

View File

@@ -2,7 +2,7 @@ ISC License
Copyright (c) 2012-2016 Dave Collins <dave@davec.name>
Permission to use, copy, modify, and distribute this software for any
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.

View File

@@ -1,12 +1,9 @@
go-spew
=======
[![Build Status](https://img.shields.io/travis/davecgh/go-spew.svg)]
(https://travis-ci.org/davecgh/go-spew) [![ISC License]
(http://img.shields.io/badge/license-ISC-blue.svg)](http://copyfree.org) [![Coverage Status]
(https://img.shields.io/coveralls/davecgh/go-spew.svg)]
(https://coveralls.io/r/davecgh/go-spew?branch=master)
[![Build Status](https://img.shields.io/travis/davecgh/go-spew.svg)](https://travis-ci.org/davecgh/go-spew)
[![ISC License](http://img.shields.io/badge/license-ISC-blue.svg)](http://copyfree.org)
[![Coverage Status](https://img.shields.io/coveralls/davecgh/go-spew.svg)](https://coveralls.io/r/davecgh/go-spew?branch=master)
Go-spew implements a deep pretty printer for Go data structures to aid in
debugging. A comprehensive suite of tests with 100% test coverage is provided
@@ -21,8 +18,7 @@ post about it
## Documentation
[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)]
(http://godoc.org/github.com/davecgh/go-spew/spew)
[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/davecgh/go-spew/spew)
Full `go doc` style documentation for the project can be viewed online without
installing this package by using the excellent GoDoc site here:

View File

@@ -16,7 +16,9 @@
// when the code is not running on Google App Engine, compiled by GopherJS, and
// "-tags safe" is not added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// +build !js,!appengine,!safe,!disableunsafe
// Go versions prior to 1.4 are disabled because they use a different layout
// for interfaces which make the implementation of unsafeReflectValue more complex.
// +build !js,!appengine,!safe,!disableunsafe,go1.4
package spew
@@ -34,80 +36,49 @@ const (
ptrSize = unsafe.Sizeof((*byte)(nil))
)
var (
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the
// internal reflect.Value fields. These values are valid before golang
// commit ecccf07e7f9d which changed the format. The are also valid
// after commit 82f48826c6c7 which changed the format again to mirror
// the original format. Code in the init function updates these offsets
// as necessary.
offsetPtr = uintptr(ptrSize)
offsetScalar = uintptr(0)
offsetFlag = uintptr(ptrSize * 2)
type flag uintptr
// flagKindWidth and flagKindShift indicate various bits that the
// reflect package uses internally to track kind information.
//
// flagRO indicates whether or not the value field of a reflect.Value is
// read-only.
//
// flagIndir indicates whether the value field of a reflect.Value is
// the actual data or a pointer to the data.
//
// These values are valid before golang commit 90a7c3c86944 which
// changed their positions. Code in the init function updates these
// flags as necessary.
flagKindWidth = uintptr(5)
flagKindShift = uintptr(flagKindWidth - 1)
flagRO = uintptr(1 << 0)
flagIndir = uintptr(1 << 1)
var (
// flagRO indicates whether the value field of a reflect.Value
// is read-only.
flagRO flag
// flagAddr indicates whether the address of the reflect.Value's
// value may be taken.
flagAddr flag
)
func init() {
// Older versions of reflect.Value stored small integers directly in the
// ptr field (which is named val in the older versions). Versions
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named
// scalar for this purpose which unfortunately came before the flag
// field, so the offset of the flag field is different for those
// versions.
//
// This code constructs a new reflect.Value from a known small integer
// and checks if the size of the reflect.Value struct indicates it has
// the scalar field. When it does, the offsets are updated accordingly.
vv := reflect.ValueOf(0xf00)
if unsafe.Sizeof(vv) == (ptrSize * 4) {
offsetScalar = ptrSize * 2
offsetFlag = ptrSize * 3
}
// flagKindMask holds the bits that make up the kind
// part of the flags field. In all the supported versions,
// it is in the lower 5 bits.
const flagKindMask = flag(0x1f)
// Commit 90a7c3c86944 changed the flag positions such that the low
// order bits are the kind. This code extracts the kind from the flags
// field and ensures it's the correct type. When it's not, the flag
// order has been changed to the newer format, so the flags are updated
// accordingly.
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag)
upfv := *(*uintptr)(upf)
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift)
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) {
flagKindShift = 0
flagRO = 1 << 5
flagIndir = 1 << 6
// Different versions of Go have used different
// bit layouts for the flags type. This table
// records the known combinations.
var okFlags = []struct {
ro, addr flag
}{{
// From Go 1.4 to 1.5
ro: 1 << 5,
addr: 1 << 7,
}, {
// Up to Go tip.
ro: 1<<5 | 1<<6,
addr: 1 << 8,
}}
// Commit adf9b30e5594 modified the flags to separate the
// flagRO flag into two bits which specifies whether or not the
// field is embedded. This causes flagIndir to move over a bit
// and means that flagRO is the combination of either of the
// original flagRO bit and the new bit.
//
// This code detects the change by extracting what used to be
// the indirect bit to ensure it's set. When it's not, the flag
// order has been changed to the newer format, so the flags are
// updated accordingly.
if upfv&flagIndir == 0 {
flagRO = 3 << 5
flagIndir = 1 << 7
}
var flagValOffset = func() uintptr {
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
if !ok {
panic("reflect.Value has no flag field")
}
return field.Offset
}()
// flagField returns a pointer to the flag field of a reflect.Value.
func flagField(v *reflect.Value) *flag {
return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset))
}
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
@@ -119,34 +90,56 @@ func init() {
// This allows us to check for implementations of the Stringer and error
// interfaces to be used for pretty printing ordinarily unaddressable and
// inaccessible values such as unexported struct fields.
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
indirects := 1
vt := v.Type()
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
if rvf&flagIndir != 0 {
vt = reflect.PtrTo(v.Type())
indirects++
} else if offsetScalar != 0 {
// The value is in the scalar field when it's not one of the
// reference types.
switch vt.Kind() {
case reflect.Uintptr:
case reflect.Chan:
case reflect.Func:
case reflect.Map:
case reflect.Ptr:
case reflect.UnsafePointer:
default:
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
offsetScalar)
func unsafeReflectValue(v reflect.Value) reflect.Value {
if !v.IsValid() || (v.CanInterface() && v.CanAddr()) {
return v
}
flagFieldPtr := flagField(&v)
*flagFieldPtr &^= flagRO
*flagFieldPtr |= flagAddr
return v
}
// Sanity checks against future reflect package changes
// to the type or semantics of the Value.flag field.
func init() {
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
if !ok {
panic("reflect.Value has no flag field")
}
if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() {
panic("reflect.Value flag field has changed kind")
}
type t0 int
var t struct {
A t0
// t0 will have flagEmbedRO set.
t0
// a will have flagStickyRO set
a t0
}
vA := reflect.ValueOf(t).FieldByName("A")
va := reflect.ValueOf(t).FieldByName("a")
vt0 := reflect.ValueOf(t).FieldByName("t0")
// Infer flagRO from the difference between the flags
// for the (otherwise identical) fields in t.
flagPublic := *flagField(&vA)
flagWithRO := *flagField(&va) | *flagField(&vt0)
flagRO = flagPublic ^ flagWithRO
// Infer flagAddr from the difference between a value
// taken from a pointer and not.
vPtrA := reflect.ValueOf(&t).Elem().FieldByName("A")
flagNoPtr := *flagField(&vA)
flagPtr := *flagField(&vPtrA)
flagAddr = flagNoPtr ^ flagPtr
// Check that the inferred flags tally with one of the known versions.
for _, f := range okFlags {
if flagRO == f.ro && flagAddr == f.addr {
return
}
}
pv := reflect.NewAt(vt, upv)
rv = pv
for i := 0; i < indirects; i++ {
rv = rv.Elem()
}
return rv
panic("reflect.Value read-only flag has changed semantics")
}

View File

@@ -16,7 +16,7 @@
// when the code is running on Google App Engine, compiled by GopherJS, or
// "-tags safe" is added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// +build js appengine safe disableunsafe
// +build js appengine safe disableunsafe !go1.4
package spew

View File

@@ -35,16 +35,16 @@ var (
// cCharRE is a regular expression that matches a cgo char.
// It is used to detect character arrays to hexdump them.
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$")
cCharRE = regexp.MustCompile(`^.*\._Ctype_char$`)
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
// char. It is used to detect unsigned character arrays to hexdump
// them.
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$")
cUnsignedCharRE = regexp.MustCompile(`^.*\._Ctype_unsignedchar$`)
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
// It is used to detect uint8_t arrays to hexdump them.
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$")
cUint8tCharRE = regexp.MustCompile(`^.*\._Ctype_uint8_t$`)
)
// dumpState contains information about the state of a dump operation.
@@ -143,10 +143,10 @@ func (d *dumpState) dumpPtr(v reflect.Value) {
// Display dereferenced value.
d.w.Write(openParenBytes)
switch {
case nilFound == true:
case nilFound:
d.w.Write(nilAngleBytes)
case cycleFound == true:
case cycleFound:
d.w.Write(circularBytes)
default:

View File

@@ -768,7 +768,7 @@ func addUintptrDumpTests() {
func addUnsafePointerDumpTests() {
// Null pointer.
v := unsafe.Pointer(uintptr(0))
v := unsafe.Pointer(nil)
nv := (*unsafe.Pointer)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)

View File

@@ -82,18 +82,20 @@ func addCgoDumpTests() {
v5Len := fmt.Sprintf("%d", v5l)
v5Cap := fmt.Sprintf("%d", v5c)
v5t := "[6]testdata._Ctype_uint8_t"
v5t2 := "[6]testdata._Ctype_uchar"
v5s := "(len=" + v5Len + " cap=" + v5Cap + ") " +
"{\n 00000000 74 65 73 74 35 00 " +
" |test5.|\n}"
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
addDumpTest(v5, "("+v5t+") "+v5s+"\n", "("+v5t2+") "+v5s+"\n")
// C typedefed unsigned char array.
v6, v6l, v6c := testdata.GetCgoTypdefedUnsignedCharArray()
v6Len := fmt.Sprintf("%d", v6l)
v6Cap := fmt.Sprintf("%d", v6c)
v6t := "[6]testdata._Ctype_custom_uchar_t"
v6t2 := "[6]testdata._Ctype_uchar"
v6s := "(len=" + v6Len + " cap=" + v6Cap + ") " +
"{\n 00000000 74 65 73 74 36 00 " +
" |test6.|\n}"
addDumpTest(v6, "("+v6t+") "+v6s+"\n")
addDumpTest(v6, "("+v6t+") "+v6s+"\n", "("+v6t2+") "+v6s+"\n")
}

View File

@@ -182,10 +182,10 @@ func (f *formatState) formatPtr(v reflect.Value) {
// Display dereferenced value.
switch {
case nilFound == true:
case nilFound:
f.fs.Write(nilAngleBytes)
case cycleFound == true:
case cycleFound:
f.fs.Write(circularShortBytes)
default:

View File

@@ -1083,7 +1083,7 @@ func addUintptrFormatterTests() {
func addUnsafePointerFormatterTests() {
// Null pointer.
v := unsafe.Pointer(uintptr(0))
v := unsafe.Pointer(nil)
nv := (*unsafe.Pointer)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
@@ -1536,14 +1536,14 @@ func TestPrintSortedKeys(t *testing.T) {
t.Errorf("Sorted keys mismatch 3:\n %v %v", s, expected)
}
s = cfg.Sprint(map[testStruct]int{testStruct{1}: 1, testStruct{3}: 3, testStruct{2}: 2})
s = cfg.Sprint(map[testStruct]int{{1}: 1, {3}: 3, {2}: 2})
expected = "map[ts.1:1 ts.2:2 ts.3:3]"
if s != expected {
t.Errorf("Sorted keys mismatch 4:\n %v %v", s, expected)
}
if !spew.UnsafeDisabled {
s = cfg.Sprint(map[testStructP]int{testStructP{1}: 1, testStructP{3}: 3, testStructP{2}: 2})
s = cfg.Sprint(map[testStructP]int{{1}: 1, {3}: 3, {2}: 2})
expected = "map[ts.1:1 ts.2:2 ts.3:3]"
if s != expected {
t.Errorf("Sorted keys mismatch 5:\n %v %v", s, expected)

View File

@@ -36,10 +36,7 @@ type dummyFmtState struct {
}
func (dfs *dummyFmtState) Flag(f int) bool {
if f == int('+') {
return true
}
return false
return f == int('+')
}
func (dfs *dummyFmtState) Precision() (int, bool) {

View File

@@ -16,7 +16,7 @@
// when the code is not running on Google App Engine, compiled by GopherJS, and
// "-tags safe" is not added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// +build !js,!appengine,!safe,!disableunsafe
// +build !js,!appengine,!safe,!disableunsafe,go1.4
/*
This test file is part of the spew package rather than than the spew_test
@@ -30,7 +30,6 @@ import (
"bytes"
"reflect"
"testing"
"unsafe"
)
// changeKind uses unsafe to intentionally change the kind of a reflect.Value to
@@ -38,13 +37,13 @@ import (
// fallback code which punts to the standard fmt library for new types that
// might get added to the language.
func changeKind(v *reflect.Value, readOnly bool) {
rvf := (*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + offsetFlag))
*rvf = *rvf | ((1<<flagKindWidth - 1) << flagKindShift)
flags := flagField(v)
if readOnly {
*rvf |= flagRO
*flags |= flagRO
} else {
*rvf &= ^uintptr(flagRO)
*flags &^= flagRO
}
*flags |= flagKindMask
}
// TestAddedReflectValue tests functionaly of the dump and formatter code which

View File

@@ -1,8 +1,13 @@
language: go
script:
- go vet ./...
- go test -v ./...
go:
- 1.3
- 1.4
- 1.5
- 1.6
- 1.7
- tip

View File

@@ -56,8 +56,9 @@ This simple parsing example:
is directly mapped to:
```go
if token, err := request.ParseFromRequest(tokenString, request.OAuth2Extractor, req, keyLookupFunc); err == nil {
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
if token, err := request.ParseFromRequest(req, request.OAuth2Extractor, keyLookupFunc); err == nil {
claims := token.Claims.(jwt.MapClaims)
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
}
```

View File

@@ -1,11 +1,15 @@
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
# jwt-go
[![Build Status](https://travis-ci.org/dgrijalva/jwt-go.svg?branch=master)](https://travis-ci.org/dgrijalva/jwt-go)
[![GoDoc](https://godoc.org/github.com/dgrijalva/jwt-go?status.svg)](https://godoc.org/github.com/dgrijalva/jwt-go)
**BREAKING CHANGES:*** Version 3.0.0 is here. It includes _a lot_ of changes including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
**NOTICE:** A vulnerability in JWT was [recently published](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). As this library doesn't force users to validate the `alg` is what they expected, it's possible your usage is effected. There will be an update soon to remedy this, and it will likey require backwards-incompatible changes to the API. In the short term, please make sure your implementation verifies the `alg` is what you expect.
**NEW VERSION COMING:** There have been a lot of improvements suggested since the version 3.0.0 released in 2016. I'm working now on cutting two different releases: 3.2.0 will contain any non-breaking changes or enhancements. 4.0.0 will follow shortly which will include breaking changes. See the 4.0.0 milestone to get an idea of what's coming. If you have other ideas, or would like to participate in 4.0.0, now's the time. If you depend on this library and don't want to be interrupted, I recommend you use your dependency mangement tool to pin to version 3.
**SECURITY NOTICE:** Some older versions of Go have a security issue in the cryotp/elliptic. Recommendation is to upgrade to at least 1.8.3. See issue #216 for more detail.
**SECURITY NOTICE:** It's important that you [validate the `alg` presented is what you expect](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key types match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided.
## What the heck is a JWT?
@@ -37,7 +41,7 @@ Here's an example of an extension that integrates with the Google App Engine sig
## Compliance
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
* In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.
@@ -47,7 +51,10 @@ This library is considered production ready. Feedback and feature requests are
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v2`. It will do the right thing WRT semantic versioning.
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v3`. It will do the right thing WRT semantic versioning.
**BREAKING CHANGES:***
* Version 3.0.0 includes _a lot_ of changes from the 2.x line, including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
## Usage Tips
@@ -68,18 +75,26 @@ Symmetric signing methods, such as HSA, use only a single secret. This is probab
Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification.
### Signing Methods and Key Types
Each signing method expects a different object type for its signing keys. See the package documentation for details. Here are the most common ones:
* The [HMAC signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodHMAC) (`HS256`,`HS384`,`HS512`) expect `[]byte` values for signing and validation
* The [RSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodRSA) (`RS256`,`RS384`,`RS512`) expect `*rsa.PrivateKey` for signing and `*rsa.PublicKey` for validation
* The [ECDSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodECDSA) (`ES256`,`ES384`,`ES512`) expect `*ecdsa.PrivateKey` for signing and `*ecdsa.PublicKey` for validation
### JWT and OAuth
It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication.
Without going too far down the rabbit hole, here's a description of the interaction of these technologies:
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
* OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token.
* Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL.
## More
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in to documentation.
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in the documentation.

View File

@@ -1,5 +1,18 @@
## `jwt-go` Version History
#### 3.2.0
* Added method `ParseUnverified` to allow users to split up the tasks of parsing and validation
* HMAC signing method returns `ErrInvalidKeyType` instead of `ErrInvalidKey` where appropriate
* Added options to `request.ParseFromRequest`, which allows for an arbitrary list of modifiers to parsing behavior. Initial set include `WithClaims` and `WithParser`. Existing usage of this function will continue to work as before.
* Deprecated `ParseFromRequestWithClaims` to simplify API in the future.
#### 3.1.0
* Improvements to `jwt` command line tool
* Added `SkipClaimsValidation` option to `Parser`
* Documentation updates
#### 3.0.0
* **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code

View File

@@ -6,8 +6,8 @@ the command line.
The following will create and sign a token, then verify it and output the original claims:
echo {\"foo\":\"bar\"} | bin/jwt -key test/sample_key -alg RS256 -sign - | bin/jwt -key test/sample_key.pub -verify -
echo {\"foo\":\"bar\"} | ./jwt -key ../../test/sample_key -alg RS256 -sign - | ./jwt -key ../../test/sample_key.pub -alg RS256 -verify -
To simply display a token, use:
echo $JWT | jwt -show -
echo $JWT | ./jwt -show -

View File

@@ -25,14 +25,20 @@ var (
flagKey = flag.String("key", "", "path to key file or '-' to read from stdin")
flagCompact = flag.Bool("compact", false, "output compact JSON")
flagDebug = flag.Bool("debug", false, "print out all kinds of debug data")
flagClaims = make(ArgList)
flagHead = make(ArgList)
// Modes - exactly one of these is required
flagSign = flag.String("sign", "", "path to claims object to sign or '-' to read from stdin")
flagSign = flag.String("sign", "", "path to claims object to sign, '-' to read from stdin, or '+' to use only -claim args")
flagVerify = flag.String("verify", "", "path to JWT token to verify or '-' to read from stdin")
flagShow = flag.String("show", "", "path to JWT file or '-' to read from stdin")
)
func main() {
// Plug in Var flags
flag.Var(flagClaims, "claim", "add additional claims. may be used more than once")
flag.Var(flagHead, "header", "add additional header params. may be used more than once")
// Usage message if you ask for -help or if you mess up inputs.
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
@@ -74,6 +80,8 @@ func loadData(p string) ([]byte, error) {
var rdr io.Reader
if p == "-" {
rdr = os.Stdin
} else if p == "+" {
return []byte("{}"), nil
} else {
if f, err := os.Open(p); err == nil {
rdr = f
@@ -126,6 +134,8 @@ func verifyToken() error {
}
if isEs() {
return jwt.ParseECPublicKeyFromPEM(data)
} else if isRs() {
return jwt.ParseRSAPublicKeyFromPEM(data)
}
return data, nil
})
@@ -171,6 +181,13 @@ func signToken() error {
return fmt.Errorf("Couldn't parse claims JSON: %v", err)
}
// add command line claims
if len(flagClaims) > 0 {
for k, v := range flagClaims {
claims[k] = v
}
}
// get the key
var key interface{}
key, err = loadData(*flagKey)
@@ -187,6 +204,13 @@ func signToken() error {
// create a new token
token := jwt.NewWithClaims(alg, claims)
// add command line headers
if len(flagHead) > 0 {
for k, v := range flagHead {
token.Header[k] = v
}
}
if isEs() {
if k, ok := key.([]byte); !ok {
return fmt.Errorf("Couldn't convert key data to key")
@@ -196,6 +220,15 @@ func signToken() error {
return err
}
}
} else if isRs() {
if k, ok := key.([]byte); !ok {
return fmt.Errorf("Couldn't convert key data to key")
} else {
key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
if err != nil {
return err
}
}
}
if out, err := token.SignedString(key); err == nil {
@@ -243,3 +276,7 @@ func showToken() error {
func isEs() bool {
return strings.HasPrefix(*flagAlg, "ES")
}
func isRs() bool {
return strings.HasPrefix(*flagAlg, "RS")
}

23
vendor/github.com/dgrijalva/jwt-go/cmd/jwt/args.go generated vendored Normal file
View File

@@ -0,0 +1,23 @@
package main
import (
"encoding/json"
"fmt"
"strings"
)
type ArgList map[string]string
func (l ArgList) String() string {
data, _ := json.Marshal(l)
return string(data)
}
func (l ArgList) Set(arg string) error {
parts := strings.SplitN(arg, "=", 2)
if len(parts) != 2 {
return fmt.Errorf("Invalid argument '%v'. Must use format 'key=value'. %v", arg, parts)
}
l[parts[0]] = parts[1]
return nil
}

View File

@@ -14,6 +14,7 @@ var (
)
// Implements the ECDSA family of signing methods signing methods
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
type SigningMethodECDSA struct {
Name string
Hash crypto.Hash

View File

@@ -51,13 +51,9 @@ func (e ValidationError) Error() string {
} else {
return "token is invalid"
}
return e.Inner.Error()
}
// No errors
func (e *ValidationError) valid() bool {
if e.Errors > 0 {
return false
}
return true
return e.Errors == 0
}

View File

@@ -7,6 +7,7 @@ import (
)
// Implements the HMAC-SHA family of signing methods signing methods
// Expects key type of []byte for both signing and validation
type SigningMethodHMAC struct {
Name string
Hash crypto.Hash
@@ -90,5 +91,5 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string,
return EncodeSegment(hasher.Sum(nil)), nil
}
return "", ErrInvalidKey
return "", ErrInvalidKeyType
}

View File

@@ -51,6 +51,8 @@ func ExampleParse_hmac() {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
return hmacSampleSecret, nil
})

View File

@@ -21,55 +21,9 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
}
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
var err error
token := &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
return token, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
}
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// parse Claims
var claimBytes []byte
token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.UseJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
token, parts, err := p.ParseUnverified(tokenString, claims)
if err != nil {
return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
}
} else {
return token, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
return token, err
}
// Verify signing method is in the required set
@@ -96,6 +50,9 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
}
if key, err = keyFunc(token); err != nil {
// keyFunc returned an error
if ve, ok := err.(*ValidationError); ok {
return token, ve
}
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
}
@@ -129,3 +86,63 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
return token, vErr
}
// WARNING: Don't use this method unless you know what you're doing
//
// This method parses the token but doesn't validate the signature. It's only
// ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from
// it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
token = &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
}
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// parse Claims
var claimBytes []byte
token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.UseJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
if err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
}
} else {
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
}
return token, parts, nil
}

View File

@@ -222,7 +222,7 @@ func TestParser_Parse(t *testing.T) {
}
if (err == nil && !token.Valid) || (err != nil && token.Valid) {
t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid")
t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name)
}
if data.errors != 0 {
@@ -247,6 +247,46 @@ func TestParser_Parse(t *testing.T) {
}
}
func TestParser_ParseUnverified(t *testing.T) {
privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key")
// Iterate over test data set and run tests
for _, data := range jwtTestData {
// If the token string is blank, use helper function to generate string
if data.tokenString == "" {
data.tokenString = test.MakeSampleToken(data.claims, privateKey)
}
// Parse the token
var token *jwt.Token
var err error
var parser = data.parser
if parser == nil {
parser = new(jwt.Parser)
}
// Figure out correct claims type
switch data.claims.(type) {
case jwt.MapClaims:
token, _, err = parser.ParseUnverified(data.tokenString, jwt.MapClaims{})
case *jwt.StandardClaims:
token, _, err = parser.ParseUnverified(data.tokenString, &jwt.StandardClaims{})
}
if err != nil {
t.Errorf("[%v] Invalid token", data.name)
}
// Verify result matches expectation
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err)
}
}
}
// Helper method for benchmarking various methods
func benchmarkSigning(b *testing.B, method jwt.SigningMethod, key interface{}) {
t := jwt.New(method)

View File

@@ -9,16 +9,60 @@ import (
// This behaves the same as Parse, but accepts a request and an extractor
// instead of a token string. The Extractor interface allows you to define
// the logic for extracting a token. Several useful implementations are provided.
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
return ParseFromRequestWithClaims(req, extractor, jwt.MapClaims{}, keyFunc)
//
// You can provide options to modify parsing behavior
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...ParseFromRequestOption) (token *jwt.Token, err error) {
// Create basic parser struct
p := &fromRequestParser{req, extractor, nil, nil}
// Handle options
for _, option := range options {
option(p)
}
// Set defaults
if p.claims == nil {
p.claims = jwt.MapClaims{}
}
if p.parser == nil {
p.parser = &jwt.Parser{}
}
// perform extract
tokenString, err := p.extractor.ExtractToken(req)
if err != nil {
return nil, err
}
// perform parse
return p.parser.ParseWithClaims(tokenString, p.claims, keyFunc)
}
// ParseFromRequest but with custom Claims type
// DEPRECATED: use ParseFromRequest and the WithClaims option
func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
// Extract token from request
if tokStr, err := extractor.ExtractToken(req); err == nil {
return jwt.ParseWithClaims(tokStr, claims, keyFunc)
} else {
return nil, err
return ParseFromRequest(req, extractor, keyFunc, WithClaims(claims))
}
type fromRequestParser struct {
req *http.Request
extractor Extractor
claims jwt.Claims
parser *jwt.Parser
}
type ParseFromRequestOption func(*fromRequestParser)
// Parse with custom claims
func WithClaims(claims jwt.Claims) ParseFromRequestOption {
return func(p *fromRequestParser) {
p.claims = claims
}
}
// Parse using a custom parser
func WithParser(parser *jwt.Parser) ParseFromRequestOption {
return func(p *fromRequestParser) {
p.parser = parser
}
}

Some files were not shown because too many files have changed in this diff Show More