1
0
mirror of https://github.com/kubernetes-sigs/descheduler.git synced 2026-01-27 05:46:13 +01:00

bump(*): kubernetes release-1.16.0 dependencies

This commit is contained in:
Mike Dame
2019-10-12 11:11:43 -04:00
parent 5af668e89a
commit 1652ba7976
28121 changed files with 3491095 additions and 2280257 deletions

View File

@@ -36,9 +36,5 @@ func ServeJSON(w http.ResponseWriter, err error) error {
w.WriteHeader(sc)
if err := json.NewEncoder(w).Encode(err); err != nil {
return err
}
return nil
return json.NewEncoder(w).Encode(err)
}

View File

@@ -14,15 +14,6 @@ const (
RouteNameCatalog = "catalog"
)
var allEndpoints = []string{
RouteNameManifest,
RouteNameCatalog,
RouteNameTags,
RouteNameBlob,
RouteNameBlobUpload,
RouteNameBlobUploadChunk,
}
// Router builds a gorilla router with named routes for the various API
// methods. This can be used directly by both server implementations and
// clients.

View File

@@ -33,11 +33,10 @@
package auth
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/docker/distribution/context"
)
const (

View File

@@ -6,13 +6,19 @@
package htpasswd
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"os"
"path/filepath"
"sync"
"time"
"github.com/docker/distribution/context"
"golang.org/x/crypto/bcrypt"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
)
@@ -32,16 +38,19 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
return nil, fmt.Errorf(`"realm" must be set for htpasswd access controller`)
}
path, present := options["path"]
if _, ok := path.(string); !present || !ok {
pathOpt, present := options["path"]
path, ok := pathOpt.(string)
if !present || !ok {
return nil, fmt.Errorf(`"path" must be set for htpasswd access controller`)
}
return &accessController{realm: realm.(string), path: path.(string)}, nil
if err := createHtpasswdFile(path); err != nil {
return nil, err
}
return &accessController{realm: realm.(string), path: path}, nil
}
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := context.GetRequest(ctx)
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
@@ -83,7 +92,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
ac.mu.Unlock()
if err := localHTPasswd.authenticateUser(username, password); err != nil {
context.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err)
dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err)
return nil, &challenge{
realm: ac.realm,
err: auth.ErrAuthenticationFailure,
@@ -110,6 +119,42 @@ func (ch challenge) Error() string {
return fmt.Sprintf("basic authentication challenge for realm %q: %s", ch.realm, ch.err)
}
// createHtpasswdFile creates and populates htpasswd file with a new user in case the file is missing
func createHtpasswdFile(path string) error {
if f, err := os.Open(path); err == nil {
f.Close()
return nil
} else if !os.IsNotExist(err) {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return fmt.Errorf("failed to open htpasswd path %s", err)
}
defer f.Close()
var secretBytes [32]byte
if _, err := rand.Read(secretBytes[:]); err != nil {
return err
}
pass := base64.RawURLEncoding.EncodeToString(secretBytes[:])
encryptedPass, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost)
if err != nil {
return err
}
if _, err := f.Write([]byte(fmt.Sprintf("docker:%s", string(encryptedPass[:])))); err != nil {
return err
}
dcontext.GetLoggerWithFields(context.Background(), map[interface{}]interface{}{
"user": "docker",
"password": pass,
}).Warnf("htpasswd is missing, provisioning with default user")
return nil
}
func init() {
auth.Register("htpasswd", auth.InitFunc(newAccessController))
}

View File

@@ -1,9 +1,11 @@
package htpasswd
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/docker/distribution/context"
@@ -120,3 +122,41 @@ func TestBasicAccessController(t *testing.T) {
}
}
func TestCreateHtpasswdFile(t *testing.T) {
tempFile, err := ioutil.TempFile("", "htpasswd-test")
if err != nil {
t.Fatalf("could not create temporary htpasswd file %v", err)
}
defer tempFile.Close()
options := map[string]interface{}{
"realm": "/auth/htpasswd",
"path": tempFile.Name(),
}
// Ensure file is not populated
if _, err := newAccessController(options); err != nil {
t.Fatalf("error creating access controller %v", err)
}
content, err := ioutil.ReadAll(tempFile)
if err != nil {
t.Fatalf("failed to read file %v", err)
}
if !bytes.Equal([]byte{}, content) {
t.Fatalf("htpasswd file should not be populated %v", string(content))
}
if err := os.Remove(tempFile.Name()); err != nil {
t.Fatalf("failed to remove temp file %v", err)
}
// Ensure htpasswd file is populated
if _, err := newAccessController(options); err != nil {
t.Fatalf("error creating access controller %v", err)
}
content, err = ioutil.ReadFile(tempFile.Name())
if err != nil {
t.Fatalf("failed to read file %v", err)
}
if !bytes.HasPrefix(content, []byte("docker:$2a$")) {
t.Fatalf("failed to find default user in file %s", string(content))
}
}

View File

@@ -38,7 +38,7 @@ func (htpasswd *htpasswd) authenticateUser(username string, password string) err
return auth.ErrAuthenticationFailure
}
err := bcrypt.CompareHashAndPassword([]byte(credentials), []byte(password))
err := bcrypt.CompareHashAndPassword(credentials, []byte(password))
if err != nil {
return auth.ErrAuthenticationFailure
}

View File

@@ -8,11 +8,12 @@
package silly
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
)
@@ -43,7 +44,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized simply checks for the existence of the authorization header,
// responding with a bearer challenge if it doesn't exist.
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := context.GetRequest(ctx)
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
@@ -65,7 +66,11 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
return nil, &challenge
}
return auth.WithUser(ctx, auth.UserInfo{Name: "silly"}), nil
ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"})
ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey))
return ctx, nil
}
type challenge struct {

View File

@@ -1,6 +1,7 @@
package token
import (
"context"
"crypto"
"crypto/x509"
"encoding/pem"
@@ -11,7 +12,7 @@ import (
"os"
"strings"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
"github.com/docker/libtrust"
)
@@ -221,7 +222,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
accessSet: newAccessSet(accessItems...),
}
req, err := context.GetRequest(ctx)
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}

View File

@@ -45,13 +45,13 @@ type Manager interface {
// to a backend.
func NewSimpleManager() Manager {
return &simpleManager{
Challanges: make(map[string][]Challenge),
Challenges: make(map[string][]Challenge),
}
}
type simpleManager struct {
sync.RWMutex
Challanges map[string][]Challenge
Challenges map[string][]Challenge
}
func normalizeURL(endpoint *url.URL) {
@@ -64,7 +64,7 @@ func (m *simpleManager) GetChallenges(endpoint url.URL) ([]Challenge, error) {
m.RLock()
defer m.RUnlock()
challenges := m.Challanges[endpoint.String()]
challenges := m.Challenges[endpoint.String()]
return challenges, nil
}
@@ -82,7 +82,7 @@ func (m *simpleManager) AddResponse(resp *http.Response) error {
m.Lock()
defer m.Unlock()
m.Challanges[urlCopy.String()] = challenges
m.Challenges[urlCopy.String()] = challenges
return nil
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/docker/distribution/registry/client"
"github.com/docker/distribution/registry/client/auth/challenge"
"github.com/docker/distribution/registry/client/transport"
"github.com/sirupsen/logrus"
)
var (
@@ -69,7 +68,6 @@ func NewAuthorizer(manager challenge.Manager, handlers ...AuthenticationHandler)
type endpointAuthorizer struct {
challenges challenge.Manager
handlers []AuthenticationHandler
transport http.RoundTripper
}
func (ea *endpointAuthorizer) ModifyRequest(req *http.Request) error {
@@ -122,7 +120,6 @@ type clock interface {
}
type tokenHandler struct {
header http.Header
creds CredentialStore
transport http.RoundTripper
clock clock
@@ -135,6 +132,8 @@ type tokenHandler struct {
tokenLock sync.Mutex
tokenCache string
tokenExpiration time.Time
logger Logger
}
// Scope is a type which is serializable to a string
@@ -176,6 +175,18 @@ func (rs RegistryScope) String() string {
return fmt.Sprintf("registry:%s:%s", rs.Name, strings.Join(rs.Actions, ","))
}
// Logger defines the injectable logging interface, used on TokenHandlers.
type Logger interface {
Debugf(format string, args ...interface{})
}
func logDebugf(logger Logger, format string, args ...interface{}) {
if logger == nil {
return
}
logger.Debugf(format, args...)
}
// TokenHandlerOptions is used to configure a new token handler
type TokenHandlerOptions struct {
Transport http.RoundTripper
@@ -185,6 +196,7 @@ type TokenHandlerOptions struct {
ForceOAuth bool
ClientID string
Scopes []Scope
Logger Logger
}
// An implementation of clock for providing real time data.
@@ -220,6 +232,7 @@ func NewTokenHandlerWithOptions(options TokenHandlerOptions) AuthenticationHandl
clientID: options.ClientID,
scopes: options.Scopes,
clock: realClock{},
logger: options.Logger,
}
return handler
@@ -264,6 +277,9 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s
}
var addedScopes bool
for _, scope := range additionalScopes {
if hasScope(scopes, scope) {
continue
}
scopes = append(scopes, scope)
addedScopes = true
}
@@ -287,6 +303,15 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s
return th.tokenCache, nil
}
func hasScope(scopes []string, scope string) bool {
for _, s := range scopes {
if s == scope {
return true
}
}
return false
}
type postTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
@@ -348,7 +373,7 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic
if tr.ExpiresIn < minimumTokenLifetimeSeconds {
// The default/minimum lifetime.
tr.ExpiresIn = minimumTokenLifetimeSeconds
logrus.Debugf("Increasing token expiration to: %d seconds", tr.ExpiresIn)
logDebugf(th.logger, "Increasing token expiration to: %d seconds", tr.ExpiresIn)
}
if tr.IssuedAt.IsZero() {
@@ -439,7 +464,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string,
if tr.ExpiresIn < minimumTokenLifetimeSeconds {
// The default/minimum lifetime.
tr.ExpiresIn = minimumTokenLifetimeSeconds
logrus.Debugf("Increasing token expiration to: %d seconds", tr.ExpiresIn)
logDebugf(th.logger, "Increasing token expiration to: %d seconds", tr.ExpiresIn)
}
if tr.IssuedAt.IsZero() {

View File

@@ -2,6 +2,7 @@ package client
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
@@ -9,7 +10,6 @@ import (
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
)
type httpBlobUpload struct {

View File

@@ -2,6 +2,7 @@ package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -14,7 +15,6 @@ import (
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/registry/client/transport"
@@ -62,7 +62,7 @@ func checkHTTPRedirect(req *http.Request, via []*http.Request) error {
}
// NewRegistry creates a registry namespace which can be used to get a listing of repositories
func NewRegistry(ctx context.Context, baseURL string, transport http.RoundTripper) (Registry, error) {
func NewRegistry(baseURL string, transport http.RoundTripper) (Registry, error) {
ub, err := v2.NewURLBuilderFromString(baseURL, false)
if err != nil {
return nil, err
@@ -75,16 +75,14 @@ func NewRegistry(ctx context.Context, baseURL string, transport http.RoundTrippe
}
return &registry{
client: client,
ub: ub,
context: ctx,
client: client,
ub: ub,
}, nil
}
type registry struct {
client *http.Client
ub *v2.URLBuilder
context context.Context
client *http.Client
ub *v2.URLBuilder
}
// Repositories returns a lexigraphically sorted catalog given a base URL. The 'entries' slice will be filled up to the size
@@ -133,7 +131,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
}
// NewRepository creates a new Repository for the given repository name and base URL.
func NewRepository(ctx context.Context, name reference.Named, baseURL string, transport http.RoundTripper) (distribution.Repository, error) {
func NewRepository(name reference.Named, baseURL string, transport http.RoundTripper) (distribution.Repository, error) {
ub, err := v2.NewURLBuilderFromString(baseURL, false)
if err != nil {
return nil, err
@@ -146,18 +144,16 @@ func NewRepository(ctx context.Context, name reference.Named, baseURL string, tr
}
return &repository{
client: client,
ub: ub,
name: name,
context: ctx,
client: client,
ub: ub,
name: name,
}, nil
}
type repository struct {
client *http.Client
ub *v2.URLBuilder
context context.Context
name reference.Named
client *http.Client
ub *v2.URLBuilder
name reference.Named
}
func (r *repository) Named() reference.Named {
@@ -190,32 +186,35 @@ func (r *repository) Manifests(ctx context.Context, options ...distribution.Mani
func (r *repository) Tags(ctx context.Context) distribution.TagService {
return &tags{
client: r.client,
ub: r.ub,
context: r.context,
name: r.Named(),
client: r.client,
ub: r.ub,
name: r.Named(),
}
}
// tags implements remote tagging operations.
type tags struct {
client *http.Client
ub *v2.URLBuilder
context context.Context
name reference.Named
client *http.Client
ub *v2.URLBuilder
name reference.Named
}
// All returns all tags
func (t *tags) All(ctx context.Context) ([]string, error) {
var tags []string
u, err := t.ub.BuildTagsURL(t.name)
listURLStr, err := t.ub.BuildTagsURL(t.name)
if err != nil {
return tags, err
}
listURL, err := url.Parse(listURLStr)
if err != nil {
return tags, err
}
for {
resp, err := t.client.Get(u)
resp, err := t.client.Get(listURL.String())
if err != nil {
return tags, err
}
@@ -235,7 +234,13 @@ func (t *tags) All(ctx context.Context) ([]string, error) {
}
tags = append(tags, tagsResponse.Tags...)
if link := resp.Header.Get("Link"); link != "" {
u = strings.Trim(strings.Split(link, ";")[0], "<>")
linkURLStr := strings.Trim(strings.Split(link, ";")[0], "<>")
linkURL, err := url.Parse(linkURLStr)
if err != nil {
return tags, err
}
listURL = listURL.ResolveReference(linkURL)
} else {
return tags, nil
}
@@ -321,7 +326,8 @@ func (t *tags) Get(ctx context.Context, tag string) (distribution.Descriptor, er
defer resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 400:
case resp.StatusCode >= 200 && resp.StatusCode < 400 && len(resp.Header.Get("Docker-Content-Digest")) > 0:
// if the response is a success AND a Docker-Content-Digest can be retrieved from the headers
return descriptorFromResponse(resp)
default:
// if the response is an error - there will be no body to decode.
@@ -421,18 +427,22 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
ref reference.Named
err error
contentDgst *digest.Digest
mediaTypes []string
)
for _, option := range options {
if opt, ok := option.(distribution.WithTagOption); ok {
switch opt := option.(type) {
case distribution.WithTagOption:
digestOrTag = opt.Tag
ref, err = reference.WithTag(ms.name, opt.Tag)
if err != nil {
return nil, err
}
} else if opt, ok := option.(contentDigestOption); ok {
case contentDigestOption:
contentDgst = opt.digest
} else {
case distribution.WithManifestMediaTypesOption:
mediaTypes = opt.MediaTypes
default:
err := option.Apply(ms)
if err != nil {
return nil, err
@@ -448,6 +458,10 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
}
}
if len(mediaTypes) == 0 {
mediaTypes = distribution.ManifestMediaTypes()
}
u, err := ms.ub.BuildManifestURL(ref)
if err != nil {
return nil, err
@@ -458,7 +472,7 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
return nil, err
}
for _, t := range distribution.ManifestMediaTypes() {
for _, t := range mediaTypes {
req.Header.Add("Accept", t)
}

View File

@@ -9,6 +9,8 @@ import (
"log"
"net/http"
"net/http/httptest"
"reflect"
"sort"
"strconv"
"strings"
"testing"
@@ -118,7 +120,7 @@ func TestBlobDelete(t *testing.T) {
defer c()
ctx := context.Background()
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -140,7 +142,7 @@ func TestBlobFetch(t *testing.T) {
ctx := context.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -194,7 +196,7 @@ func TestBlobExistsNoContentLength(t *testing.T) {
defer c()
ctx := context.Background()
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -220,7 +222,7 @@ func TestBlobExists(t *testing.T) {
ctx := context.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -325,7 +327,7 @@ func TestBlobUploadChunked(t *testing.T) {
defer c()
ctx := context.Background()
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -435,7 +437,7 @@ func TestBlobUploadMonolithic(t *testing.T) {
defer c()
ctx := context.Background()
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -512,7 +514,7 @@ func TestBlobMount(t *testing.T) {
defer c()
ctx := context.Background()
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -647,7 +649,38 @@ func addTestManifest(repo reference.Named, reference string, mediatype string, c
}),
},
})
}
func addTestManifestWithoutDigestHeader(repo reference.Named, reference string, mediatype string, content []byte, m *testutil.RequestResponseMap) {
*m = append(*m, testutil.RequestResponseMapping{
Request: testutil.Request{
Method: "GET",
Route: "/v2/" + repo.Name() + "/manifests/" + reference,
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: content,
Headers: http.Header(map[string][]string{
"Content-Length": {fmt.Sprint(len(content))},
"Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)},
"Content-Type": {mediatype},
}),
},
})
*m = append(*m, testutil.RequestResponseMapping{
Request: testutil.Request{
Method: "HEAD",
Route: "/v2/" + repo.Name() + "/manifests/" + reference,
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Headers: http.Header(map[string][]string{
"Content-Length": {fmt.Sprint(len(content))},
"Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)},
"Content-Type": {mediatype},
}),
},
})
}
func checkEqualManifest(m1, m2 *schema1.SignedManifest) error {
@@ -692,7 +725,7 @@ func TestV1ManifestFetch(t *testing.T) {
e, c := testServer(m)
defer c()
r, err := NewRepository(context.Background(), repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -764,7 +797,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
defer c()
ctx := context.Background()
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -784,6 +817,65 @@ func TestManifestFetchWithEtag(t *testing.T) {
}
}
func TestManifestFetchWithAccept(t *testing.T) {
ctx := context.Background()
repo, _ := reference.WithName("test.example.com/repo")
_, dgst, _ := newRandomSchemaV1Manifest(repo, "latest", 6)
headers := make(chan []string, 1)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
headers <- req.Header["Accept"]
}))
defer close(headers)
defer s.Close()
r, err := NewRepository(repo, s.URL, nil)
if err != nil {
t.Fatal(err)
}
ms, err := r.Manifests(ctx)
if err != nil {
t.Fatal(err)
}
testCases := []struct {
// the media types we send
mediaTypes []string
// the expected Accept headers the server should receive
expect []string
// whether to sort the request and response values for comparison
sort bool
}{
{
mediaTypes: []string{},
expect: distribution.ManifestMediaTypes(),
sort: true,
},
{
mediaTypes: []string{"test1", "test2"},
expect: []string{"test1", "test2"},
},
{
mediaTypes: []string{"test1"},
expect: []string{"test1"},
},
{
mediaTypes: []string{""},
expect: []string{""},
},
}
for _, testCase := range testCases {
ms.Get(ctx, dgst, distribution.WithManifestMediaTypes(testCase.mediaTypes))
actual := <-headers
if testCase.sort {
sort.Strings(actual)
sort.Strings(testCase.expect)
}
if !reflect.DeepEqual(actual, testCase.expect) {
t.Fatalf("unexpected Accept header values: %v", actual)
}
}
}
func TestManifestDelete(t *testing.T) {
repo, _ := reference.WithName("test.example.com/repo/delete")
_, dgst1, _ := newRandomSchemaV1Manifest(repo, "latest", 6)
@@ -805,7 +897,7 @@ func TestManifestDelete(t *testing.T) {
e, c := testServer(m)
defer c()
r, err := NewRepository(context.Background(), repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -868,7 +960,7 @@ func TestManifestPut(t *testing.T) {
e, c := testServer(m)
defer c()
r, err := NewRepository(context.Background(), repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -921,7 +1013,7 @@ func TestManifestTags(t *testing.T) {
e, c := testServer(m)
defer c()
r, err := NewRepository(context.Background(), repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -978,7 +1070,7 @@ func TestObtainsErrorForMissingTag(t *testing.T) {
defer c()
ctx := context.Background()
r, err := NewRepository(ctx, repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -994,6 +1086,36 @@ func TestObtainsErrorForMissingTag(t *testing.T) {
}
}
func TestObtainsManifestForTagWithoutHeaders(t *testing.T) {
repo, _ := reference.WithName("test.example.com/repo")
var m testutil.RequestResponseMap
m1, dgst, _ := newRandomSchemaV1Manifest(repo, "latest", 6)
_, pl, err := m1.Payload()
if err != nil {
t.Fatal(err)
}
addTestManifestWithoutDigestHeader(repo, "1.0.0", schema1.MediaTypeSignedManifest, pl, &m)
e, c := testServer(m)
defer c()
ctx := context.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
tagService := r.Tags(ctx)
desc, err := tagService.Get(ctx, "1.0.0")
if err != nil {
t.Fatalf("Expected no error")
}
if desc.Digest != dgst {
t.Fatalf("Unexpected digest")
}
}
func TestManifestTagsPaginated(t *testing.T) {
s := httptest.NewServer(http.NotFoundHandler())
defer s.Close()
@@ -1014,13 +1136,27 @@ func TestManifestTagsPaginated(t *testing.T) {
queryParams["n"] = []string{"1"}
queryParams["last"] = []string{tagsList[i-1]}
}
// Test both relative and absolute links.
relativeLink := "/v2/" + repo.Name() + "/tags/list?n=1&last=" + tagsList[i]
var link string
switch i {
case 0:
link = relativeLink
case len(tagsList) - 1:
link = ""
default:
link = s.URL + relativeLink
}
headers := http.Header(map[string][]string{
"Content-Length": {fmt.Sprint(len(body))},
"Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)},
})
if i < 2 {
headers.Set("Link", "<"+s.URL+"/v2/"+repo.Name()+"/tags/list?n=1&last="+tagsList[i]+`>; rel="next"`)
if link != "" {
headers.Set("Link", fmt.Sprintf(`<%s>; rel="next"`, link))
}
m = append(m, testutil.RequestResponseMapping{
Request: testutil.Request{
Method: "GET",
@@ -1037,7 +1173,7 @@ func TestManifestTagsPaginated(t *testing.T) {
s.Config.Handler = testutil.NewHandler(m)
r, err := NewRepository(context.Background(), repo, s.URL, nil)
r, err := NewRepository(repo, s.URL, nil)
if err != nil {
t.Fatal(err)
}
@@ -1085,7 +1221,7 @@ func TestManifestUnauthorized(t *testing.T) {
e, c := testServer(m)
defer c()
r, err := NewRepository(context.Background(), repo, e, nil)
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
}
@@ -1122,7 +1258,7 @@ func TestCatalog(t *testing.T) {
entries := make([]string, 5)
r, err := NewRegistry(context.Background(), e, nil)
r, err := NewRegistry(e, nil)
if err != nil {
t.Fatal(err)
}
@@ -1154,7 +1290,7 @@ func TestCatalogInParts(t *testing.T) {
entries := make([]string, 2)
r, err := NewRegistry(context.Background(), e, nil)
r, err := NewRegistry(e, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"os"
"regexp"
"strconv"
)
@@ -97,7 +96,7 @@ func (hrs *httpReadSeeker) Seek(offset int64, whence int) (int64, error) {
lastReaderOffset := hrs.readerOffset
if whence == os.SEEK_SET && hrs.rc == nil {
if whence == io.SeekStart && hrs.rc == nil {
// If no request has been made yet, and we are seeking to an
// absolute position, set the read offset as well to avoid an
// unnecessary request.
@@ -113,14 +112,14 @@ func (hrs *httpReadSeeker) Seek(offset int64, whence int) (int64, error) {
newOffset := hrs.seekOffset
switch whence {
case os.SEEK_CUR:
case io.SeekCurrent:
newOffset += offset
case os.SEEK_END:
case io.SeekEnd:
if hrs.size < 0 {
return 0, errors.New("content length not known")
}
newOffset = hrs.size + offset
case os.SEEK_SET:
case io.SeekStart:
newOffset = offset
}

View File

@@ -2,6 +2,7 @@ package handlers
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -21,7 +22,6 @@ import (
"github.com/docker/distribution"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/context"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/manifestlist"
"github.com/docker/distribution/manifest/schema1"
@@ -478,7 +478,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
// -----------------------------------------
// Do layer push with an empty body and different digest
uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
uploadURLBase, _ = startPushLayer(t, env, imageName)
resp, err = doPushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, bytes.NewReader([]byte{}))
if err != nil {
t.Fatalf("unexpected error doing bad layer push: %v", err)
@@ -494,7 +494,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
t.Fatalf("unexpected error digesting empty buffer: %v", err)
}
uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
uploadURLBase, _ = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, zeroDigest, uploadURLBase, bytes.NewReader([]byte{}))
// -----------------------------------------
@@ -507,15 +507,15 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
t.Fatalf("unexpected error digesting empty tar: %v", err)
}
uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
uploadURLBase, _ = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, emptyDigest, uploadURLBase, bytes.NewReader(emptyTar))
// ------------------------------------------
// Now, actually do successful upload.
layerLength, _ := layerFile.Seek(0, os.SEEK_END)
layerFile.Seek(0, os.SEEK_SET)
layerLength, _ := layerFile.Seek(0, io.SeekEnd)
layerFile.Seek(0, io.SeekStart)
uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
uploadURLBase, _ = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
// ------------------------------------------
@@ -529,7 +529,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
canonicalDigest := canonicalDigester.Digest()
layerFile.Seek(0, 0)
uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
uploadURLBase, _ = startPushLayer(t, env, imageName)
uploadURLBase, dgst := pushChunk(t, env.builder, imageName, uploadURLBase, layerFile, layerLength)
finishUpload(t, env.builder, imageName, uploadURLBase, dgst)
@@ -612,7 +612,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
t.Fatalf("Error constructing request: %s", err)
}
req.Header.Set("If-None-Match", "")
resp, err = http.DefaultClient.Do(req)
resp, _ = http.DefaultClient.Do(req)
checkResponse(t, "fetching layer with invalid etag", resp, http.StatusOK)
// Missing tests:
@@ -674,12 +674,12 @@ func testBlobDelete(t *testing.T, env *testEnv, args blobArgs) {
// ----------------
// Reupload previously deleted blob
layerFile.Seek(0, os.SEEK_SET)
layerFile.Seek(0, io.SeekStart)
uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
layerFile.Seek(0, os.SEEK_SET)
layerFile.Seek(0, io.SeekStart)
canonicalDigester := digest.Canonical.Digester()
if _, err := io.Copy(canonicalDigester.Hash(), layerFile); err != nil {
t.Fatalf("error copying to digest: %v", err)
@@ -693,7 +693,7 @@ func testBlobDelete(t *testing.T, env *testEnv, args blobArgs) {
t.Fatalf("unexpected error checking head on existing layer: %v", err)
}
layerLength, _ := layerFile.Seek(0, os.SEEK_END)
layerLength, _ := layerFile.Seek(0, io.SeekEnd)
checkResponse(t, "checking head on reuploaded layer", resp, http.StatusOK)
checkHeaders(t, resp, http.Header{
"Content-Length": []string{fmt.Sprint(layerLength)},
@@ -1874,7 +1874,7 @@ func testManifestDelete(t *testing.T, env *testEnv, args manifestArgs) {
manifest := args.manifest
ref, _ := reference.WithDigest(imageName, dgst)
manifestDigestURL, err := env.builder.BuildManifestURL(ref)
manifestDigestURL, _ := env.builder.BuildManifestURL(ref)
// ---------------
// Delete by digest
resp, err := httpDelete(manifestDigestURL)
@@ -1935,7 +1935,7 @@ func testManifestDelete(t *testing.T, env *testEnv, args manifestArgs) {
// Upload manifest by tag
tag := "atag"
tagRef, _ := reference.WithTag(imageName, tag)
manifestTagURL, err := env.builder.BuildManifestURL(tagRef)
manifestTagURL, _ := env.builder.BuildManifestURL(tagRef)
resp = putManifest(t, "putting manifest by tag", manifestTagURL, args.mediaType, manifest)
checkResponse(t, "putting manifest by tag", resp, http.StatusCreated)
checkHeaders(t, resp, http.Header{
@@ -2027,6 +2027,7 @@ func newTestEnvMirror(t *testing.T, deleteEnabled bool) *testEnv {
RemoteURL: "http://example.com",
},
}
config.Compatibility.Schema1.Enabled = true
return newTestEnvWithConfig(t, &config)
@@ -2043,6 +2044,7 @@ func newTestEnv(t *testing.T, deleteEnabled bool) *testEnv {
},
}
config.Compatibility.Schema1.Enabled = true
config.HTTP.Headers = headerConfig
return newTestEnvWithConfig(t, &config)
@@ -2502,7 +2504,7 @@ func TestRegistryAsCacheMutationAPIs(t *testing.T) {
checkResponse(t, "putting signed manifest to cache", resp, errcode.ErrorCodeUnsupported.Descriptor().HTTPStatusCode)
// Manifest Delete
resp, err = httpDelete(manifestURL)
resp, _ = httpDelete(manifestURL)
checkResponse(t, "deleting signed manifest from cache", resp, errcode.ErrorCodeUnsupported.Descriptor().HTTPStatusCode)
// Blob upload initialization
@@ -2521,41 +2523,12 @@ func TestRegistryAsCacheMutationAPIs(t *testing.T) {
// Blob Delete
ref, _ := reference.WithDigest(imageName, digestSha256EmptyTar)
blobURL, err := env.builder.BuildBlobURL(ref)
resp, err = httpDelete(blobURL)
blobURL, _ := env.builder.BuildBlobURL(ref)
resp, _ = httpDelete(blobURL)
checkResponse(t, "deleting blob from cache", resp, errcode.ErrorCodeUnsupported.Descriptor().HTTPStatusCode)
}
// TestCheckContextNotifier makes sure the API endpoints get a ResponseWriter
// that implements http.ContextNotifier.
func TestCheckContextNotifier(t *testing.T) {
env := newTestEnv(t, false)
defer env.Shutdown()
// Register a new endpoint for testing
env.app.router.Handle("/unittest/{name}/", env.app.dispatcher(func(ctx *Context, r *http.Request) http.Handler {
return handlers.MethodHandler{
"GET": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := w.(http.CloseNotifier); !ok {
t.Fatal("could not cast ResponseWriter to CloseNotifier")
}
w.WriteHeader(200)
}),
}
}))
resp, err := http.Get(env.server.URL + "/unittest/reponame/")
if err != nil {
t.Fatalf("unexpected error issuing request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("wrong status code - expected 200, got %d", resp.StatusCode)
}
}
func TestProxyManifestGetByTag(t *testing.T) {
truthConfig := configuration.Configuration{
Storage: configuration.Storage{
@@ -2565,6 +2538,7 @@ func TestProxyManifestGetByTag(t *testing.T) {
}},
},
}
truthConfig.Compatibility.Schema1.Enabled = true
truthConfig.HTTP.Headers = headerConfig
imageName, _ := reference.WithName("foo/bar")
@@ -2583,6 +2557,7 @@ func TestProxyManifestGetByTag(t *testing.T) {
RemoteURL: truthEnv.server.URL,
},
}
proxyConfig.Compatibility.Schema1.Enabled = true
proxyConfig.HTTP.Headers = headerConfig
proxyEnv := newTestEnvWithConfig(t, &proxyConfig)
@@ -2601,9 +2576,9 @@ func TestProxyManifestGetByTag(t *testing.T) {
checkErr(t, err, "building manifest url")
resp, err = http.Get(manifestTagURL)
checkErr(t, err, "fetching manifest from proxy by tag")
checkErr(t, err, "fetching manifest from proxy by tag (error check 1)")
defer resp.Body.Close()
checkResponse(t, "fetching manifest from proxy by tag", resp, http.StatusOK)
checkResponse(t, "fetching manifest from proxy by tag (response check 1)", resp, http.StatusOK)
checkHeaders(t, resp, http.Header{
"Docker-Content-Digest": []string{dgst.String()},
})
@@ -2616,9 +2591,9 @@ func TestProxyManifestGetByTag(t *testing.T) {
// fetch it with the same proxy URL as before. Ensure the updated content is at the same tag
resp, err = http.Get(manifestTagURL)
checkErr(t, err, "fetching manifest from proxy by tag")
checkErr(t, err, "fetching manifest from proxy by tag (error check 2)")
defer resp.Body.Close()
checkResponse(t, "fetching manifest from proxy by tag", resp, http.StatusOK)
checkResponse(t, "fetching manifest from proxy by tag (response check 2)", resp, http.StatusOK)
checkHeaders(t, resp, http.Header{
"Docker-Content-Digest": []string{newDigest.String()},
})

View File

@@ -1,6 +1,7 @@
package handlers
import (
"context"
cryptorand "crypto/rand"
"expvar"
"fmt"
@@ -16,9 +17,10 @@ import (
"github.com/docker/distribution"
"github.com/docker/distribution/configuration"
ctxu "github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/health"
"github.com/docker/distribution/health/checks"
prometheus "github.com/docker/distribution/metrics"
"github.com/docker/distribution/notifications"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/api/errcode"
@@ -34,11 +36,11 @@ import (
"github.com/docker/distribution/registry/storage/driver/factory"
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
"github.com/docker/distribution/version"
"github.com/docker/go-metrics"
"github.com/docker/libtrust"
"github.com/garyburd/redigo/redis"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"golang.org/x/net/context"
"github.com/sirupsen/logrus"
)
// randomSecretSize is the number of random bytes to generate if no secret
@@ -56,10 +58,11 @@ type App struct {
Config *configuration.Configuration
router *mux.Router // main application router, configured with dispatchers
driver storagedriver.StorageDriver // driver maintains the app global storage driver instance.
registry distribution.Namespace // registry is the primary registry backend for the app instance.
accessController auth.AccessController // main access controller for application
router *mux.Router // main application router, configured with dispatchers
driver storagedriver.StorageDriver // driver maintains the app global storage driver instance.
registry distribution.Namespace // registry is the primary registry backend for the app instance.
repoRemover distribution.RepositoryRemover // repoRemover provides ability to delete repos
accessController auth.AccessController // main access controller for application
// httpHost is a parsed representation of the http.host parameter from
// the configuration. Only the Scheme and Host fields are used.
@@ -145,7 +148,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
}
}
startUploadPurger(app, app.driver, ctxu.GetLogger(app), purgeConfig)
startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig)
app.driver, err = applyStorageMiddleware(app.driver, config.Middleware["storage"])
if err != nil {
@@ -174,6 +177,10 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
options = append(options, storage.Schema1SigningKey(app.trustKey))
if config.Compatibility.Schema1.Enabled {
options = append(options, storage.EnableSchema1)
}
if config.HTTP.Host != "" {
u, err := url.Parse(config.HTTP.Host)
if err != nil {
@@ -208,7 +215,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
}
}
if redirectDisabled {
ctxu.GetLogger(app).Infof("backend redirection disabled")
dcontext.GetLogger(app).Infof("backend redirection disabled")
} else {
options = append(options, storage.EnableRedirect)
}
@@ -269,7 +276,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
if err != nil {
panic("could not create registry: " + err.Error())
}
ctxu.GetLogger(app).Infof("using redis blob descriptor cache")
dcontext.GetLogger(app).Infof("using redis blob descriptor cache")
case "inmemory":
cacheProvider := memorycache.NewInMemoryBlobDescriptorCacheProvider()
localOptions := append(options, storage.BlobDescriptorCacheProvider(cacheProvider))
@@ -277,10 +284,10 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
if err != nil {
panic("could not create registry: " + err.Error())
}
ctxu.GetLogger(app).Infof("using inmemory blob descriptor cache")
dcontext.GetLogger(app).Infof("using inmemory blob descriptor cache")
default:
if v != "" {
ctxu.GetLogger(app).Warnf("unknown cache type %q, caching disabled", config.Storage["cache"])
dcontext.GetLogger(app).Warnf("unknown cache type %q, caching disabled", config.Storage["cache"])
}
}
}
@@ -300,13 +307,13 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
authType := config.Auth.Type()
if authType != "" {
if authType != "" && !strings.EqualFold(authType, "none") {
accessController, err := auth.GetAccessController(config.Auth.Type(), config.Auth.Parameters())
if err != nil {
panic(fmt.Sprintf("unable to configure authorization (%s): %v", authType, err))
}
app.accessController = accessController
ctxu.GetLogger(app).Debugf("configured %q access controller", authType)
dcontext.GetLogger(app).Debugf("configured %q access controller", authType)
}
// configure as a pull through cache
@@ -316,7 +323,12 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
panic(err.Error())
}
app.isCache = true
ctxu.GetLogger(app).Info("Registry configured as a proxy cache to ", config.Proxy.RemoteURL)
dcontext.GetLogger(app).Info("Registry configured as a proxy cache to ", config.Proxy.RemoteURL)
}
var ok bool
app.repoRemover, ok = app.registry.(distribution.RepositoryRemover)
if !ok {
dcontext.GetLogger(app).Warnf("Registry does not implement RempositoryRemover. Will not be able to delete repos and tags")
}
return app
@@ -346,7 +358,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
storageDriverCheck := func() error {
_, err := app.driver.Stat(app, "/") // "/" should always exist
return err // any error will be treated as failure
if _, ok := err.(storagedriver.PathNotFoundError); ok {
err = nil // pass this through, backend is responding, but this path doesn't exist.
}
return err
}
if app.Config.Health.StorageDriver.Threshold != 0 {
@@ -361,7 +376,7 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
if interval == 0 {
interval = defaultCheckInterval
}
ctxu.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second)
dcontext.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second)
healthRegistry.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval))
}
@@ -379,10 +394,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
checker := checks.HTTPChecker(httpChecker.URI, statusCode, httpChecker.Timeout, httpChecker.Headers)
if httpChecker.Threshold != 0 {
ctxu.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold)
dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold)
healthRegistry.Register(httpChecker.URI, health.PeriodicThresholdChecker(checker, interval, httpChecker.Threshold))
} else {
ctxu.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d", httpChecker.URI, interval/time.Second)
dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d", httpChecker.URI, interval/time.Second)
healthRegistry.Register(httpChecker.URI, health.PeriodicChecker(checker, interval))
}
}
@@ -396,10 +411,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
checker := checks.TCPChecker(tcpChecker.Addr, tcpChecker.Timeout)
if tcpChecker.Threshold != 0 {
ctxu.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold)
dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold)
healthRegistry.Register(tcpChecker.Addr, health.PeriodicThresholdChecker(checker, interval, tcpChecker.Threshold))
} else {
ctxu.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d", tcpChecker.Addr, interval/time.Second)
dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d", tcpChecker.Addr, interval/time.Second)
healthRegistry.Register(tcpChecker.Addr, health.PeriodicChecker(checker, interval))
}
}
@@ -409,6 +424,15 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
// passed through the application filters and context will be constructed at
// request time.
func (app *App) register(routeName string, dispatch dispatchFunc) {
handler := app.dispatcher(dispatch)
// Chain the handler with prometheus instrumented handler
if app.Config.HTTP.Debug.Prometheus.Enabled {
namespace := metrics.NewNamespace(prometheus.NamespacePrefix, "http", nil)
httpMetrics := namespace.NewDefaultHttpMetrics(strings.Replace(routeName, "-", "_", -1))
metrics.Register(namespace)
handler = metrics.InstrumentHandler(httpMetrics, handler)
}
// TODO(stevvooe): This odd dispatcher/route registration is by-product of
// some limitations in the gorilla/mux router. We are using it to keep
@@ -416,7 +440,7 @@ func (app *App) register(routeName string, dispatch dispatchFunc) {
// replace it with manual routing and structure-based dispatch for better
// control over the request execution.
app.router.GetRoute(routeName).Handler(app.dispatcher(dispatch))
app.router.GetRoute(routeName).Handler(handler)
}
// configureEvents prepares the event sink for action.
@@ -425,17 +449,18 @@ func (app *App) configureEvents(configuration *configuration.Configuration) {
var sinks []notifications.Sink
for _, endpoint := range configuration.Notifications.Endpoints {
if endpoint.Disabled {
ctxu.GetLogger(app).Infof("endpoint %s disabled, skipping", endpoint.Name)
dcontext.GetLogger(app).Infof("endpoint %s disabled, skipping", endpoint.Name)
continue
}
ctxu.GetLogger(app).Infof("configuring endpoint %v (%v), timeout=%s, headers=%v", endpoint.Name, endpoint.URL, endpoint.Timeout, endpoint.Headers)
dcontext.GetLogger(app).Infof("configuring endpoint %v (%v), timeout=%s, headers=%v", endpoint.Name, endpoint.URL, endpoint.Timeout, endpoint.Headers)
endpoint := notifications.NewEndpoint(endpoint.Name, endpoint.URL, notifications.EndpointConfig{
Timeout: endpoint.Timeout,
Threshold: endpoint.Threshold,
Backoff: endpoint.Backoff,
Headers: endpoint.Headers,
IgnoredMediaTypes: endpoint.IgnoredMediaTypes,
Ignore: endpoint.Ignore,
})
sinks = append(sinks, endpoint)
@@ -461,7 +486,7 @@ func (app *App) configureEvents(configuration *configuration.Configuration) {
app.events.source = notifications.SourceRecord{
Addr: hostname,
InstanceID: ctxu.GetStringValue(app, "instance.id"),
InstanceID: dcontext.GetStringValue(app, "instance.id"),
}
}
@@ -469,7 +494,7 @@ type redisStartAtKey struct{}
func (app *App) configureRedis(configuration *configuration.Configuration) {
if configuration.Redis.Addr == "" {
ctxu.GetLogger(app).Infof("redis not configured")
dcontext.GetLogger(app).Infof("redis not configured")
return
}
@@ -479,8 +504,8 @@ func (app *App) configureRedis(configuration *configuration.Configuration) {
ctx := context.WithValue(app, redisStartAtKey{}, time.Now())
done := func(err error) {
logger := ctxu.GetLoggerWithField(ctx, "redis.connect.duration",
ctxu.Since(ctx, redisStartAtKey{}))
logger := dcontext.GetLoggerWithField(ctx, "redis.connect.duration",
dcontext.Since(ctx, redisStartAtKey{}))
if err != nil {
logger.Errorf("redis: error connecting: %v", err)
} else {
@@ -494,7 +519,7 @@ func (app *App) configureRedis(configuration *configuration.Configuration) {
configuration.Redis.ReadTimeout,
configuration.Redis.WriteTimeout)
if err != nil {
ctxu.GetLogger(app).Errorf("error connecting to redis instance %s: %v",
dcontext.GetLogger(app).Errorf("error connecting to redis instance %s: %v",
configuration.Redis.Addr, err)
done(err)
return nil, err
@@ -551,7 +576,7 @@ func (app *App) configureRedis(configuration *configuration.Configuration) {
// configureLogHook prepares logging hook parameters.
func (app *App) configureLogHook(configuration *configuration.Configuration) {
entry, ok := ctxu.GetLogger(app).(*log.Entry)
entry, ok := dcontext.GetLogger(app).(*logrus.Entry)
if !ok {
// somehow, we are not using logrus
return
@@ -589,7 +614,7 @@ func (app *App) configureSecret(configuration *configuration.Configuration) {
panic(fmt.Sprintf("could not generate random bytes for HTTP secret: %v", err))
}
configuration.HTTP.Secret = string(secretBytes[:])
ctxu.GetLogger(app).Warn("No HTTP secret provided - generated random secret. This may cause problems with uploads if multiple registries are behind a load-balancer. To provide a shared secret, fill in http.secret in the configuration file or set the REGISTRY_HTTP_SECRET environment variable.")
dcontext.GetLogger(app).Warn("No HTTP secret provided - generated random secret. This may cause problems with uploads if multiple registries are behind a load-balancer. To provide a shared secret, fill in http.secret in the configuration file or set the REGISTRY_HTTP_SECRET environment variable.")
}
}
@@ -598,15 +623,15 @@ func (app *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Prepare the context with our own little decorations.
ctx := r.Context()
ctx = ctxu.WithRequest(ctx, r)
ctx, w = ctxu.WithResponseWriter(ctx, w)
ctx = ctxu.WithLogger(ctx, ctxu.GetRequestLogger(ctx))
ctx = dcontext.WithRequest(ctx, r)
ctx, w = dcontext.WithResponseWriter(ctx, w)
ctx = dcontext.WithLogger(ctx, dcontext.GetRequestLogger(ctx))
r = r.WithContext(ctx)
defer func() {
status, ok := ctx.Value("http.response.status").(int)
if ok && status >= 200 && status <= 399 {
ctxu.GetResponseLogger(r.Context()).Infof("response completed")
dcontext.GetResponseLogger(r.Context()).Infof("response completed")
}
}()
@@ -637,12 +662,12 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
context := app.context(w, r)
if err := app.authorized(w, r, context); err != nil {
ctxu.GetLogger(context).Warnf("error authorizing context: %v", err)
dcontext.GetLogger(context).Warnf("error authorizing context: %v", err)
return
}
// Add username to request logging
context.Context = ctxu.WithLogger(context.Context, ctxu.GetLogger(context.Context, auth.UserNameKey))
context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, auth.UserNameKey))
// sync up context on the request.
r = r.WithContext(context)
@@ -650,20 +675,20 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
if app.nameRequired(r) {
nameRef, err := reference.WithName(getName(context))
if err != nil {
ctxu.GetLogger(context).Errorf("error parsing reference from context: %v", err)
dcontext.GetLogger(context).Errorf("error parsing reference from context: %v", err)
context.Errors = append(context.Errors, distribution.ErrRepositoryNameInvalid{
Name: getName(context),
Reason: err,
})
if err := errcode.ServeJSON(w, context.Errors); err != nil {
ctxu.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
dcontext.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
}
return
}
repository, err := app.registry.Repository(context, nameRef)
if err != nil {
ctxu.GetLogger(context).Errorf("error resolving repository: %v", err)
dcontext.GetLogger(context).Errorf("error resolving repository: %v", err)
switch err := err.(type) {
case distribution.ErrRepositoryUnknown:
@@ -675,23 +700,24 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
}
if err := errcode.ServeJSON(w, context.Errors); err != nil {
ctxu.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
dcontext.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
}
return
}
// assign and decorate the authorized repository with an event bridge.
context.Repository = notifications.Listen(
context.Repository, context.RepositoryRemover = notifications.Listen(
repository,
context.App.repoRemover,
app.eventBridge(context, r))
context.Repository, err = applyRepoMiddleware(app, context.Repository, app.Config.Middleware["repository"])
if err != nil {
ctxu.GetLogger(context).Errorf("error initializing repository middleware: %v", err)
dcontext.GetLogger(context).Errorf("error initializing repository middleware: %v", err)
context.Errors = append(context.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
if err := errcode.ServeJSON(w, context.Errors); err != nil {
ctxu.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
dcontext.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
}
return
}
@@ -703,7 +729,7 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
// for layer upload).
if context.Errors.Len() > 0 {
if err := errcode.ServeJSON(w, context.Errors); err != nil {
ctxu.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
dcontext.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
}
app.logError(context, context.Errors)
@@ -723,31 +749,31 @@ type errDetailKey struct{}
func (errDetailKey) String() string { return "err.detail" }
func (app *App) logError(context context.Context, errors errcode.Errors) {
func (app *App) logError(ctx context.Context, errors errcode.Errors) {
for _, e1 := range errors {
var c ctxu.Context
var c context.Context
switch e1.(type) {
case errcode.Error:
e, _ := e1.(errcode.Error)
c = ctxu.WithValue(context, errCodeKey{}, e.Code)
c = ctxu.WithValue(c, errMessageKey{}, e.Code.Message())
c = ctxu.WithValue(c, errDetailKey{}, e.Detail)
c = context.WithValue(ctx, errCodeKey{}, e.Code)
c = context.WithValue(c, errMessageKey{}, e.Message)
c = context.WithValue(c, errDetailKey{}, e.Detail)
case errcode.ErrorCode:
e, _ := e1.(errcode.ErrorCode)
c = ctxu.WithValue(context, errCodeKey{}, e)
c = ctxu.WithValue(c, errMessageKey{}, e.Message())
c = context.WithValue(ctx, errCodeKey{}, e)
c = context.WithValue(c, errMessageKey{}, e.Message())
default:
// just normal go 'error'
c = ctxu.WithValue(context, errCodeKey{}, errcode.ErrorCodeUnknown)
c = ctxu.WithValue(c, errMessageKey{}, e1.Error())
c = context.WithValue(ctx, errCodeKey{}, errcode.ErrorCodeUnknown)
c = context.WithValue(c, errMessageKey{}, e1.Error())
}
c = ctxu.WithLogger(c, ctxu.GetLogger(c,
c = dcontext.WithLogger(c, dcontext.GetLogger(c,
errCodeKey{},
errMessageKey{},
errDetailKey{}))
ctxu.GetResponseLogger(c).Errorf("response completed with error")
dcontext.GetResponseLogger(c).Errorf("response completed with error")
}
}
@@ -755,8 +781,8 @@ func (app *App) logError(context context.Context, errors errcode.Errors) {
// called once per request.
func (app *App) context(w http.ResponseWriter, r *http.Request) *Context {
ctx := r.Context()
ctx = ctxu.WithVars(ctx, r)
ctx = ctxu.WithLogger(ctx, ctxu.GetLogger(ctx,
ctx = dcontext.WithVars(ctx, r)
ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx,
"vars.name",
"vars.reference",
"vars.digest",
@@ -783,7 +809,7 @@ func (app *App) context(w http.ResponseWriter, r *http.Request) *Context {
// repository. If it succeeds, the context may access the requested
// repository. An error will be returned if access is not available.
func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Context) error {
ctxu.GetLogger(context).Debug("authorizing request")
dcontext.GetLogger(context).Debug("authorizing request")
repo := getName(context)
if app.accessController == nil {
@@ -809,7 +835,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
// that mistake elsewhere in the code, allowing any operation to
// proceed.
if err := errcode.ServeJSON(w, errcode.ErrorCodeUnauthorized); err != nil {
ctxu.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
dcontext.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
}
return fmt.Errorf("forbidden: no repository name")
}
@@ -824,20 +850,21 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
err.SetHeaders(w)
if err := errcode.ServeJSON(w, errcode.ErrorCodeUnauthorized.WithDetail(accessRecords)); err != nil {
ctxu.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
dcontext.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)
}
default:
// This condition is a potential security problem either in
// the configuration or whatever is backing the access
// controller. Just return a bad request with no information
// to avoid exposure. The request should not proceed.
ctxu.GetLogger(context).Errorf("error checking authorization: %v", err)
dcontext.GetLogger(context).Errorf("error checking authorization: %v", err)
w.WriteHeader(http.StatusBadRequest)
}
return err
}
dcontext.GetLogger(ctx).Info("authorized request")
// TODO(stevvooe): This pattern needs to be cleaned up a bit. One context
// should be replaced by another, rather than replacing the context on a
// mutable object.
@@ -851,9 +878,9 @@ func (app *App) eventBridge(ctx *Context, r *http.Request) notifications.Listene
actor := notifications.ActorRecord{
Name: getUserName(ctx, r),
}
request := notifications.NewRequestRecord(ctxu.GetRequestID(ctx), r)
request := notifications.NewRequestRecord(dcontext.GetRequestID(ctx), r)
return notifications.NewBridge(ctx.urlBuilder, app.events.source, actor, request, app.events.sink)
return notifications.NewBridge(ctx.urlBuilder, app.events.source, actor, request, app.events.sink, app.Config.Notifications.EventConfig.IncludeReferences)
}
// nameRequired returns true if the route requires a name.
@@ -986,7 +1013,7 @@ func badPurgeUploadConfig(reason string) {
// startUploadPurger schedules a goroutine which will periodically
// check upload directories for old files and delete them
func startUploadPurger(ctx context.Context, storageDriver storagedriver.StorageDriver, log ctxu.Logger, config map[interface{}]interface{}) {
func startUploadPurger(ctx context.Context, storageDriver storagedriver.StorageDriver, log dcontext.Logger, config map[interface{}]interface{}) {
if config["enabled"] == false {
return
}

View File

@@ -6,7 +6,7 @@ import (
"net/url"
"github.com/docker/distribution"
ctxu "github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
@@ -39,7 +39,7 @@ func blobUploadDispatcher(ctx *Context, r *http.Request) http.Handler {
state, err := hmacKey(ctx.Config.HTTP.Secret).unpackUploadState(r.FormValue("_state"))
if err != nil {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(ctx).Infof("error resolving upload: %v", err)
dcontext.GetLogger(ctx).Infof("error resolving upload: %v", err)
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
})
}
@@ -47,14 +47,14 @@ func blobUploadDispatcher(ctx *Context, r *http.Request) http.Handler {
if state.Name != ctx.Repository.Named().Name() {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(ctx).Infof("mismatched repository name in upload state: %q != %q", state.Name, buh.Repository.Named().Name())
dcontext.GetLogger(ctx).Infof("mismatched repository name in upload state: %q != %q", state.Name, buh.Repository.Named().Name())
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
})
}
if state.UUID != buh.UUID {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(ctx).Infof("mismatched uuid in upload state: %q != %q", state.UUID, buh.UUID)
dcontext.GetLogger(ctx).Infof("mismatched uuid in upload state: %q != %q", state.UUID, buh.UUID)
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
})
}
@@ -62,7 +62,7 @@ func blobUploadDispatcher(ctx *Context, r *http.Request) http.Handler {
blobs := ctx.Repository.Blobs(buh)
upload, err := blobs.Resume(buh, buh.UUID)
if err != nil {
ctxu.GetLogger(ctx).Errorf("error resolving upload: %v", err)
dcontext.GetLogger(ctx).Errorf("error resolving upload: %v", err)
if err == distribution.ErrBlobUploadUnknown {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadUnknown.WithDetail(err))
@@ -77,7 +77,7 @@ func blobUploadDispatcher(ctx *Context, r *http.Request) http.Handler {
if size := upload.Size(); size != buh.State.Offset {
defer upload.Close()
ctxu.GetLogger(ctx).Errorf("upload resumed at wrong offest: %d != %d", size, buh.State.Offset)
dcontext.GetLogger(ctx).Errorf("upload resumed at wrong offest: %d != %d", size, buh.State.Offset)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
upload.Cancel(buh)
@@ -179,7 +179,7 @@ func (buh *blobUploadHandler) PatchBlobData(w http.ResponseWriter, r *http.Reque
// TODO(dmcgowan): support Content-Range header to seek and write range
if err := copyFullPayload(w, r, buh.Upload, -1, buh, "blob PATCH"); err != nil {
if err := copyFullPayload(buh, w, r, buh.Upload, -1, "blob PATCH"); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err.Error()))
return
}
@@ -218,7 +218,7 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
return
}
if err := copyFullPayload(w, r, buh.Upload, -1, buh, "blob PUT"); err != nil {
if err := copyFullPayload(buh, w, r, buh.Upload, -1, "blob PUT"); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err.Error()))
return
}
@@ -246,7 +246,7 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
case distribution.ErrBlobInvalidLength, distribution.ErrBlobDigestUnsupported:
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
default:
ctxu.GetLogger(buh).Errorf("unknown error completing upload: %v", err)
dcontext.GetLogger(buh).Errorf("unknown error completing upload: %v", err)
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
@@ -255,7 +255,7 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
// Clean up the backend blob data if there was an error.
if err := buh.Upload.Cancel(buh); err != nil {
// If the cleanup fails, all we can do is observe and report.
ctxu.GetLogger(buh).Errorf("error canceling upload after error: %v", err)
dcontext.GetLogger(buh).Errorf("error canceling upload after error: %v", err)
}
return
@@ -275,7 +275,7 @@ func (buh *blobUploadHandler) CancelBlobUpload(w http.ResponseWriter, r *http.Re
w.Header().Set("Docker-Upload-UUID", buh.UUID)
if err := buh.Upload.Cancel(buh); err != nil {
ctxu.GetLogger(buh).Errorf("error encountered canceling upload: %v", err)
dcontext.GetLogger(buh).Errorf("error encountered canceling upload: %v", err)
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
@@ -297,7 +297,7 @@ func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.
token, err := hmacKey(buh.Config.HTTP.Secret).packUploadState(buh.State)
if err != nil {
ctxu.GetLogger(buh).Infof("error building upload state token: %s", err)
dcontext.GetLogger(buh).Infof("error building upload state token: %s", err)
return err
}
@@ -307,7 +307,7 @@ func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.
"_state": []string{token},
})
if err != nil {
ctxu.GetLogger(buh).Infof("error building upload url: %s", err)
dcontext.GetLogger(buh).Infof("error building upload url: %s", err)
return err
}

View File

@@ -1,16 +1,16 @@
package handlers
import (
"context"
"fmt"
"net/http"
"github.com/docker/distribution"
ctxu "github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/registry/auth"
"github.com/opencontainers/go-digest"
"golang.org/x/net/context"
)
// Context should contain the request specific context for use in across
@@ -25,6 +25,9 @@ type Context struct {
// should be scoped to a single repository. This field may be nil.
Repository distribution.Repository
// RepositoryRemover provides method to delete a repository
RepositoryRemover distribution.RepositoryRemover
// Errors is a collection of errors encountered during the request to be
// returned to the client API. If errors are added to the collection, the
// handler *must not* start the response via http.ResponseWriter.
@@ -44,26 +47,26 @@ func (ctx *Context) Value(key interface{}) interface{} {
}
func getName(ctx context.Context) (name string) {
return ctxu.GetStringValue(ctx, "vars.name")
return dcontext.GetStringValue(ctx, "vars.name")
}
func getReference(ctx context.Context) (reference string) {
return ctxu.GetStringValue(ctx, "vars.reference")
return dcontext.GetStringValue(ctx, "vars.reference")
}
var errDigestNotAvailable = fmt.Errorf("digest not available in context")
func getDigest(ctx context.Context) (dgst digest.Digest, err error) {
dgstStr := ctxu.GetStringValue(ctx, "vars.digest")
dgstStr := dcontext.GetStringValue(ctx, "vars.digest")
if dgstStr == "" {
ctxu.GetLogger(ctx).Errorf("digest not available")
dcontext.GetLogger(ctx).Errorf("digest not available")
return "", errDigestNotAvailable
}
d, err := digest.Parse(dgstStr)
if err != nil {
ctxu.GetLogger(ctx).Errorf("error parsing digest=%q: %v", dgstStr, err)
dcontext.GetLogger(ctx).Errorf("error parsing digest=%q: %v", dgstStr, err)
return "", err
}
@@ -71,13 +74,13 @@ func getDigest(ctx context.Context) (dgst digest.Digest, err error) {
}
func getUploadUUID(ctx context.Context) (uuid string) {
return ctxu.GetStringValue(ctx, "vars.uuid")
return dcontext.GetStringValue(ctx, "vars.uuid")
}
// getUserName attempts to resolve a username from the context and request. If
// a username cannot be resolved, the empty string is returned.
func getUserName(ctx context.Context, r *http.Request) string {
username := ctxu.GetStringValue(ctx, auth.UserNameKey)
username := dcontext.GetStringValue(ctx, auth.UserNameKey)
// Fallback to request user with basic auth
if username == "" {

View File

@@ -1,11 +1,12 @@
package handlers
import (
"context"
"errors"
"io"
"net/http"
ctxu "github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
)
// closeResources closes all the provided resources after running the target
@@ -24,15 +25,9 @@ func closeResources(handler http.Handler, closers ...io.Closer) http.Handler {
// upload, it avoids sending a 400 error to keep the logs cleaner.
//
// The copy will be limited to `limit` bytes, if limit is greater than zero.
func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, limit int64, context ctxu.Context, action string) error {
func copyFullPayload(ctx context.Context, responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, limit int64, action string) error {
// Get a channel that tells us if the client disconnects
var clientClosed <-chan bool
if notifier, ok := responseWriter.(http.CloseNotifier); ok {
clientClosed = notifier.CloseNotify()
} else {
ctxu.GetLogger(context).Warnf("the ResponseWriter does not implement CloseNotifier (type: %T)", responseWriter)
}
clientClosed := r.Context().Done()
var body = r.Body
if limit > 0 {
body = http.MaxBytesReader(responseWriter, body, limit)
@@ -52,7 +47,7 @@ func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWr
// instead of showing 0 for the HTTP status.
responseWriter.WriteHeader(499)
ctxu.GetLoggerWithFields(context, map[interface{}]interface{}{
dcontext.GetLoggerWithFields(ctx, map[interface{}]interface{}{
"error": err,
"copied": copied,
"contentLength": r.ContentLength,
@@ -63,7 +58,7 @@ func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWr
}
if err != nil {
ctxu.GetLogger(context).Errorf("unknown error reading request payload: %v", err)
dcontext.GetLogger(ctx).Errorf("unknown error reading request payload: %v", err)
return err
}

View File

@@ -36,7 +36,7 @@ func (mail *mailer) sendMail(subject, message string) error {
auth,
mail.From,
mail.To,
[]byte(msg),
msg,
)
if err != nil {
return err

View File

@@ -7,8 +7,9 @@ import (
"strings"
"github.com/docker/distribution"
ctxu "github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/manifest/manifestlist"
"github.com/docker/distribution/manifest/ocischema"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/reference"
@@ -17,6 +18,7 @@ import (
"github.com/docker/distribution/registry/auth"
"github.com/gorilla/handlers"
"github.com/opencontainers/go-digest"
"github.com/opencontainers/image-spec/specs-go/v1"
)
// These constants determine which architecture and OS to choose from a
@@ -25,6 +27,18 @@ const (
defaultArch = "amd64"
defaultOS = "linux"
maxManifestBodySize = 4 << 20
imageClass = "image"
)
type storageType int
const (
manifestSchema1 storageType = iota // 0
manifestSchema2 // 1
manifestlistSchema // 2
ociSchema // 3
ociImageIndexSchema // 4
numStorageTypes // 5
)
// manifestDispatcher takes the request context and builds the
@@ -66,14 +80,46 @@ type manifestHandler struct {
// GetManifest fetches the image manifest from the storage backend, if it exists.
func (imh *manifestHandler) GetManifest(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(imh).Debug("GetImageManifest")
dcontext.GetLogger(imh).Debug("GetImageManifest")
manifests, err := imh.Repository.Manifests(imh)
if err != nil {
imh.Errors = append(imh.Errors, err)
return
}
var supports [numStorageTypes]bool
// this parsing of Accept headers is not quite as full-featured as godoc.org's parser, but we don't care about "q=" values
// https://github.com/golang/gddo/blob/e91d4165076d7474d20abda83f92d15c7ebc3e81/httputil/header/header.go#L165-L202
for _, acceptHeader := range r.Header["Accept"] {
// r.Header[...] is a slice in case the request contains the same header more than once
// if the header isn't set, we'll get the zero value, which "range" will handle gracefully
// we need to split each header value on "," to get the full list of "Accept" values (per RFC 2616)
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.1
for _, mediaType := range strings.Split(acceptHeader, ",") {
// remove "; q=..." if present
if i := strings.Index(mediaType, ";"); i >= 0 {
mediaType = mediaType[:i]
}
// it's common (but not required) for Accept values to be space separated ("a/b, c/d, e/f")
mediaType = strings.TrimSpace(mediaType)
if mediaType == schema2.MediaTypeManifest {
supports[manifestSchema2] = true
}
if mediaType == manifestlist.MediaTypeManifestList {
supports[manifestlistSchema] = true
}
if mediaType == v1.MediaTypeImageManifest {
supports[ociSchema] = true
}
if mediaType == v1.MediaTypeImageIndex {
supports[ociImageIndexSchema] = true
}
}
}
var manifest distribution.Manifest
if imh.Tag != "" {
tags := imh.Repository.Tags(imh)
desc, err := tags.Get(imh, imh.Tag)
@@ -97,7 +143,7 @@ func (imh *manifestHandler) GetManifest(w http.ResponseWriter, r *http.Request)
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
}
manifest, err = manifests.Get(imh, imh.Digest, options...)
manifest, err := manifests.Get(imh, imh.Digest, options...)
if err != nil {
if _, ok := err.(distribution.ErrManifestUnknownRevision); ok {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err))
@@ -106,52 +152,44 @@ func (imh *manifestHandler) GetManifest(w http.ResponseWriter, r *http.Request)
}
return
}
supportsSchema2 := false
supportsManifestList := false
// this parsing of Accept headers is not quite as full-featured as godoc.org's parser, but we don't care about "q=" values
// https://github.com/golang/gddo/blob/e91d4165076d7474d20abda83f92d15c7ebc3e81/httputil/header/header.go#L165-L202
for _, acceptHeader := range r.Header["Accept"] {
// r.Header[...] is a slice in case the request contains the same header more than once
// if the header isn't set, we'll get the zero value, which "range" will handle gracefully
// we need to split each header value on "," to get the full list of "Accept" values (per RFC 2616)
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.1
for _, mediaType := range strings.Split(acceptHeader, ",") {
// remove "; q=..." if present
if i := strings.Index(mediaType, ";"); i >= 0 {
mediaType = mediaType[:i]
}
// it's common (but not required) for Accept values to be space separated ("a/b, c/d, e/f")
mediaType = strings.TrimSpace(mediaType)
if mediaType == schema2.MediaTypeManifest {
supportsSchema2 = true
}
if mediaType == manifestlist.MediaTypeManifestList {
supportsManifestList = true
}
// determine the type of the returned manifest
manifestType := manifestSchema1
schema2Manifest, isSchema2 := manifest.(*schema2.DeserializedManifest)
manifestList, isManifestList := manifest.(*manifestlist.DeserializedManifestList)
if isSchema2 {
manifestType = manifestSchema2
} else if _, isOCImanifest := manifest.(*ocischema.DeserializedManifest); isOCImanifest {
manifestType = ociSchema
} else if isManifestList {
if manifestList.MediaType == manifestlist.MediaTypeManifestList {
manifestType = manifestlistSchema
} else if manifestList.MediaType == v1.MediaTypeImageIndex {
manifestType = ociImageIndexSchema
}
}
schema2Manifest, isSchema2 := manifest.(*schema2.DeserializedManifest)
manifestList, isManifestList := manifest.(*manifestlist.DeserializedManifestList)
if manifestType == ociSchema && !supports[ociSchema] {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithMessage("OCI manifest found, but accept header does not support OCI manifests"))
return
}
if manifestType == ociImageIndexSchema && !supports[ociImageIndexSchema] {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithMessage("OCI index found, but accept header does not support OCI indexes"))
return
}
// Only rewrite schema2 manifests when they are being fetched by tag.
// If they are being fetched by digest, we can't return something not
// matching the digest.
if imh.Tag != "" && isSchema2 && !supportsSchema2 {
if imh.Tag != "" && manifestType == manifestSchema2 && !supports[manifestSchema2] {
// Rewrite manifest in schema1 format
ctxu.GetLogger(imh).Infof("rewriting manifest %s in schema1 format to support old client", imh.Digest.String())
dcontext.GetLogger(imh).Infof("rewriting manifest %s in schema1 format to support old client", imh.Digest.String())
manifest, err = imh.convertSchema2Manifest(schema2Manifest)
if err != nil {
return
}
} else if imh.Tag != "" && isManifestList && !supportsManifestList {
} else if imh.Tag != "" && manifestType == manifestlistSchema && !supports[manifestlistSchema] {
// Rewrite manifest in schema1 format
ctxu.GetLogger(imh).Infof("rewriting manifest list %s in schema1 format to support old client", imh.Digest.String())
dcontext.GetLogger(imh).Infof("rewriting manifest list %s in schema1 format to support old client", imh.Digest.String())
// Find the image manifest corresponding to the default
// platform
@@ -179,7 +217,7 @@ func (imh *manifestHandler) GetManifest(w http.ResponseWriter, r *http.Request)
}
// If necessary, convert the image manifest
if schema2Manifest, isSchema2 := manifest.(*schema2.DeserializedManifest); isSchema2 && !supportsSchema2 {
if schema2Manifest, isSchema2 := manifest.(*schema2.DeserializedManifest); isSchema2 && !supports[manifestSchema2] {
manifest, err = imh.convertSchema2Manifest(schema2Manifest)
if err != nil {
return
@@ -252,7 +290,7 @@ func etagMatch(r *http.Request, etag string) bool {
// PutManifest validates and stores a manifest in the registry.
func (imh *manifestHandler) PutManifest(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(imh).Debug("PutImageManifest")
dcontext.GetLogger(imh).Debug("PutImageManifest")
manifests, err := imh.Repository.Manifests(imh)
if err != nil {
imh.Errors = append(imh.Errors, err)
@@ -260,7 +298,7 @@ func (imh *manifestHandler) PutManifest(w http.ResponseWriter, r *http.Request)
}
var jsonBuf bytes.Buffer
if err := copyFullPayload(w, r, &jsonBuf, maxManifestBodySize, imh, "image manifest PUT"); err != nil {
if err := copyFullPayload(imh, w, r, &jsonBuf, maxManifestBodySize, "image manifest PUT"); err != nil {
// copyFullPayload reports the error if necessary
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err.Error()))
return
@@ -275,7 +313,7 @@ func (imh *manifestHandler) PutManifest(w http.ResponseWriter, r *http.Request)
if imh.Digest != "" {
if desc.Digest != imh.Digest {
ctxu.GetLogger(imh).Errorf("payload digest does match: %q != %q", desc.Digest, imh.Digest)
dcontext.GetLogger(imh).Errorf("payload digest does match: %q != %q", desc.Digest, imh.Digest)
imh.Errors = append(imh.Errors, v2.ErrorCodeDigestInvalid)
return
}
@@ -286,6 +324,14 @@ func (imh *manifestHandler) PutManifest(w http.ResponseWriter, r *http.Request)
return
}
isAnOCIManifest := mediaType == v1.MediaTypeImageManifest || mediaType == v1.MediaTypeImageIndex
if isAnOCIManifest {
dcontext.GetLogger(imh).Debug("Putting an OCI Manifest!")
} else {
dcontext.GetLogger(imh).Debug("Putting a Docker Manifest!")
}
var options []distribution.ManifestServiceOption
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
@@ -331,7 +377,6 @@ func (imh *manifestHandler) PutManifest(w http.ResponseWriter, r *http.Request)
default:
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
return
}
@@ -358,12 +403,14 @@ func (imh *manifestHandler) PutManifest(w http.ResponseWriter, r *http.Request)
// NOTE(stevvooe): Given the behavior above, this absurdly unlikely to
// happen. We'll log the error here but proceed as if it worked. Worst
// case, we set an empty location header.
ctxu.GetLogger(imh).Errorf("error building manifest url from digest: %v", err)
dcontext.GetLogger(imh).Errorf("error building manifest url from digest: %v", err)
}
w.Header().Set("Location", location)
w.Header().Set("Docker-Content-Digest", imh.Digest.String())
w.WriteHeader(http.StatusCreated)
dcontext.GetLogger(imh).Debug("Succeeded in putting manifest!")
}
// applyResourcePolicy checks whether the resource class matches what has
@@ -377,16 +424,22 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest)
var class string
switch m := manifest.(type) {
case *schema1.SignedManifest:
class = "image"
class = imageClass
case *schema2.DeserializedManifest:
switch m.Config.MediaType {
case schema2.MediaTypeImageConfig:
class = "image"
class = imageClass
case schema2.MediaTypePluginConfig:
class = "plugin"
default:
message := fmt.Sprintf("unknown manifest class for %s", m.Config.MediaType)
return errcode.ErrorCodeDenied.WithMessage(message)
return errcode.ErrorCodeDenied.WithMessage("unknown manifest class for " + m.Config.MediaType)
}
case *ocischema.DeserializedManifest:
switch m.Config.MediaType {
case v1.MediaTypeImageConfig:
class = imageClass
default:
return errcode.ErrorCodeDenied.WithMessage("unknown manifest class for " + m.Config.MediaType)
}
}
@@ -403,8 +456,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest)
}
}
if !allowedClass {
message := fmt.Sprintf("registry does not allow %s manifest", class)
return errcode.ErrorCodeDenied.WithMessage(message)
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("registry does not allow %s manifest", class))
}
resources := auth.AuthorizedResources(imh)
@@ -414,7 +466,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest)
for _, r := range resources {
if r.Name == n {
if r.Class == "" {
r.Class = "image"
r.Class = imageClass
}
if r.Class == class {
return nil
@@ -425,8 +477,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest)
// resource was found but no matching class was found
if foundResource {
message := fmt.Sprintf("repository not authorized for %s manifest", class)
return errcode.ErrorCodeDenied.WithMessage(message)
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("repository not authorized for %s manifest", class))
}
return nil
@@ -435,7 +486,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest)
// DeleteManifest removes the manifest with the given digest from the registry.
func (imh *manifestHandler) DeleteManifest(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(imh).Debug("DeleteImageManifest")
dcontext.GetLogger(imh).Debug("DeleteImageManifest")
manifests, err := imh.Repository.Manifests(imh)
if err != nil {

View File

@@ -1,10 +1,10 @@
package middleware
import (
"context"
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
)

View File

@@ -1,10 +1,10 @@
package middleware
import (
"context"
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
)
// InitFunc is the type of a RepositoryMiddleware factory function and is

View File

@@ -79,9 +79,5 @@ func ping(manager challenge.Manager, endpoint, versionHeader string) error {
}
defer resp.Body.Close()
if err := manager.AddResponse(resp); err != nil {
return err
}
return nil
return manager.AddResponse(resp)
}

View File

@@ -1,6 +1,7 @@
package proxy
import (
"context"
"io"
"net/http"
"strconv"
@@ -8,14 +9,14 @@ import (
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/proxy/scheduler"
"github.com/opencontainers/go-digest"
)
// todo(richardscothern): from cache control header or config file
const blobTTL = time.Duration(24 * 7 * time.Hour)
const blobTTL = 24 * 7 * time.Hour
type proxyBlobStore struct {
localStore distribution.BlobStore
@@ -116,7 +117,7 @@ func (pbs *proxyBlobStore) storeLocal(ctx context.Context, dgst digest.Digest) e
func (pbs *proxyBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
served, err := pbs.serveLocal(ctx, w, r, dgst)
if err != nil {
context.GetLogger(ctx).Errorf("Error serving blob from local storage: %s", err.Error())
dcontext.GetLogger(ctx).Errorf("Error serving blob from local storage: %s", err.Error())
return err
}
@@ -140,12 +141,12 @@ func (pbs *proxyBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter,
go func(dgst digest.Digest) {
if err := pbs.storeLocal(ctx, dgst); err != nil {
context.GetLogger(ctx).Errorf("Error committing to storage: %s", err.Error())
dcontext.GetLogger(ctx).Errorf("Error committing to storage: %s", err.Error())
}
blobRef, err := reference.WithDigest(pbs.repositoryName, dgst)
if err != nil {
context.GetLogger(ctx).Errorf("Error creating reference: %s", err)
dcontext.GetLogger(ctx).Errorf("Error creating reference: %s", err)
return
}

View File

@@ -1,6 +1,7 @@
package proxy
import (
"context"
"io/ioutil"
"math/rand"
"net/http"
@@ -10,7 +11,6 @@ import (
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/proxy/scheduler"
"github.com/docker/distribution/registry/storage"
@@ -350,24 +350,30 @@ func testProxyStoreServe(t *testing.T, te *testEnv, numClients int) {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "", nil)
if err != nil {
t.Fatal(err)
t.Error(err)
return
}
err = te.store.ServeBlob(te.ctx, w, r, remoteBlob.Digest)
if err != nil {
t.Fatalf(err.Error())
t.Errorf(err.Error())
return
}
bodyBytes := w.Body.Bytes()
localDigest := digest.FromBytes(bodyBytes)
if localDigest != remoteBlob.Digest {
t.Fatalf("Mismatching blob fetch from proxy")
t.Errorf("Mismatching blob fetch from proxy")
return
}
}
}()
}
wg.Wait()
if t.Failed() {
t.FailNow()
}
remoteBlobCount := len(te.inRemote)
sbsMu.Lock()
@@ -404,7 +410,6 @@ func testProxyStoreServe(t *testing.T, te *testEnv, numClients int) {
}
}
localStats = te.LocalStats()
remoteStats = te.RemoteStats()
// Ensure remote unchanged

View File

@@ -1,17 +1,18 @@
package proxy
import (
"context"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/proxy/scheduler"
"github.com/opencontainers/go-digest"
)
// todo(richardscothern): from cache control header or config
const repositoryTTL = time.Duration(24 * 7 * time.Hour)
const repositoryTTL = 24 * 7 * time.Hour
type proxyManifestStore struct {
ctx context.Context
@@ -72,7 +73,7 @@ func (pms proxyManifestStore) Get(ctx context.Context, dgst digest.Digest, optio
// Schedule the manifest blob for removal
repoBlob, err := reference.WithDigest(pms.repositoryName, dgst)
if err != nil {
context.GetLogger(ctx).Errorf("Error creating reference: %s", err)
dcontext.GetLogger(ctx).Errorf("Error creating reference: %s", err)
return nil, err
}

View File

@@ -1,12 +1,12 @@
package proxy
import (
"context"
"io"
"sync"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/reference"
@@ -95,7 +95,8 @@ func newManifestStoreTestEnv(t *testing.T, name, tag string) *manifestStoreTestE
ctx := context.Background()
truthRegistry, err := storage.NewRegistry(ctx, inmemory.New(),
storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()),
storage.Schema1SigningKey(k))
storage.Schema1SigningKey(k),
storage.EnableSchema1)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
@@ -117,7 +118,7 @@ func newManifestStoreTestEnv(t *testing.T, name, tag string) *manifestStoreTestE
t.Fatalf(err.Error())
}
localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption, storage.Schema1SigningKey(k))
localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption, storage.Schema1SigningKey(k), storage.EnableSchema1)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
@@ -216,7 +217,7 @@ func TestProxyManifests(t *testing.T) {
t.Fatalf("Error checking existence")
}
if !exists {
t.Errorf("Unexpected non-existant manifest")
t.Errorf("Unexpected non-existent manifest")
}
if (*localStats)["exists"] != 1 && (*remoteStats)["exists"] != 1 {
@@ -251,7 +252,7 @@ func TestProxyManifests(t *testing.T) {
t.Fatal(err)
}
if !exists {
t.Errorf("Unexpected non-existant manifest")
t.Errorf("Unexpected non-existent manifest")
}
if (*localStats)["exists"] != 2 && (*remoteStats)["exists"] != 1 {

View File

@@ -1,6 +1,7 @@
package proxy
import (
"context"
"fmt"
"net/http"
"net/url"
@@ -8,7 +9,7 @@ import (
"github.com/docker/distribution"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/client"
"github.com/docker/distribution/registry/client/auth"
@@ -120,8 +121,21 @@ func (pr *proxyingRegistry) Repositories(ctx context.Context, repos []string, la
func (pr *proxyingRegistry) Repository(ctx context.Context, name reference.Named) (distribution.Repository, error) {
c := pr.authChallenger
tkopts := auth.TokenHandlerOptions{
Transport: http.DefaultTransport,
Credentials: c.credentialStore(),
Scopes: []auth.Scope{
auth.RepositoryScope{
Repository: name.Name(),
Actions: []string{"pull"},
},
},
Logger: dcontext.GetLogger(ctx),
}
tr := transport.NewTransport(http.DefaultTransport,
auth.NewAuthorizer(c.challengeManager(), auth.NewTokenHandler(http.DefaultTransport, c.credentialStore(), name.Name(), "pull")))
auth.NewAuthorizer(c.challengeManager(),
auth.NewTokenHandlerWithOptions(tkopts)))
localRepo, err := pr.embedded.Repository(ctx, name)
if err != nil {
@@ -132,7 +146,7 @@ func (pr *proxyingRegistry) Repository(ctx context.Context, name reference.Named
return nil, err
}
remoteRepo, err := client.NewRepository(ctx, name, pr.remoteURL.String(), tr)
remoteRepo, err := client.NewRepository(name, pr.remoteURL.String(), tr)
if err != nil {
return nil, err
}
@@ -218,7 +232,7 @@ func (r *remoteAuthChallenger) tryEstablishChallenges(ctx context.Context) error
return err
}
context.GetLogger(ctx).Infof("Challenge established with upstream : %s %s", remoteURL, r.cm)
dcontext.GetLogger(ctx).Infof("Challenge established with upstream : %s %s", remoteURL, r.cm)
return nil
}

View File

@@ -1,8 +1,9 @@
package proxy
import (
"context"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
)
// proxyTagService supports local and remote lookup of tags.

View File

@@ -1,13 +1,13 @@
package proxy
import (
"context"
"reflect"
"sort"
"sync"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
)
type mockTagStore struct {

View File

@@ -1,12 +1,13 @@
package scheduler
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/driver"
)
@@ -120,7 +121,7 @@ func (ttles *TTLExpirationScheduler) Start() error {
return fmt.Errorf("Scheduler already started")
}
context.GetLogger(ttles.ctx).Infof("Starting cached object TTL expiration scheduler...")
dcontext.GetLogger(ttles.ctx).Infof("Starting cached object TTL expiration scheduler...")
ttles.stopped = false
// Start timer for each deserialized entry
@@ -142,7 +143,7 @@ func (ttles *TTLExpirationScheduler) Start() error {
err := ttles.writeState()
if err != nil {
context.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
dcontext.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
} else {
ttles.indexDirty = false
}
@@ -163,7 +164,7 @@ func (ttles *TTLExpirationScheduler) add(r reference.Reference, ttl time.Duratio
Expiry: time.Now().Add(ttl),
EntryType: eType,
}
context.GetLogger(ttles.ctx).Infof("Adding new scheduler entry for %s with ttl=%s", entry.Key, entry.Expiry.Sub(time.Now()))
dcontext.GetLogger(ttles.ctx).Infof("Adding new scheduler entry for %s with ttl=%s", entry.Key, entry.Expiry.Sub(time.Now()))
if oldEntry, present := ttles.entries[entry.Key]; present && oldEntry.timer != nil {
oldEntry.timer.Stop()
}
@@ -193,10 +194,10 @@ func (ttles *TTLExpirationScheduler) startTimer(entry *schedulerEntry, ttl time.
ref, err := reference.Parse(entry.Key)
if err == nil {
if err := f(ref); err != nil {
context.GetLogger(ttles.ctx).Errorf("Scheduler error returned from OnExpire(%s): %s", entry.Key, err)
dcontext.GetLogger(ttles.ctx).Errorf("Scheduler error returned from OnExpire(%s): %s", entry.Key, err)
}
} else {
context.GetLogger(ttles.ctx).Errorf("Error unpacking reference: %s", err)
dcontext.GetLogger(ttles.ctx).Errorf("Error unpacking reference: %s", err)
}
delete(ttles.entries, entry.Key)
@@ -210,7 +211,7 @@ func (ttles *TTLExpirationScheduler) Stop() {
defer ttles.Unlock()
if err := ttles.writeState(); err != nil {
context.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
dcontext.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
}
for _, entry := range ttles.entries {

View File

@@ -1,12 +1,15 @@
package registry
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"rsc.io/letsencrypt"
@@ -14,18 +17,22 @@ import (
logstash "github.com/bshuster-repo/logrus-logstash-hook"
"github.com/bugsnag/bugsnag-go"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/health"
"github.com/docker/distribution/registry/handlers"
"github.com/docker/distribution/registry/listener"
"github.com/docker/distribution/uuid"
"github.com/docker/distribution/version"
"github.com/docker/go-metrics"
gorhandlers "github.com/gorilla/handlers"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/yvasiyarov/gorelic"
)
// this channel gets notified when process receives signal. It is global to ease unit testing
var quit = make(chan os.Signal, 1)
// ServeCmd is a cobra command for running the registry.
var ServeCmd = &cobra.Command{
Use: "serve <config>",
@@ -34,7 +41,7 @@ var ServeCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
// setup context
ctx := context.WithVersion(context.Background(), version.Version)
ctx := dcontext.WithVersion(dcontext.Background(), version.Version)
config, err := resolveConfiguration(args)
if err != nil {
@@ -57,6 +64,15 @@ var ServeCmd = &cobra.Command{
log.Fatalln(err)
}
if config.HTTP.Debug.Prometheus.Enabled {
path := config.HTTP.Debug.Prometheus.Path
if path == "" {
path = "/metrics"
}
log.Info("providing prometheus metrics on ", path)
http.Handle(path, metrics.Handler())
}
if err = registry.ListenAndServe(); err != nil {
log.Fatalln(err)
}
@@ -81,7 +97,7 @@ func NewRegistry(ctx context.Context, config *configuration.Configuration) (*Reg
// inject a logger into the uuid library. warns us if there is a problem
// with uuid generation under low entropy.
uuid.Loggerf = context.GetLogger(ctx).Warnf
uuid.Loggerf = dcontext.GetLogger(ctx).Warnf
app := handlers.NewApp(ctx, config)
// TODO(aaronl): The global scope of the health checks means NewRegistry
@@ -128,8 +144,6 @@ func (registry *Registry) ListenAndServe() error {
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
},
}
@@ -146,6 +160,9 @@ func (registry *Registry) ListenAndServe() error {
return err
}
}
if len(config.HTTP.TLS.LetsEncrypt.Hosts) > 0 {
m.SetHosts(config.HTTP.TLS.LetsEncrypt.Hosts)
}
tlsConf.GetCertificate = m.GetCertificate
} else {
tlsConf.Certificates = make([]tls.Certificate, 1)
@@ -170,7 +187,7 @@ func (registry *Registry) ListenAndServe() error {
}
for _, subj := range pool.Subjects() {
context.GetLogger(registry.app).Debugf("CA Subject: %s", string(subj))
dcontext.GetLogger(registry.app).Debugf("CA Subject: %s", string(subj))
}
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
@@ -178,12 +195,34 @@ func (registry *Registry) ListenAndServe() error {
}
ln = tls.NewListener(ln, tlsConf)
context.GetLogger(registry.app).Infof("listening on %v, tls", ln.Addr())
dcontext.GetLogger(registry.app).Infof("listening on %v, tls", ln.Addr())
} else {
context.GetLogger(registry.app).Infof("listening on %v", ln.Addr())
dcontext.GetLogger(registry.app).Infof("listening on %v", ln.Addr())
}
return registry.server.Serve(ln)
if config.HTTP.DrainTimeout == 0 {
return registry.server.Serve(ln)
}
// setup channel to get notified on SIGTERM signal
signal.Notify(quit, syscall.SIGTERM)
serveErr := make(chan error)
// Start serving in goroutine and listen for stop signal in main thread
go func() {
serveErr <- registry.server.Serve(ln)
}()
select {
case err := <-serveErr:
return err
case <-quit:
dcontext.GetLogger(registry.app).Info("stopping server gracefully. Draining connections for ", config.HTTP.DrainTimeout)
// shutdown the server with a grace period of configured timeout
c, cancel := context.WithTimeout(context.Background(), config.HTTP.DrainTimeout)
defer cancel()
return registry.server.Shutdown(c)
}
}
func configureReporting(app *handlers.App) http.Handler {
@@ -225,13 +264,6 @@ func configureReporting(app *handlers.App) http.Handler {
// configureLogging prepares the context with a logger using the
// configuration.
func configureLogging(ctx context.Context, config *configuration.Configuration) (context.Context, error) {
if config.Log.Level == "" && config.Log.Formatter == "" {
// If no config for logging is set, fallback to deprecated "Loglevel".
log.SetLevel(logLevel(config.Loglevel))
ctx = context.WithLogger(ctx, context.GetLogger(ctx))
return ctx, nil
}
log.SetLevel(logLevel(config.Log.Level))
formatter := config.Log.Formatter
@@ -270,8 +302,8 @@ func configureLogging(ctx context.Context, config *configuration.Configuration)
fields = append(fields, k)
}
ctx = context.WithValues(ctx, config.Log.Fields)
ctx = context.WithLogger(ctx, context.GetLogger(ctx, fields...))
ctx = dcontext.WithValues(ctx, config.Log.Fields)
ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, fields...))
}
return ctx, nil

View File

@@ -1,10 +1,19 @@
package registry
import (
"bufio"
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"reflect"
"testing"
"time"
"github.com/docker/distribution/configuration"
_ "github.com/docker/distribution/registry/storage/driver/inmemory"
)
// Tests to ensure nextProtos returns the correct protocols when:
@@ -28,3 +37,64 @@ func TestNextProtos(t *testing.T) {
t.Fatalf("expected protos to equal [http/1.1], got %s", protos)
}
}
func setupRegistry() (*Registry, error) {
config := &configuration.Configuration{}
// TODO: this needs to change to something ephemeral as the test will fail if there is any server
// already listening on port 5000
config.HTTP.Addr = ":5000"
config.HTTP.DrainTimeout = time.Duration(10) * time.Second
config.Storage = map[string]configuration.Parameters{"inmemory": map[string]interface{}{}}
return NewRegistry(context.Background(), config)
}
func TestGracefulShutdown(t *testing.T) {
registry, err := setupRegistry()
if err != nil {
t.Fatal(err)
}
// run registry server
var errchan chan error
go func() {
errchan <- registry.ListenAndServe()
}()
select {
case err = <-errchan:
t.Fatalf("Error listening: %v", err)
default:
}
// Wait for some unknown random time for server to start listening
time.Sleep(3 * time.Second)
// send incomplete request
conn, err := net.Dial("tcp", "localhost:5000")
if err != nil {
t.Fatal(err)
}
fmt.Fprintf(conn, "GET /v2/ ")
// send stop signal
quit <- os.Interrupt
time.Sleep(100 * time.Millisecond)
// try connecting again. it shouldn't
_, err = net.Dial("tcp", "localhost:5000")
if err == nil {
t.Fatal("Managed to connect after stopping.")
}
// make sure earlier request is not disconnected and response can be received
fmt.Fprintf(conn, "HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {
t.Fatal(err)
}
if resp.Status != "200 OK" {
t.Error("response status is not 200 OK: ", resp.Status)
}
if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
t.Error("Body is not {}; ", string(body))
}
}

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"os"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/docker/distribution/version"
@@ -18,6 +18,7 @@ func init() {
RootCmd.AddCommand(ServeCmd)
RootCmd.AddCommand(GCCmd)
GCCmd.Flags().BoolVarP(&dryRun, "dry-run", "d", false, "do everything except remove the blobs")
GCCmd.Flags().BoolVarP(&removeUntagged, "delete-untagged", "m", false, "delete manifests that are not currently referenced via tag")
RootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "show the version and exit")
}
@@ -36,6 +37,7 @@ var RootCmd = &cobra.Command{
}
var dryRun bool
var removeUntagged bool
// GCCmd is the cobra command that corresponds to the garbage-collect subcommand
var GCCmd = &cobra.Command{
@@ -56,7 +58,7 @@ var GCCmd = &cobra.Command{
os.Exit(1)
}
ctx := context.Background()
ctx := dcontext.Background()
ctx, err = configureLogging(ctx, config)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to configure logging with config: %s", err)
@@ -75,7 +77,10 @@ var GCCmd = &cobra.Command{
os.Exit(1)
}
err = storage.MarkAndSweep(ctx, driver, registry, dryRun)
err = storage.MarkAndSweep(ctx, driver, registry, storage.GCOpts{
DryRun: dryRun,
RemoveUntagged: removeUntagged,
})
if err != nil {
fmt.Fprintf(os.Stderr, "failed to garbage collect: %v", err)
os.Exit(1)

View File

@@ -2,17 +2,16 @@ package storage
import (
"bytes"
"context"
"crypto/sha256"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"reflect"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver/testdriver"
@@ -96,7 +95,7 @@ func TestSimpleBlobUpload(t *testing.T) {
}
// Do a resume, get unknown upload
blobUpload, err = bs.Resume(ctx, blobUpload.ID())
_, err = bs.Resume(ctx, blobUpload.ID())
if err != distribution.ErrBlobUploadUnknown {
t.Fatalf("unexpected error resuming upload, should be unknown: %v", err)
}
@@ -278,7 +277,7 @@ func TestSimpleBlobRead(t *testing.T) {
t.Fatalf("expected not found error when testing for existence: %v", err)
}
rc, err := bs.Open(ctx, dgst)
_, err = bs.Open(ctx, dgst)
if err != distribution.ErrBlobUnknown {
t.Fatalf("expected not found error when opening non-existent blob: %v", err)
}
@@ -300,7 +299,7 @@ func TestSimpleBlobRead(t *testing.T) {
t.Fatalf("committed blob has incorrect length: %v != %v", desc.Size, randomLayerSize)
}
rc, err = bs.Open(ctx, desc.Digest) // note that we are opening with original digest.
rc, err := bs.Open(ctx, desc.Digest) // note that we are opening with original digest.
if err != nil {
t.Fatalf("error opening blob with %v: %v", dgst, err)
}
@@ -323,7 +322,7 @@ func TestSimpleBlobRead(t *testing.T) {
}
// Now seek back the blob, read the whole thing and check against randomLayerData
offset, err := rc.Seek(0, os.SEEK_SET)
offset, err := rc.Seek(0, io.SeekStart)
if err != nil {
t.Fatalf("error seeking blob: %v", err)
}
@@ -342,7 +341,7 @@ func TestSimpleBlobRead(t *testing.T) {
}
// Reset the randomLayerReader and read back the buffer
_, err = randomLayerReader.Seek(0, os.SEEK_SET)
_, err = randomLayerReader.Seek(0, io.SeekStart)
if err != nil {
t.Fatalf("error resetting layer reader: %v", err)
}
@@ -397,7 +396,7 @@ func TestBlobMount(t *testing.T) {
t.Fatalf("error getting seeker size of random data: %v", err)
}
nn, err := io.Copy(blobUpload, randomDataReader)
_, err = io.Copy(blobUpload, randomDataReader)
if err != nil {
t.Fatalf("unexpected error uploading layer data: %v", err)
}
@@ -460,7 +459,7 @@ func TestBlobMount(t *testing.T) {
defer rc.Close()
h := sha256.New()
nn, err = io.Copy(h, rc)
nn, err := io.Copy(h, rc)
if err != nil {
t.Fatalf("error reading layer: %v", err)
}
@@ -573,17 +572,17 @@ func simpleUpload(t *testing.T, bs distribution.BlobIngester, blob []byte, expec
// the original state, returning the size. The state of the seeker should be
// treated as unknown if an error is returned.
func seekerSize(seeker io.ReadSeeker) (int64, error) {
current, err := seeker.Seek(0, os.SEEK_CUR)
current, err := seeker.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}
end, err := seeker.Seek(0, os.SEEK_END)
end, err := seeker.Seek(0, io.SeekEnd)
if err != nil {
return 0, err
}
resumed, err := seeker.Seek(current, os.SEEK_SET)
resumed, err := seeker.Seek(current, io.SeekStart)
if err != nil {
return 0, err
}

View File

@@ -1,9 +1,11 @@
package storage
import (
"context"
"expvar"
"sync/atomic"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/cache"
)
@@ -25,6 +27,10 @@ func (bsc *blobStatCollector) Metrics() cache.Metrics {
return bsc.metrics
}
func (bsc *blobStatCollector) Logger(ctx context.Context) cache.Logger {
return dcontext.GetLogger(ctx)
}
// blobStatterCacheMetrics keeps track of cache metrics for blob descriptor
// cache requests. Note this is kept globally and made available via expvar.
// For more detailed metrics, its recommend to instrument a particular cache

View File

@@ -1,12 +1,12 @@
package storage
import (
"context"
"fmt"
"net/http"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/driver"
"github.com/opencontainers/go-digest"
)

View File

@@ -1,10 +1,11 @@
package storage
import (
"context"
"path"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/driver"
"github.com/opencontainers/go-digest"
)
@@ -64,7 +65,7 @@ func (bs *blobStore) Put(ctx context.Context, mediaType string, p []byte) (distr
// content already present
return desc, nil
} else if err != distribution.ErrBlobUnknown {
context.GetLogger(ctx).Errorf("blobStore: error stating content (%v): %v", dgst, err)
dcontext.GetLogger(ctx).Errorf("blobStore: error stating content (%v): %v", dgst, err)
// real error, return it
return distribution.Descriptor{}, err
}
@@ -87,13 +88,12 @@ func (bs *blobStore) Put(ctx context.Context, mediaType string, p []byte) (distr
}
func (bs *blobStore) Enumerate(ctx context.Context, ingester func(dgst digest.Digest) error) error {
specPath, err := pathFor(blobsPathSpec{})
if err != nil {
return err
}
err = Walk(ctx, bs.driver, specPath, func(fileInfo driver.FileInfo) error {
return bs.driver.Walk(ctx, specPath, func(fileInfo driver.FileInfo) error {
// skip directories
if fileInfo.IsDir() {
return nil
@@ -113,7 +113,6 @@ func (bs *blobStore) Enumerate(ctx context.Context, ingester func(dgst digest.Di
return ingester(digest)
})
return err
}
// path returns the canonical path for the blob identified by digest. The blob
@@ -195,7 +194,7 @@ func (bs *blobStatter) Stat(ctx context.Context, dgst digest.Digest) (distributi
// NOTE(stevvooe): This represents a corruption situation. Somehow, we
// calculated a blob path and then detected a directory. We log the
// error and then error on the side of not knowing about the blob.
context.GetLogger(ctx).Warnf("blob path should not be a directory: %q", path)
dcontext.GetLogger(ctx).Warnf("blob path should not be a directory: %q", path)
return distribution.Descriptor{}, distribution.ErrBlobUnknown
}

View File

@@ -1,6 +1,7 @@
package storage
import (
"context"
"errors"
"fmt"
"io"
@@ -8,7 +9,7 @@ import (
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/opencontainers/go-digest"
"github.com/sirupsen/logrus"
@@ -32,7 +33,7 @@ type blobWriter struct {
id string
startedAt time.Time
digester digest.Digester
written int64 // track the contiguous write
written int64 // track the write to digester
fileWriter storagedriver.FileWriter
driver storagedriver.StorageDriver
@@ -56,7 +57,7 @@ func (bw *blobWriter) StartedAt() time.Time {
// Commit marks the upload as completed, returning a valid descriptor. The
// final size and digest are checked against the first descriptor provided.
func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
context.GetLogger(ctx).Debug("(*blobWriter).Commit")
dcontext.GetLogger(ctx).Debug("(*blobWriter).Commit")
if err := bw.fileWriter.Commit(); err != nil {
return distribution.Descriptor{}, err
@@ -94,20 +95,16 @@ func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor)
// Cancel the blob upload process, releasing any resources associated with
// the writer and canceling the operation.
func (bw *blobWriter) Cancel(ctx context.Context) error {
context.GetLogger(ctx).Debug("(*blobWriter).Cancel")
dcontext.GetLogger(ctx).Debug("(*blobWriter).Cancel")
if err := bw.fileWriter.Cancel(); err != nil {
return err
}
if err := bw.Close(); err != nil {
context.GetLogger(ctx).Errorf("error closing blobwriter: %s", err)
dcontext.GetLogger(ctx).Errorf("error closing blobwriter: %s", err)
}
if err := bw.removeResources(ctx); err != nil {
return err
}
return nil
return bw.removeResources(ctx)
}
func (bw *blobWriter) Size() int64 {
@@ -122,7 +119,12 @@ func (bw *blobWriter) Write(p []byte) (int, error) {
return 0, err
}
n, err := io.MultiWriter(bw.fileWriter, bw.digester.Hash()).Write(p)
_, err := bw.fileWriter.Write(p)
if err != nil {
return 0, err
}
n, err := bw.digester.Hash().Write(p)
bw.written += int64(n)
return n, err
@@ -136,7 +138,11 @@ func (bw *blobWriter) ReadFrom(r io.Reader) (n int64, err error) {
return 0, err
}
nn, err := io.Copy(io.MultiWriter(bw.fileWriter, bw.digester.Hash()), r)
// Using a TeeReader instead of MultiWriter ensures Copy returns
// the amount written to the digester as well as ensuring that we
// write to the fileWriter first
tee := io.TeeReader(r, bw.fileWriter)
nn, err := io.Copy(bw.digester.Hash(), tee)
bw.written += nn
return nn, err
@@ -261,7 +267,7 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
}
if !verified {
context.GetLoggerWithFields(ctx,
dcontext.GetLoggerWithFields(ctx,
map[interface{}]interface{}{
"canonical": canonical,
"provided": desc.Digest,
@@ -365,7 +371,7 @@ func (bw *blobWriter) removeResources(ctx context.Context) error {
// This should be uncommon enough such that returning an error
// should be okay. At this point, the upload should be mostly
// complete, but perhaps the backend became unaccessible.
context.GetLogger(ctx).Errorf("unable to delete layer upload resources %q: %v", dirPath, err)
dcontext.GetLogger(ctx).Errorf("unable to delete layer upload resources %q: %v", dirPath, err)
return err
}
}
@@ -383,7 +389,7 @@ func (bw *blobWriter) Reader() (io.ReadCloser, error) {
}
switch err.(type) {
case storagedriver.PathNotFoundError:
context.GetLogger(bw.ctx).Debugf("Nothing found on try %d, sleeping...", try)
dcontext.GetLogger(bw.ctx).Debugf("Nothing found on try %d, sleeping...", try)
time.Sleep(1 * time.Second)
try++
default:

View File

@@ -3,18 +3,15 @@
package storage
import (
"context"
"encoding"
"fmt"
"hash"
"path"
"strconv"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/sirupsen/logrus"
"github.com/stevvooe/resumable"
// register resumable hashes with import
_ "github.com/stevvooe/resumable/sha256"
_ "github.com/stevvooe/resumable/sha512"
)
// resumeDigest attempts to restore the state of the internal hash function
@@ -24,12 +21,13 @@ func (bw *blobWriter) resumeDigest(ctx context.Context) error {
return errResumableDigestNotAvailable
}
h, ok := bw.digester.Hash().(resumable.Hash)
h, ok := bw.digester.Hash().(encoding.BinaryUnmarshaler)
if !ok {
return errResumableDigestNotAvailable
}
offset := bw.fileWriter.Size()
if offset == int64(h.Len()) {
if offset == bw.written {
// State of digester is already at the requested offset.
return nil
}
@@ -52,20 +50,21 @@ func (bw *blobWriter) resumeDigest(ctx context.Context) error {
if hashStateMatch.offset == 0 {
// No need to load any state, just reset the hasher.
h.Reset()
h.(hash.Hash).Reset()
} else {
storedState, err := bw.driver.GetContent(ctx, hashStateMatch.path)
if err != nil {
return err
}
if err = h.Restore(storedState); err != nil {
if err = h.UnmarshalBinary(storedState); err != nil {
return err
}
bw.written = hashStateMatch.offset
}
// Mind the gap.
if gapLen := offset - int64(h.Len()); gapLen > 0 {
if gapLen := offset - bw.written; gapLen > 0 {
return errResumableDigestNotAvailable
}
@@ -120,26 +119,26 @@ func (bw *blobWriter) storeHashState(ctx context.Context) error {
return errResumableDigestNotAvailable
}
h, ok := bw.digester.Hash().(resumable.Hash)
h, ok := bw.digester.Hash().(encoding.BinaryMarshaler)
if !ok {
return errResumableDigestNotAvailable
}
state, err := h.MarshalBinary()
if err != nil {
return err
}
uploadHashStatePath, err := pathFor(uploadHashStatePathSpec{
name: bw.blobStore.repository.Named().String(),
id: bw.id,
alg: bw.digester.Digest().Algorithm(),
offset: int64(h.Len()),
offset: bw.written,
})
if err != nil {
return err
}
hashState, err := h.State()
if err != nil {
return err
}
return bw.driver.PutContent(ctx, uploadHashStatePath, hashState)
return bw.driver.PutContent(ctx, uploadHashStatePath, state)
}

View File

@@ -1,11 +1,11 @@
package cachecheck
import (
"context"
"reflect"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/cache"
"github.com/opencontainers/go-digest"
)
@@ -26,12 +26,12 @@ func checkBlobDescriptorCacheEmptyRepository(ctx context.Context, t *testing.T,
t.Fatalf("expected unknown blob error with empty store: %v", err)
}
cache, err := provider.RepositoryScoped("")
_, err := provider.RepositoryScoped("")
if err == nil {
t.Fatalf("expected an error when asking for invalid repo")
}
cache, err = provider.RepositoryScoped("foo/bar")
cache, err := provider.RepositoryScoped("foo/bar")
if err != nil {
t.Fatalf("unexpected error getting repository: %v", err)
}

View File

@@ -1,10 +1,11 @@
package cache
import (
"github.com/docker/distribution/context"
"github.com/opencontainers/go-digest"
"context"
"github.com/docker/distribution"
prometheus "github.com/docker/distribution/metrics"
"github.com/opencontainers/go-digest"
)
// Metrics is used to hold metric counters
@@ -16,12 +17,20 @@ type Metrics struct {
Misses uint64
}
// Logger can be provided on the MetricsTracker to log errors.
//
// Usually, this is just a proxy to dcontext.GetLogger.
type Logger interface {
Errorf(format string, args ...interface{})
}
// MetricsTracker represents a metric tracker
// which simply counts the number of hits and misses.
type MetricsTracker interface {
Hit()
Miss()
Metrics() Metrics
Logger(context.Context) Logger
}
type cachedBlobStatter struct {
@@ -30,6 +39,11 @@ type cachedBlobStatter struct {
tracker MetricsTracker
}
var (
// cacheCount is the number of total cache request received/hits/misses
cacheCount = prometheus.StorageNamespace.NewLabeledCounter("cache", "The number of cache request received", "type")
)
// NewCachedBlobStatter creates a new statter which prefers a cache and
// falls back to a backend.
func NewCachedBlobStatter(cache distribution.BlobDescriptorService, backend distribution.BlobDescriptorService) distribution.BlobDescriptorService {
@@ -50,20 +64,22 @@ func NewCachedBlobStatterWithMetrics(cache distribution.BlobDescriptorService, b
}
func (cbds *cachedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
cacheCount.WithValues("Request").Inc(1)
desc, err := cbds.cache.Stat(ctx, dgst)
if err != nil {
if err != distribution.ErrBlobUnknown {
context.GetLogger(ctx).Errorf("error retrieving descriptor from cache: %v", err)
logErrorf(ctx, cbds.tracker, "error retrieving descriptor from cache: %v", err)
}
goto fallback
}
cacheCount.WithValues("Hit").Inc(1)
if cbds.tracker != nil {
cbds.tracker.Hit()
}
return desc, nil
fallback:
cacheCount.WithValues("Miss").Inc(1)
if cbds.tracker != nil {
cbds.tracker.Miss()
}
@@ -73,7 +89,7 @@ fallback:
}
if err := cbds.cache.SetDescriptor(ctx, dgst, desc); err != nil {
context.GetLogger(ctx).Errorf("error adding descriptor %v to cache: %v", desc.Digest, err)
logErrorf(ctx, cbds.tracker, "error adding descriptor %v to cache: %v", desc.Digest, err)
}
return desc, err
@@ -95,7 +111,19 @@ func (cbds *cachedBlobStatter) Clear(ctx context.Context, dgst digest.Digest) er
func (cbds *cachedBlobStatter) SetDescriptor(ctx context.Context, dgst digest.Digest, desc distribution.Descriptor) error {
if err := cbds.cache.SetDescriptor(ctx, dgst, desc); err != nil {
context.GetLogger(ctx).Errorf("error adding descriptor %v to cache: %v", desc.Digest, err)
logErrorf(ctx, cbds.tracker, "error adding descriptor %v to cache: %v", desc.Digest, err)
}
return nil
}
func logErrorf(ctx context.Context, tracker MetricsTracker, format string, args ...interface{}) {
if tracker == nil {
return
}
logger := tracker.Logger(ctx)
if logger == nil {
return
}
logger.Errorf(format, args...)
}

View File

@@ -1,10 +1,10 @@
package memory
import (
"context"
"sync"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache"
"github.com/opencontainers/go-digest"

View File

@@ -1,10 +1,10 @@
package redis
import (
"context"
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache"
"github.com/garyburd/redigo/redis"
@@ -72,7 +72,7 @@ func (rbds *redisBlobDescriptorService) Clear(ctx context.Context, dgst digest.D
defer conn.Close()
// Not atomic in redis <= 2.3
reply, err := conn.Do("HDEL", rbds.blobDescriptorHashKey(dgst), "digest", "length", "mediatype")
reply, err := conn.Do("HDEL", rbds.blobDescriptorHashKey(dgst), "digest", "size", "mediatype")
if err != nil {
return err
}

View File

@@ -1,23 +1,21 @@
package storage
import (
"context"
"errors"
"io"
"path"
"strings"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/driver"
)
// errFinishedWalk signals an early exit to the walk when the current query
// is satisfied.
var errFinishedWalk = errors.New("finished walk")
// Returns a list, or partial list, of repositories in the registry.
// Because it's a quite expensive operation, it should only be used when building up
// an initial set of repositories.
func (reg *registry) Repositories(ctx context.Context, repos []string, last string) (n int, err error) {
var finishedWalk bool
var foundRepos []string
if len(repos) == 0 {
@@ -29,7 +27,7 @@ func (reg *registry) Repositories(ctx context.Context, repos []string, last stri
return 0, err
}
err = Walk(ctx, reg.blobStore.driver, root, func(fileInfo driver.FileInfo) error {
err = reg.blobStore.driver.Walk(ctx, root, func(fileInfo driver.FileInfo) error {
err := handleRepository(fileInfo, root, last, func(repoPath string) error {
foundRepos = append(foundRepos, repoPath)
return nil
@@ -40,7 +38,8 @@ func (reg *registry) Repositories(ctx context.Context, repos []string, last stri
// if we've filled our array, no need to walk any further
if len(foundRepos) == len(repos) {
return errFinishedWalk
finishedWalk = true
return driver.ErrSkipDir
}
return nil
@@ -48,14 +47,11 @@ func (reg *registry) Repositories(ctx context.Context, repos []string, last stri
n = copy(repos, foundRepos)
switch err {
case nil:
// nil means that we completed walk and didn't fill buffer. No more
// records are available.
err = io.EOF
case errFinishedWalk:
// more records are available.
err = nil
if err != nil {
return n, err
} else if !finishedWalk {
// We didn't fill buffer. No more records are available.
return n, io.EOF
}
return n, err
@@ -68,13 +64,23 @@ func (reg *registry) Enumerate(ctx context.Context, ingester func(string) error)
return err
}
err = Walk(ctx, reg.blobStore.driver, root, func(fileInfo driver.FileInfo) error {
err = reg.blobStore.driver.Walk(ctx, root, func(fileInfo driver.FileInfo) error {
return handleRepository(fileInfo, root, "", ingester)
})
return err
}
// Remove removes a repository from storage
func (reg *registry) Remove(ctx context.Context, name reference.Named) error {
root, err := pathFor(repositoriesRootPathSpec{})
if err != nil {
return err
}
repoDir := path.Join(root, name.Name())
return reg.driver.Delete(ctx, repoDir)
}
// lessPath returns true if one path a is less than path b.
//
// A component-wise comparison is done, rather than the lexical comparison of
@@ -144,9 +150,9 @@ func handleRepository(fileInfo driver.FileInfo, root, last string, fn func(repoP
return err
}
}
return ErrSkipDir
return driver.ErrSkipDir
} else if strings.HasPrefix(file, "_") {
return ErrSkipDir
return driver.ErrSkipDir
}
return nil

View File

@@ -1,13 +1,13 @@
package storage
import (
"context"
"fmt"
"io"
"math/rand"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver"
@@ -26,7 +26,7 @@ type setupEnv struct {
func setupFS(t *testing.T) *setupEnv {
d := inmemory.New()
ctx := context.Background()
registry, err := NewRegistry(ctx, d, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableRedirect)
registry, err := NewRegistry(ctx, d, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableRedirect, EnableSchema1)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
@@ -207,7 +207,7 @@ func testEq(a, b []string, size int) bool {
func setupBadWalkEnv(t *testing.T) *setupEnv {
d := newBadListDriver()
ctx := context.Background()
registry, err := NewRegistry(ctx, d, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableRedirect)
registry, err := NewRegistry(ctx, d, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableRedirect, EnableSchema1)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
@@ -273,7 +273,7 @@ func BenchmarkPathCompareNative(B *testing.B) {
for i := 0; i < B.N; i++ {
c := a < b
c = c && false
_ = c && false
}
}
@@ -285,7 +285,7 @@ func BenchmarkPathCompareNativeEqual(B *testing.B) {
for i := 0; i < B.N; i++ {
c := a < b
c = c && false
_ = c && false
}
}

View File

@@ -1,22 +0,0 @@
// +build !noresumabledigest
package storage
import (
"testing"
digest "github.com/opencontainers/go-digest"
"github.com/stevvooe/resumable"
_ "github.com/stevvooe/resumable/sha256"
)
// TestResumableDetection just ensures that the resumable capability of a hash
// is exposed through the digester type, which is just a hash plus a Digest
// method.
func TestResumableDetection(t *testing.T) {
d := digest.Canonical.Digester()
if _, ok := d.Hash().(resumable.Hash); !ok {
t.Fatalf("expected digester to implement resumable.Hash: %#v, %v", d, d.Hash())
}
}

View File

@@ -5,6 +5,7 @@ package azure
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
@@ -12,7 +13,6 @@ import (
"strings"
"time"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
@@ -87,7 +87,7 @@ func New(accountName, accountKey, container, realm string) (*Driver, error) {
// Create registry container
containerRef := blobClient.GetContainerReference(container)
if _, err = containerRef.CreateIfNotExists(); err != nil {
if _, err = containerRef.CreateIfNotExists(nil); err != nil {
return nil, err
}
@@ -104,7 +104,8 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
blob, err := d.client.GetBlob(d.container, path)
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blob, err := blobRef.Get(nil)
if err != nil {
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
@@ -118,7 +119,10 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
if limit := 64 * 1024 * 1024; len(contents) > limit { // max size for block blobs uploaded via single "Put Blob"
// max size for block blobs uploaded via single "Put Blob" for version after "2016-05-31"
// https://docs.microsoft.com/en-us/rest/api/storageservices/put-blob#remarks
const limit = 256 * 1024 * 1024
if len(contents) > limit {
return fmt.Errorf("uploading %d bytes with PutContent is not supported; limit: %d bytes", len(contents), limit)
}
@@ -133,41 +137,49 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
// losing the existing data while migrating it to BlockBlob type. However,
// expectation is the clients pushing will be retrying when they get an error
// response.
props, err := d.client.GetBlobProperties(d.container, path)
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
err := blobRef.GetProperties(nil)
if err != nil && !is404(err) {
return fmt.Errorf("failed to get blob properties: %v", err)
}
if err == nil && props.BlobType != azure.BlobTypeBlock {
if err := d.client.DeleteBlob(d.container, path, nil); err != nil {
return fmt.Errorf("failed to delete legacy blob (%s): %v", props.BlobType, err)
if err == nil && blobRef.Properties.BlobType != azure.BlobTypeBlock {
if err := blobRef.Delete(nil); err != nil {
return fmt.Errorf("failed to delete legacy blob (%s): %v", blobRef.Properties.BlobType, err)
}
}
r := bytes.NewReader(contents)
return d.client.CreateBlockBlobFromReader(d.container, path, uint64(len(contents)), r, nil)
// reset properties to empty before doing overwrite
blobRef.Properties = azure.BlobProperties{}
return blobRef.CreateBlockBlobFromReader(r, nil)
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
if ok, err := d.client.BlobExists(d.container, path); err != nil {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
if ok, err := blobRef.Exists(); err != nil {
return nil, err
} else if !ok {
return nil, storagedriver.PathNotFoundError{Path: path}
}
info, err := d.client.GetBlobProperties(d.container, path)
err := blobRef.GetProperties(nil)
if err != nil {
return nil, err
}
size := int64(info.ContentLength)
info := blobRef.Properties
size := info.ContentLength
if offset >= size {
return ioutil.NopCloser(bytes.NewReader(nil)), nil
}
bytesRange := fmt.Sprintf("%v-", offset)
resp, err := d.client.GetBlobRange(d.container, path, bytesRange, nil)
resp, err := blobRef.GetRange(&azure.GetBlobRangeOptions{
Range: &azure.BlobRange{
Start: uint64(offset),
End: 0,
},
})
if err != nil {
return nil, err
}
@@ -177,20 +189,22 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
blobExists, err := d.client.BlobExists(d.container, path)
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
blobExists, err := blobRef.Exists()
if err != nil {
return nil, err
}
var size int64
if blobExists {
if append {
blobProperties, err := d.client.GetBlobProperties(d.container, path)
err = blobRef.GetProperties(nil)
if err != nil {
return nil, err
}
blobProperties := blobRef.Properties
size = blobProperties.ContentLength
} else {
err := d.client.DeleteBlob(d.container, path, nil)
err = blobRef.Delete(nil)
if err != nil {
return nil, err
}
@@ -199,7 +213,7 @@ func (d *driver) Writer(ctx context.Context, path string, append bool) (storaged
if append {
return nil, storagedriver.PathNotFoundError{Path: path}
}
err := d.client.PutAppendBlob(d.container, path, nil)
err = blobRef.PutAppendBlob(nil)
if err != nil {
return nil, err
}
@@ -211,24 +225,21 @@ func (d *driver) Writer(ctx context.Context, path string, append bool) (storaged
// Stat retrieves the FileInfo for the given path, including the current size
// in bytes and the creation time.
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
// Check if the path is a blob
if ok, err := d.client.BlobExists(d.container, path); err != nil {
if ok, err := blobRef.Exists(); err != nil {
return nil, err
} else if ok {
blob, err := d.client.GetBlobProperties(d.container, path)
if err != nil {
return nil, err
}
mtim, err := time.Parse(http.TimeFormat, blob.LastModified)
err = blobRef.GetProperties(nil)
if err != nil {
return nil, err
}
blobProperties := blobRef.Properties
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: int64(blob.ContentLength),
ModTime: mtim,
Size: blobProperties.ContentLength,
ModTime: time.Time(blobProperties.LastModified),
IsDir: false,
}}, nil
}
@@ -281,8 +292,10 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
sourceBlobURL := d.client.GetBlobURL(d.container, sourcePath)
err := d.client.CopyBlob(d.container, destPath, sourceBlobURL)
srcBlobRef := d.client.GetContainerReference(d.container).GetBlobReference(sourcePath)
sourceBlobURL := srcBlobRef.GetURL()
destBlobRef := d.client.GetContainerReference(d.container).GetBlobReference(destPath)
err := destBlobRef.Copy(sourceBlobURL, nil)
if err != nil {
if is404(err) {
return storagedriver.PathNotFoundError{Path: sourcePath}
@@ -290,12 +303,13 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e
return err
}
return d.client.DeleteBlob(d.container, sourcePath, nil)
return srcBlobRef.Delete(nil)
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, path string) error {
ok, err := d.client.DeleteBlobIfExists(d.container, path, nil)
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
ok, err := blobRef.DeleteIfExists(nil)
if err != nil {
return err
}
@@ -310,7 +324,8 @@ func (d *driver) Delete(ctx context.Context, path string) error {
}
for _, b := range blobs {
if err = d.client.DeleteBlob(d.container, b, nil); err != nil {
blobRef = d.client.GetContainerReference(d.container).GetBlobReference(b)
if err = blobRef.Delete(nil); err != nil {
return err
}
}
@@ -333,7 +348,21 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
expiresTime = t
}
}
return d.client.GetBlobSASURI(d.container, path, expiresTime, "r")
blobRef := d.client.GetContainerReference(d.container).GetBlobReference(path)
return blobRef.GetSASURI(azure.BlobSASOptions{
BlobServiceSASPermissions: azure.BlobServiceSASPermissions{
Read: true,
},
SASOptions: azure.SASOptions{
Expiry: expiresTime,
},
})
}
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func (d *driver) Walk(ctx context.Context, path string, f storagedriver.WalkFn) error {
return storagedriver.WalkFallback(ctx, d, path, f)
}
// directDescendants will find direct descendants (blobs or virtual containers)
@@ -461,7 +490,8 @@ func (w *writer) Cancel() error {
return fmt.Errorf("already committed")
}
w.cancelled = true
return w.driver.client.DeleteBlob(w.driver.container, w.path, nil)
blobRef := w.driver.client.GetContainerReference(w.driver.container).GetBlobReference(w.path)
return blobRef.Delete(nil)
}
func (w *writer) Commit() error {
@@ -484,12 +514,13 @@ type blockWriter struct {
func (bw *blockWriter) Write(p []byte) (int, error) {
n := 0
blobRef := bw.client.GetContainerReference(bw.container).GetBlobReference(bw.path)
for offset := 0; offset < len(p); offset += maxChunkSize {
chunkSize := maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
err := bw.client.AppendBlock(bw.container, bw.path, p[offset:offset+chunkSize], nil)
err := blobRef.AppendBlock(p[offset:offset+chunkSize], nil)
if err != nil {
return n, err
}

View File

@@ -38,12 +38,25 @@
package base
import (
"context"
"io"
"time"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
prometheus "github.com/docker/distribution/metrics"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/go-metrics"
)
var (
// storageAction is the metrics of blob related operations
storageAction = prometheus.StorageNamespace.NewLabeledTimer("action", "The number of seconds that the storage action takes", "driver", "action")
)
func init() {
metrics.Register(prometheus.StorageNamespace)
}
// Base provides a wrapper around a storagedriver implementation that provides
// common path and bounds checking.
type Base struct {
@@ -79,32 +92,37 @@ func (base *Base) setDriverName(e error) error {
// GetContent wraps GetContent of underlying storage driver.
func (base *Base) GetContent(ctx context.Context, path string) ([]byte, error) {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.GetContent(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
start := time.Now()
b, e := base.StorageDriver.GetContent(ctx, path)
storageAction.WithValues(base.Name(), "GetContent").UpdateSince(start)
return b, base.setDriverName(e)
}
// PutContent wraps PutContent of underlying storage driver.
func (base *Base) PutContent(ctx context.Context, path string, content []byte) error {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.PutContent(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
return base.setDriverName(base.StorageDriver.PutContent(ctx, path, content))
start := time.Now()
err := base.setDriverName(base.StorageDriver.PutContent(ctx, path, content))
storageAction.WithValues(base.Name(), "PutContent").UpdateSince(start)
return err
}
// Reader wraps Reader of underlying storage driver.
func (base *Base) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.Reader(%q, %d)", base.Name(), path, offset)
if offset < 0 {
@@ -121,7 +139,7 @@ func (base *Base) Reader(ctx context.Context, path string, offset int64) (io.Rea
// Writer wraps Writer of underlying storage driver.
func (base *Base) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.Writer(%q, %v)", base.Name(), path, append)
if !storagedriver.PathRegexp.MatchString(path) {
@@ -134,33 +152,37 @@ func (base *Base) Writer(ctx context.Context, path string, append bool) (storage
// Stat wraps Stat of underlying storage driver.
func (base *Base) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.Stat(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) && path != "/" {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
start := time.Now()
fi, e := base.StorageDriver.Stat(ctx, path)
storageAction.WithValues(base.Name(), "Stat").UpdateSince(start)
return fi, base.setDriverName(e)
}
// List wraps List of underlying storage driver.
func (base *Base) List(ctx context.Context, path string) ([]string, error) {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.List(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) && path != "/" {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
start := time.Now()
str, e := base.StorageDriver.List(ctx, path)
storageAction.WithValues(base.Name(), "List").UpdateSince(start)
return str, base.setDriverName(e)
}
// Move wraps Move of underlying storage driver.
func (base *Base) Move(ctx context.Context, sourcePath string, destPath string) error {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.Move(%q, %q", base.Name(), sourcePath, destPath)
if !storagedriver.PathRegexp.MatchString(sourcePath) {
@@ -169,30 +191,50 @@ func (base *Base) Move(ctx context.Context, sourcePath string, destPath string)
return storagedriver.InvalidPathError{Path: destPath, DriverName: base.StorageDriver.Name()}
}
return base.setDriverName(base.StorageDriver.Move(ctx, sourcePath, destPath))
start := time.Now()
err := base.setDriverName(base.StorageDriver.Move(ctx, sourcePath, destPath))
storageAction.WithValues(base.Name(), "Move").UpdateSince(start)
return err
}
// Delete wraps Delete of underlying storage driver.
func (base *Base) Delete(ctx context.Context, path string) error {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.Delete(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
return base.setDriverName(base.StorageDriver.Delete(ctx, path))
start := time.Now()
err := base.setDriverName(base.StorageDriver.Delete(ctx, path))
storageAction.WithValues(base.Name(), "Delete").UpdateSince(start)
return err
}
// URLFor wraps URLFor of underlying storage driver.
func (base *Base) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
ctx, done := context.WithTrace(ctx)
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.URLFor(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
start := time.Now()
str, e := base.StorageDriver.URLFor(ctx, path, options)
storageAction.WithValues(base.Name(), "URLFor").UpdateSince(start)
return str, base.setDriverName(e)
}
// Walk wraps Walk of underlying storage driver.
func (base *Base) Walk(ctx context.Context, path string, f storagedriver.WalkFn) error {
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.Walk(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) && path != "/" {
return storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
return base.setDriverName(base.StorageDriver.Walk(ctx, path, f))
}

View File

@@ -1,10 +1,13 @@
package base
import (
"context"
"fmt"
"io"
"reflect"
"strconv"
"sync"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
@@ -15,6 +18,46 @@ type regulator struct {
available uint64
}
// GetLimitFromParameter takes an interface type as decoded from the YAML
// configuration and returns a uint64 representing the maximum number of
// concurrent calls given a minimum limit and default.
//
// If the parameter supplied is of an invalid type this returns an error.
func GetLimitFromParameter(param interface{}, min, def uint64) (uint64, error) {
limit := def
switch v := param.(type) {
case string:
var err error
if limit, err = strconv.ParseUint(v, 0, 64); err != nil {
return limit, fmt.Errorf("parameter must be an integer, '%v' invalid", param)
}
case uint64:
limit = v
case int, int32, int64:
val := reflect.ValueOf(v).Convert(reflect.TypeOf(param)).Int()
// if param is negative casting to uint64 will wrap around and
// give you the hugest thread limit ever. Let's be sensible, here
if val > 0 {
limit = uint64(val)
} else {
limit = min
}
case uint, uint32:
limit = reflect.ValueOf(v).Convert(reflect.TypeOf(param)).Uint()
case nil:
// use the default
default:
return 0, fmt.Errorf("invalid value '%#v'", param)
}
if limit < min {
return min, nil
}
return limit, nil
}
// NewRegulator wraps the given driver and is used to regulate concurrent calls
// to the given storage driver to a maximum of the given limit. This is useful
// for storage drivers that would otherwise create an unbounded number of OS

View File

@@ -1,6 +1,7 @@
package base
import (
"fmt"
"sync"
"testing"
"time"
@@ -65,3 +66,33 @@ func TestRegulatorEnterExit(t *testing.T) {
}
}
}
func TestGetLimitFromParameter(t *testing.T) {
tests := []struct {
Input interface{}
Expected uint64
Min uint64
Default uint64
Err error
}{
{"foo", 0, 5, 5, fmt.Errorf("parameter must be an integer, 'foo' invalid")},
{"50", 50, 5, 5, nil},
{"5", 25, 25, 50, nil}, // lower than Min returns Min
{nil, 50, 25, 50, nil}, // nil returns default
{812, 812, 25, 50, nil},
}
for _, item := range tests {
t.Run(fmt.Sprint(item.Input), func(t *testing.T) {
actual, err := GetLimitFromParameter(item.Input, item.Min, item.Default)
if err != nil && item.Err != nil && err.Error() != item.Err.Error() {
t.Fatalf("GetLimitFromParameter error, expected %#v got %#v", item.Err, err)
}
if actual != item.Expected {
t.Fatalf("GetLimitFromParameter result error, expected %d got %d", item.Expected, actual)
}
})
}
}

View File

@@ -3,16 +3,14 @@ package filesystem
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"reflect"
"strconv"
"time"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
@@ -85,33 +83,9 @@ func fromParametersImpl(parameters map[string]interface{}) (*DriverParameters, e
rootDirectory = fmt.Sprint(rootDir)
}
// Get maximum number of threads for blocking filesystem operations,
// if specified
threads := parameters["maxthreads"]
switch v := threads.(type) {
case string:
if maxThreads, err = strconv.ParseUint(v, 0, 64); err != nil {
return nil, fmt.Errorf("maxthreads parameter must be an integer, %v invalid", threads)
}
case uint64:
maxThreads = v
case int, int32, int64:
val := reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Int()
// If threads is negative casting to uint64 will wrap around and
// give you the hugest thread limit ever. Let's be sensible, here
if val > 0 {
maxThreads = uint64(val)
}
case uint, uint32:
maxThreads = reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Uint()
case nil:
// do nothing
default:
return nil, fmt.Errorf("invalid value for maxthreads: %#v", threads)
}
if maxThreads < minThreads {
maxThreads = minThreads
maxThreads, err = base.GetLimitFromParameter(parameters["maxthreads"], minThreads, defaultMaxThreads)
if err != nil {
return nil, fmt.Errorf("maxthreads config error: %s", err.Error())
}
}
@@ -184,11 +158,11 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read
return nil, err
}
seekPos, err := file.Seek(int64(offset), os.SEEK_SET)
seekPos, err := file.Seek(offset, io.SeekStart)
if err != nil {
file.Close()
return nil, err
} else if seekPos < int64(offset) {
} else if seekPos < offset {
file.Close()
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
@@ -217,12 +191,12 @@ func (d *driver) Writer(ctx context.Context, subPath string, append bool) (stora
return nil, err
}
} else {
n, err := fp.Seek(0, os.SEEK_END)
n, err := fp.Seek(0, io.SeekEnd)
if err != nil {
fp.Close()
return nil, err
}
offset = int64(n)
offset = n
}
return newFileWriter(fp, offset), nil
@@ -315,6 +289,12 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
return "", storagedriver.ErrUnsupportedMethod{}
}
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func (d *driver) Walk(ctx context.Context, path string, f storagedriver.WalkFn) error {
return storagedriver.WalkFallback(ctx, d, path, f)
}
// fullPath returns the absolute path of a key within the Driver's storage.
func (d *driver) fullPath(subPath string) string {
return path.Join(d.rootDirectory, subPath)

View File

@@ -16,6 +16,8 @@ package gcs
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
@@ -29,20 +31,16 @@ import (
"strings"
"time"
"golang.org/x/net/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/oauth2/jwt"
"google.golang.org/api/googleapi"
"google.golang.org/cloud"
"google.golang.org/cloud/storage"
"github.com/sirupsen/logrus"
ctx "github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
)
const (
@@ -52,6 +50,8 @@ const (
uploadSessionContentType = "application/x-docker-upload-session"
minChunkSize = 256 * 1024
defaultChunkSize = 20 * minChunkSize
defaultMaxConcurrency = 50
minConcurrency = 25
maxTries = 5
)
@@ -67,6 +67,12 @@ type driverParameters struct {
client *http.Client
rootDirectory string
chunkSize int
// maxConcurrency limits the number of concurrent driver operations
// to GCS, which ultimately increases reliability of many simultaneous
// pushes by ensuring we aren't DoSing our own server with many
// connections.
maxConcurrency uint64
}
func init() {
@@ -92,6 +98,16 @@ type driver struct {
chunkSize int
}
// Wrapper wraps `driver` with a throttler, ensuring that no more than N
// GCS actions can occur concurrently. The default limit is 75.
type Wrapper struct {
baseEmbed
}
type baseEmbed struct {
base.Base
}
// FromParameters constructs a new Driver with a given parameters map
// Required parameters:
// - bucket
@@ -143,6 +159,31 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
return nil, err
}
ts = jwtConf.TokenSource(context.Background())
} else if credentials, ok := parameters["credentials"]; ok {
credentialMap, ok := credentials.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("The credentials were not specified in the correct format")
}
stringMap := map[string]interface{}{}
for k, v := range credentialMap {
key, ok := k.(string)
if !ok {
return nil, fmt.Errorf("One of the credential keys was not a string: %s", fmt.Sprint(k))
}
stringMap[key] = v
}
data, err := json.Marshal(stringMap)
if err != nil {
return nil, fmt.Errorf("Failed to marshal gcs credentials to json")
}
jwtConf, err = google.JWTConfigFromJSON(data, storage.ScopeFullControl)
if err != nil {
return nil, err
}
ts = jwtConf.TokenSource(context.Background())
} else {
var err error
ts, err = google.DefaultTokenSource(context.Background(), storage.ScopeFullControl)
@@ -151,13 +192,19 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
}
}
maxConcurrency, err := base.GetLimitFromParameter(parameters["maxconcurrency"], minConcurrency, defaultMaxConcurrency)
if err != nil {
return nil, fmt.Errorf("maxconcurrency config error: %s", err)
}
params := driverParameters{
bucket: fmt.Sprint(bucket),
rootDirectory: fmt.Sprint(rootDirectory),
email: jwtConf.Email,
privateKey: jwtConf.PrivateKey,
client: oauth2.NewClient(context.Background(), ts),
chunkSize: chunkSize,
bucket: fmt.Sprint(bucket),
rootDirectory: fmt.Sprint(rootDirectory),
email: jwtConf.Email,
privateKey: jwtConf.PrivateKey,
client: oauth2.NewClient(context.Background(), ts),
chunkSize: chunkSize,
maxConcurrency: maxConcurrency,
}
return New(params)
@@ -181,8 +228,12 @@ func New(params driverParameters) (storagedriver.StorageDriver, error) {
chunkSize: params.chunkSize,
}
return &base.Base{
StorageDriver: d,
return &Wrapper{
baseEmbed: baseEmbed{
Base: base.Base{
StorageDriver: base.NewRegulator(d, params.maxConcurrency),
},
},
}, nil
}
@@ -194,7 +245,7 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte.
// This should primarily be used for small objects.
func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) {
func (d *driver) GetContent(context context.Context, path string) ([]byte, error) {
gcsContext := d.context(context)
name := d.pathToKey(path)
var rc io.ReadCloser
@@ -220,7 +271,7 @@ func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) {
// PutContent stores the []byte content at a location designated by "path".
// This should primarily be used for small objects.
func (d *driver) PutContent(context ctx.Context, path string, contents []byte) error {
func (d *driver) PutContent(context context.Context, path string, contents []byte) error {
return retry(func() error {
wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path))
wc.ContentType = "application/octet-stream"
@@ -231,7 +282,7 @@ func (d *driver) PutContent(context ctx.Context, path string, contents []byte) e
// Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset.
func (d *driver) Reader(context ctx.Context, path string, offset int64) (io.ReadCloser, error) {
func (d *driver) Reader(context context.Context, path string, offset int64) (io.ReadCloser, error) {
res, err := getObject(d.client, d.bucket, d.pathToKey(path), offset)
if err != nil {
if res != nil {
@@ -290,7 +341,7 @@ func getObject(client *http.Client, bucket string, name string, offset int64) (*
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(context ctx.Context, path string, append bool) (storagedriver.FileWriter, error) {
func (d *driver) Writer(context context.Context, path string, append bool) (storagedriver.FileWriter, error) {
writer := &writer{
client: d.client,
bucket: d.bucket,
@@ -542,7 +593,7 @@ func retry(req request) error {
// Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time.
func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo, error) {
func (d *driver) Stat(context context.Context, path string) (storagedriver.FileInfo, error) {
var fi storagedriver.FileInfoFields
//try to get as file
gcsContext := d.context(context)
@@ -588,7 +639,7 @@ func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo,
// List returns a list of the objects that are direct descendants of the
//given path.
func (d *driver) List(context ctx.Context, path string) ([]string, error) {
func (d *driver) List(context context.Context, path string) ([]string, error) {
var query *storage.Query
query = &storage.Query{}
query.Delimiter = "/"
@@ -626,7 +677,7 @@ func (d *driver) List(context ctx.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the
// original object.
func (d *driver) Move(context ctx.Context, sourcePath string, destPath string) error {
func (d *driver) Move(context context.Context, sourcePath string, destPath string) error {
gcsContext := d.context(context)
_, err := storageCopyObject(gcsContext, d.bucket, d.pathToKey(sourcePath), d.bucket, d.pathToKey(destPath), nil)
if err != nil {
@@ -674,7 +725,7 @@ func (d *driver) listAll(context context.Context, prefix string) ([]string, erro
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(context ctx.Context, path string) error {
func (d *driver) Delete(context context.Context, path string) error {
prefix := d.pathToDirKey(path)
gcsContext := d.context(context)
keys, err := d.listAll(gcsContext, prefix)
@@ -749,7 +800,7 @@ func storageCopyObject(context context.Context, srcBucket, srcName string, destB
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// Returns ErrUnsupportedMethod if this driver has no privateKey
func (d *driver) URLFor(context ctx.Context, path string, options map[string]interface{}) (string, error) {
func (d *driver) URLFor(context context.Context, path string, options map[string]interface{}) (string, error) {
if d.privateKey == nil {
return "", storagedriver.ErrUnsupportedMethod{}
}
@@ -782,6 +833,12 @@ func (d *driver) URLFor(context ctx.Context, path string, options map[string]int
return storage.SignedURL(d.bucket, name, opts)
}
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func (d *driver) Walk(ctx context.Context, path string, f storagedriver.WalkFn) error {
return storagedriver.WalkFallback(ctx, d, path, f)
}
func startSession(client *http.Client, bucket string, name string) (uri string, err error) {
u := &url.URL{
Scheme: "https",
@@ -856,12 +913,12 @@ func putChunk(client *http.Client, sessionURI string, chunk []byte, from int64,
return bytesPut, err
}
func (d *driver) context(context ctx.Context) context.Context {
func (d *driver) context(context context.Context) context.Context {
return cloud.WithContext(context, dummyProjectID, d.client)
}
func (d *driver) pathToKey(path string) string {
return strings.TrimRight(d.rootDirectory+strings.TrimLeft(path, "/"), "/")
return strings.TrimSpace(strings.TrimRight(d.rootDirectory+strings.TrimLeft(path, "/"), "/"))
}
func (d *driver) pathToDirKey(path string) string {

View File

@@ -3,12 +3,12 @@
package gcs
import (
"fmt"
"io/ioutil"
"os"
"testing"
"fmt"
ctx "github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
"golang.org/x/oauth2"
@@ -49,7 +49,7 @@ func init() {
var email string
var privateKey []byte
ts, err = google.DefaultTokenSource(ctx.Background(), storage.ScopeFullControl)
ts, err = google.DefaultTokenSource(dcontext.Background(), storage.ScopeFullControl)
if err != nil {
// Assume that the file contents are within the environment variable since it exists
// but does not contain a valid file path
@@ -65,7 +65,7 @@ func init() {
if email == "" {
panic("Error reading JWT config : missing client_email property")
}
ts = jwtConfig.TokenSource(ctx.Background())
ts = jwtConfig.TokenSource(dcontext.Background())
}
gcsDriverConstructor = func(rootDirectory string) (storagedriver.StorageDriver, error) {
@@ -74,7 +74,7 @@ func init() {
rootDirectory: root,
email: email,
privateKey: privateKey,
client: oauth2.NewClient(ctx.Background(), ts),
client: oauth2.NewClient(dcontext.Background(), ts),
chunkSize: defaultChunkSize,
}
@@ -104,7 +104,7 @@ func TestCommitEmpty(t *testing.T) {
}
filename := "/test"
ctx := ctx.Background()
ctx := dcontext.Background()
writer, err := driver.Writer(ctx, filename, false)
defer driver.Delete(ctx, filename)
@@ -150,7 +150,7 @@ func TestCommit(t *testing.T) {
}
filename := "/test"
ctx := ctx.Background()
ctx := dcontext.Background()
contents := make([]byte, defaultChunkSize)
writer, err := driver.Writer(ctx, filename, false)
@@ -247,7 +247,7 @@ func TestEmptyRootList(t *testing.T) {
filename := "/test"
contents := []byte("contents")
ctx := ctx.Background()
ctx := dcontext.Background()
err = rootedDriver.PutContent(ctx, filename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
@@ -290,7 +290,7 @@ func TestMoveDirectory(t *testing.T) {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
ctx := ctx.Background()
ctx := dcontext.Background()
contents := []byte("contents")
// Create a regular file.
err = driver.PutContent(ctx, "/parent/dir/foo", contents)

View File

@@ -1,13 +1,13 @@
package inmemory
import (
"context"
"fmt"
"io"
"io/ioutil"
"sync"
"time"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
@@ -73,7 +73,7 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
d.mutex.RLock()
defer d.mutex.RUnlock()
rc, err := d.Reader(ctx, path, 0)
rc, err := d.reader(ctx, path, 0)
if err != nil {
return nil, err
}
@@ -108,6 +108,10 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read
d.mutex.RLock()
defer d.mutex.RUnlock()
return d.reader(ctx, path, offset)
}
func (d *driver) reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
if offset < 0 {
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
@@ -240,6 +244,12 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
return "", storagedriver.ErrUnsupportedMethod{}
}
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func (d *driver) Walk(ctx context.Context, path string, f storagedriver.WalkFn) error {
return storagedriver.WalkFallback(ctx, d, path, f)
}
type writer struct {
d *driver
f *file

View File

@@ -4,6 +4,7 @@
package middleware
import (
"context"
"crypto/x509"
"encoding/pem"
"fmt"
@@ -13,9 +14,9 @@ import (
"time"
"github.com/aws/aws-sdk-go/service/cloudfront/sign"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
"github.com/docker/distribution/registry/storage/driver/middleware"
)
// cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
@@ -23,6 +24,7 @@ import (
// then issues HTTP Temporary Redirects to this CloudFront content URL.
type cloudFrontStorageMiddleware struct {
storagedriver.StorageDriver
awsIPs *awsIPs
urlSigner *sign.URLSigner
baseURL string
duration time.Duration
@@ -33,7 +35,13 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
// newCloudFrontLayerHandler constructs and returns a new CloudFront
// LayerHandler implementation.
// Required options: baseurl, privatekey, keypairid
// Optional options: ipFilteredBy, awsregion
// ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP, default value. "aws", only aws IP goes
// to S3 directly. "awsregion", only regions listed in awsregion options goes to S3 directly
// awsregion: a comma separated string of AWS regions.
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
// parse baseurl
base, ok := options["baseurl"]
if !ok {
return nil, fmt.Errorf("no baseurl provided")
@@ -51,6 +59,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
if _, err := url.Parse(baseURL); err != nil {
return nil, fmt.Errorf("invalid baseurl: %v", err)
}
// parse privatekey to get pkPath
pk, ok := options["privatekey"]
if !ok {
return nil, fmt.Errorf("no privatekey provided")
@@ -59,6 +69,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
if !ok {
return nil, fmt.Errorf("privatekey must be a string")
}
// parse keypairid
kpid, ok := options["keypairid"]
if !ok {
return nil, fmt.Errorf("no keypairid provided")
@@ -68,12 +80,13 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
return nil, fmt.Errorf("keypairid must be a string")
}
// get urlSigner from the file specified in pkPath
pkBytes, err := ioutil.ReadFile(pkPath)
if err != nil {
return nil, fmt.Errorf("failed to read privatekey file: %s", err)
}
block, _ := pem.Decode([]byte(pkBytes))
block, _ := pem.Decode(pkBytes)
if block == nil {
return nil, fmt.Errorf("failed to decode private key as an rsa private key")
}
@@ -81,12 +94,11 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
if err != nil {
return nil, err
}
urlSigner := sign.NewURLSigner(keypairID, privateKey)
// parse duration
duration := 20 * time.Minute
d, ok := options["duration"]
if ok {
if d, ok := options["duration"]; ok {
switch d := d.(type) {
case time.Duration:
duration = d
@@ -99,11 +111,62 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
}
}
// parse updatefrenquency
updateFrequency := defaultUpdateFrequency
if u, ok := options["updatefrenquency"]; ok {
switch u := u.(type) {
case time.Duration:
updateFrequency = u
case string:
updateFreq, err := time.ParseDuration(u)
if err != nil {
return nil, fmt.Errorf("invalid updatefrenquency: %s", err)
}
duration = updateFreq
}
}
// parse iprangesurl
ipRangesURL := defaultIPRangesURL
if i, ok := options["iprangesurl"]; ok {
if iprangeurl, ok := i.(string); ok {
ipRangesURL = iprangeurl
} else {
return nil, fmt.Errorf("iprangesurl must be a string")
}
}
// parse ipfilteredby
var awsIPs *awsIPs
if ipFilteredBy := options["ipfilteredby"].(string); ok {
switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) {
case "", "none":
awsIPs = nil
case "aws":
newAWSIPs(ipRangesURL, updateFrequency, nil)
case "awsregion":
var awsRegion []string
if regions, ok := options["awsregion"].(string); ok {
for _, awsRegions := range strings.Split(regions, ",") {
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
}
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion)
} else {
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
}
default:
return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion")
}
} else {
return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion")
}
return &cloudFrontStorageMiddleware{
StorageDriver: storageDriver,
urlSigner: urlSigner,
baseURL: baseURL,
duration: duration,
awsIPs: awsIPs,
}, nil
}
@@ -113,16 +176,21 @@ type S3BucketKeyer interface {
S3BucketKey(path string) string
}
// Resolve returns an http.Handler which can serve the contents of the given
// Layer, or an error if not supported by the storagedriver.
// URLFor attempts to find a url which may be used to retrieve the file at the given path.
// Returns an error if the file cannot be found.
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
// TODO(endophage): currently only supports S3
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
if !ok {
context.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver")
dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver")
return lh.StorageDriver.URLFor(ctx, path, options)
}
if eligibleForS3(ctx, lh.awsIPs) {
return lh.StorageDriver.URLFor(ctx, path, options)
}
// Get signed cloudfront url.
cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration))
if err != nil {
return "", err

View File

@@ -0,0 +1,223 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
"time"
dcontext "github.com/docker/distribution/context"
)
const (
// ipRangesURL is the URL to get definition of AWS IPs
defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
// updateFrequency tells how frequently AWS IPs need to be updated
defaultUpdateFrequency = time.Hour * 12
)
// newAWSIPs returns a New awsIP object.
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs {
ips := &awsIPs{
host: host,
updateFrequency: updateFrequency,
awsRegion: awsRegion,
updaterStopChan: make(chan bool),
}
if err := ips.tryUpdate(); err != nil {
dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP")
}
go ips.updater()
return ips
}
// awsIPs tracks a list of AWS ips, filtered by awsRegion
type awsIPs struct {
host string
updateFrequency time.Duration
ipv4 []net.IPNet
ipv6 []net.IPNet
mutex sync.RWMutex
awsRegion []string
updaterStopChan chan bool
initialized bool
}
type awsIPResponse struct {
Prefixes []prefixEntry `json:"prefixes"`
V6Prefixes []prefixEntry `json:"ipv6_prefixes"`
}
type prefixEntry struct {
IPV4Prefix string `json:"ip_prefix"`
IPV6Prefix string `json:"ipv6_prefix"`
Region string `json:"region"`
Service string `json:"service"`
}
func fetchAWSIPs(url string) (awsIPResponse, error) {
var response awsIPResponse
resp, err := http.Get(url)
if err != nil {
return response, err
}
if resp.StatusCode != 200 {
body, _ := ioutil.ReadAll(resp.Body)
return response, fmt.Errorf("failed to fetch network data. response = %s", body)
}
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&response)
if err != nil {
return response, err
}
return response, nil
}
// tryUpdate attempts to download the new set of ip addresses.
// tryUpdate must be thread safe with contains
func (s *awsIPs) tryUpdate() error {
response, err := fetchAWSIPs(s.host)
if err != nil {
return err
}
var ipv4 []net.IPNet
var ipv6 []net.IPNet
processAddress := func(output *[]net.IPNet, prefix string, region string) {
regionAllowed := false
if len(s.awsRegion) > 0 {
for _, ar := range s.awsRegion {
if strings.ToLower(region) == ar {
regionAllowed = true
break
}
}
} else {
regionAllowed = true
}
_, network, err := net.ParseCIDR(prefix)
if err != nil {
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
"cidr": prefix,
}).Error("unparseable cidr")
return
}
if regionAllowed {
*output = append(*output, *network)
}
}
for _, prefix := range response.Prefixes {
processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region)
}
for _, prefix := range response.V6Prefixes {
processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region)
}
s.mutex.Lock()
defer s.mutex.Unlock()
// Update each attr of awsips atomically.
s.ipv4 = ipv4
s.ipv6 = ipv6
s.initialized = true
return nil
}
// This function is meant to be run in a background goroutine.
// It will periodically update the ips from aws.
func (s *awsIPs) updater() {
defer close(s.updaterStopChan)
for {
time.Sleep(s.updateFrequency)
select {
case <-s.updaterStopChan:
dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal")
return
default:
err := s.tryUpdate()
if err != nil {
dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP")
}
}
}
}
// getCandidateNetworks returns either the ipv4 or ipv6 networks
// that were last read from aws. The networks returned
// have the same type as the ip address provided.
func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet {
s.mutex.RLock()
defer s.mutex.RUnlock()
if ip.To4() != nil {
return s.ipv4
} else if ip.To16() != nil {
return s.ipv6
} else {
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
"ip": ip,
}).Error("unknown ip address format")
// assume mismatch, pass through cloudfront
return nil
}
}
// Contains determines whether the host is within aws.
func (s *awsIPs) contains(ip net.IP) bool {
networks := s.getCandidateNetworks(ip)
for _, network := range networks {
if network.Contains(ip) {
return true
}
}
return false
}
// parseIPFromRequest attempts to extract the ip address of the
// client that made the request
func parseIPFromRequest(ctx context.Context) (net.IP, error) {
request, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
ipStr := dcontext.RemoteIP(request)
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr)
}
return ip, nil
}
// eligibleForS3 checks if a request is eligible for using S3 directly
// Return true only when the IP belongs to a specific aws region and user-agent is docker
func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool {
if awsIPs != nil && awsIPs.initialized {
if addr, err := parseIPFromRequest(ctx); err == nil {
request, err := dcontext.GetRequest(ctx)
if err != nil {
dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err)
} else {
loggerField := map[interface{}]interface{}{
"user-client": request.UserAgent(),
"ip": dcontext.RemoteIP(request),
}
if awsIPs.contains(addr) {
dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")
return true
}
dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
}
} else {
dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
}
}
return false
}

View File

@@ -0,0 +1,402 @@
package middleware
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
dcontext "github.com/docker/distribution/context"
"reflect" // used as a replacement for testify
)
// Rather than pull in all of testify
func assertEqual(t *testing.T, x, y interface{}) {
if !reflect.DeepEqual(x, y) {
t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y)
t.FailNow()
}
}
type mockIPRangeHandler struct {
data awsIPResponse
}
func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bytes, err := json.Marshal(m.data)
if err != nil {
w.WriteHeader(500)
return
}
w.Write(bytes)
}
func newTestHandler(data awsIPResponse) *httptest.Server {
return httptest.NewServer(mockIPRangeHandler{
data: data,
})
}
func serverIPRanges(server *httptest.Server) string {
return fmt.Sprintf("%s/", server.URL)
}
func setupTest(data awsIPResponse) *httptest.Server {
// This is a basic schema which only claims the exact ip
// is in aws.
server := newTestHandler(data)
return server
}
func TestS3TryUpdate(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "123.231.123.231/32"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
assertEqual(t, 1, len(ips.ipv4))
assertEqual(t, 0, len(ips.ipv6))
}
func TestMatchIPV6(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
V6Prefixes: []prefixEntry{
{IPV6Prefix: "ff00::/16"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
assertEqual(t, 1, len(ips.ipv6))
assertEqual(t, 0, len(ips.ipv4))
}
func TestMatchIPV4(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "192.168.0.0/24"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4_2(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4WithRegionMatched(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"})
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"})
ips.tryUpdate()
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestInvalidData(t *testing.T) {
t.Parallel()
// Invalid entries from aws should be ignored.
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "9000"},
{IPV4Prefix: "192.168.0.0/24"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, 1, len(ips.ipv4))
}
func TestInvalidNetworkType(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "192.168.0.0/24"},
},
V6Prefixes: []prefixEntry{
{IPV6Prefix: "ff00::/8"},
{IPV6Prefix: "fe00::/8"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
}
func TestParsing(t *testing.T) {
var data = `{
"prefixes": [{
"ip_prefix": "192.168.0.0",
"region": "someregion",
"service": "s3"}],
"ipv6_prefixes": [{
"ipv6_prefix": "2001:4860:4860::8888",
"region": "anotherregion",
"service": "ec2"}]
}`
rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) })
t.Parallel()
server := httptest.NewServer(rawMockHandler)
defer server.Close()
schema, err := fetchAWSIPs(server.URL)
assertEqual(t, nil, err)
assertEqual(t, 1, len(schema.Prefixes))
assertEqual(t, prefixEntry{
IPV4Prefix: "192.168.0.0",
Region: "someregion",
Service: "s3",
}, schema.Prefixes[0])
assertEqual(t, 1, len(schema.V6Prefixes))
assertEqual(t, prefixEntry{
IPV6Prefix: "2001:4860:4860::8888",
Region: "anotherregion",
Service: "ec2",
}, schema.V6Prefixes[0])
}
func TestUpdateCalledRegularly(t *testing.T) {
t.Parallel()
updateCount := 0
server := httptest.NewServer(http.HandlerFunc(
func(rw http.ResponseWriter, req *http.Request) {
updateCount++
rw.Write([]byte("ok"))
}))
defer server.Close()
newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil)
time.Sleep(time.Second*4 + time.Millisecond*500)
if updateCount < 4 {
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
}
}
func TestEligibleForS3(t *testing.T) {
awsIPs := &awsIPs{
ipv4: []net.IPNet{{
IP: net.ParseIP("192.168.1.1"),
Mask: net.IPv4Mask(255, 255, 255, 0),
}},
initialized: true,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
cases := []struct {
Context context.Context
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: true},
{Context: makeContext("192.168.0.2"), Expected: false},
}
for _, testCase := range cases {
name := fmt.Sprintf("Client IP = %v",
testCase.Context.Value("http.request.ip"))
t.Run(name, func(t *testing.T) {
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
})
}
}
func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) {
awsIPs := &awsIPs{
ipv4: []net.IPNet{{
IP: net.ParseIP("192.168.1.1"),
Mask: net.IPv4Mask(255, 255, 255, 0),
}},
initialized: false,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
cases := []struct {
Context context.Context
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: false},
{Context: makeContext("192.168.0.2"), Expected: false},
}
for _, testCase := range cases {
name := fmt.Sprintf("Client IP = %v",
testCase.Context.Value("http.request.ip"))
t.Run(name, func(t *testing.T) {
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
})
}
}
// populate ips with a number of different ipv4 and ipv6 networks, for the purposes
// of benchmarking contains() performance.
func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) {
generateNetworks := func(dest *[]net.IPNet, bytes int, count int) {
for i := 0; i < count; i++ {
ip := make([]byte, bytes)
_, err := rand.Read(ip)
if err != nil {
b.Fatalf("failed to generate network for test : %s", err.Error())
}
mask := make([]byte, bytes)
for i := 0; i < bytes; i++ {
mask[i] = 0xff
}
*dest = append(*dest, net.IPNet{
IP: ip,
Mask: mask,
})
}
}
generateNetworks(&ips.ipv4, 4, ipv4Count)
generateNetworks(&ips.ipv6, 16, ipv6Count)
}
func BenchmarkContainsRandom(b *testing.B) {
// Generate a random network configuration, of size comparable to
// aws official networks list
// curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length'
// 941
numNetworksPerType := 1000 // keep in sync with the above
// intentionally skip constructor when creating awsIPs, to avoid updater routine.
// This benchmark is only concerned with contains() performance.
awsIPs := awsIPs{}
populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType)
ipv4 := make([][]byte, b.N)
ipv6 := make([][]byte, b.N)
for i := 0; i < b.N; i++ {
ipv4[i] = make([]byte, 4)
ipv6[i] = make([]byte, 16)
rand.Read(ipv4[i])
rand.Read(ipv6[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
awsIPs.contains(ipv4[i])
awsIPs.contains(ipv6[i])
}
}
func BenchmarkContainsProd(b *testing.B) {
awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil)
ipv4 := make([][]byte, b.N)
ipv6 := make([][]byte, b.N)
for i := 0; i < b.N; i++ {
ipv4[i] = make([]byte, 4)
ipv6[i] = make([]byte, 16)
rand.Read(ipv4[i])
rand.Read(ipv6[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
awsIPs.contains(ipv4[i])
awsIPs.contains(ipv6[i])
}
}

View File

@@ -1,10 +1,10 @@
package middleware
import (
"context"
"fmt"
"net/url"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
)

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"testing"
check "gopkg.in/check.v1"
@@ -36,7 +37,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
c.Assert(m.scheme, check.Equals, "https")
c.Assert(m.host, check.Equals, "example.com:5443")
url, err := middleware.URLFor(nil, "/rick/data", nil)
url, err := middleware.URLFor(context.TODO(), "/rick/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com:5443/rick/data")
}
@@ -52,7 +53,7 @@ func (s *MiddlewareSuite) TestHTTP(c *check.C) {
c.Assert(m.scheme, check.Equals, "http")
c.Assert(m.host, check.Equals, "example.com")
url, err := middleware.URLFor(nil, "morty/data", nil)
url, err := middleware.URLFor(context.TODO(), "morty/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "http://example.com/morty/data")
}

View File

@@ -13,6 +13,7 @@ package oss
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
@@ -22,8 +23,6 @@ import (
"strings"
"time"
"github.com/docker/distribution/context"
"github.com/denverdino/aliyungo/oss"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
@@ -480,6 +479,12 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
return signedURL, nil
}
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func (d *driver) Walk(ctx context.Context, path string, f storagedriver.WalkFn) error {
return storagedriver.WalkFallback(ctx, d, path, f)
}
func (d *driver) ossPath(path string) string {
return strings.TrimLeft(strings.TrimRight(d.RootDirectory, "/")+path, "/")
}

View File

@@ -4,16 +4,14 @@ package oss
import (
"io/ioutil"
"os"
"strconv"
"testing"
alioss "github.com/denverdino/aliyungo/oss"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
//"log"
"os"
"strconv"
"testing"
"gopkg.in/check.v1"
)

View File

@@ -13,6 +13,8 @@ package s3
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
@@ -29,11 +31,12 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/client/transport"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
@@ -89,6 +92,7 @@ type DriverParameters struct {
Encrypt bool
KeyID string
Secure bool
SkipVerify bool
V4Auth bool
ChunkSize int64
MultipartCopyChunkSize int64
@@ -102,25 +106,11 @@ type DriverParameters struct {
}
func init() {
for _, region := range []string{
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2",
"eu-west-1",
"eu-west-2",
"eu-central-1",
"ap-south-1",
"ap-southeast-1",
"ap-southeast-2",
"ap-northeast-1",
"ap-northeast-2",
"sa-east-1",
"cn-north-1",
"us-gov-west-1",
"ca-central-1",
} {
validRegions[region] = struct{}{}
partitions := endpoints.DefaultPartitions()
for _, p := range partitions {
for region := range p.Regions() {
validRegions[region] = struct{}{}
}
}
for _, objectACL := range []string{
@@ -196,14 +186,14 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
regionEndpoint = ""
}
regionName, ok := parameters["region"]
regionName := parameters["region"]
if regionName == nil || fmt.Sprint(regionName) == "" {
return nil, fmt.Errorf("No region parameter provided")
}
region := fmt.Sprint(regionName)
// Don't check the region value if a custom endpoint is provided.
if regionEndpoint == "" {
if _, ok = validRegions[region]; !ok {
if _, ok := validRegions[region]; !ok {
return nil, fmt.Errorf("Invalid region provided: %v", region)
}
}
@@ -247,6 +237,23 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
return nil, fmt.Errorf("The secure parameter should be a boolean")
}
skipVerifyBool := false
skipVerify := parameters["skipverify"]
switch skipVerify := skipVerify.(type) {
case string:
b, err := strconv.ParseBool(skipVerify)
if err != nil {
return nil, fmt.Errorf("The skipVerify parameter should be a boolean")
}
skipVerifyBool = b
case bool:
skipVerifyBool = skipVerify
case nil:
// do nothing
default:
return nil, fmt.Errorf("The skipVerify parameter should be a boolean")
}
v4Bool := true
v4auth := parameters["v4auth"]
switch v4auth := v4auth.(type) {
@@ -343,6 +350,7 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
encryptBool,
fmt.Sprint(keyID),
secureBool,
skipVerifyBool,
v4Bool,
chunkSize,
multipartCopyChunkSize,
@@ -397,6 +405,10 @@ func New(params DriverParameters) (*Driver, error) {
}
awsConfig := aws.NewConfig()
sess, err := session.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create new session: %v", err)
}
creds := credentials.NewChainCredentials([]credentials.Provider{
&credentials.StaticProvider{
Value: credentials.Value{
@@ -407,7 +419,7 @@ func New(params DriverParameters) (*Driver, error) {
},
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{},
&ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(session.New())},
&ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
})
if params.RegionEndpoint != "" {
@@ -419,13 +431,29 @@ func New(params DriverParameters) (*Driver, error) {
awsConfig.WithRegion(params.Region)
awsConfig.WithDisableSSL(!params.Secure)
if params.UserAgent != "" {
awsConfig.WithHTTPClient(&http.Client{
Transport: transport.NewTransport(http.DefaultTransport, transport.NewHeaderRequestModifier(http.Header{http.CanonicalHeaderKey("User-Agent"): []string{params.UserAgent}})),
})
if params.UserAgent != "" || params.SkipVerify {
httpTransport := http.DefaultTransport
if params.SkipVerify {
httpTransport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
}
if params.UserAgent != "" {
awsConfig.WithHTTPClient(&http.Client{
Transport: transport.NewTransport(httpTransport, transport.NewHeaderRequestModifier(http.Header{http.CanonicalHeaderKey("User-Agent"): []string{params.UserAgent}})),
})
} else {
awsConfig.WithHTTPClient(&http.Client{
Transport: transport.NewTransport(httpTransport),
})
}
}
s3obj := s3.New(session.New(awsConfig))
sess, err = session.NewSession(awsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create new session with aws config: %v", err)
}
s3obj := s3.New(sess)
// enable S3 compatible signature v2 signing instead
if !params.V4Auth {
@@ -448,11 +476,11 @@ func New(params DriverParameters) (*Driver, error) {
// }
d := &driver{
S3: s3obj,
Bucket: params.Bucket,
ChunkSize: params.ChunkSize,
Encrypt: params.Encrypt,
KeyID: params.KeyID,
S3: s3obj,
Bucket: params.Bucket,
ChunkSize: params.ChunkSize,
Encrypt: params.Encrypt,
KeyID: params.KeyID,
MultipartCopyChunkSize: params.MultipartCopyChunkSize,
MultipartCopyMaxConcurrency: params.MultipartCopyMaxConcurrency,
MultipartCopyThresholdSize: params.MultipartCopyThresholdSize,
@@ -874,6 +902,136 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
return req.Presign(expiresIn)
}
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func (d *driver) Walk(ctx context.Context, from string, f storagedriver.WalkFn) error {
path := from
if !strings.HasSuffix(path, "/") {
path = path + "/"
}
prefix := ""
if d.s3Path("") == "" {
prefix = "/"
}
var objectCount int64
if err := d.doWalk(ctx, &objectCount, d.s3Path(path), prefix, f); err != nil {
return err
}
// S3 doesn't have the concept of empty directories, so it'll return path not found if there are no objects
if objectCount == 0 {
return storagedriver.PathNotFoundError{Path: from}
}
return nil
}
type walkInfoContainer struct {
storagedriver.FileInfoFields
prefix *string
}
// Path provides the full path of the target of this file info.
func (wi walkInfoContainer) Path() string {
return wi.FileInfoFields.Path
}
// Size returns current length in bytes of the file. The return value can
// be used to write to the end of the file at path. The value is
// meaningless if IsDir returns true.
func (wi walkInfoContainer) Size() int64 {
return wi.FileInfoFields.Size
}
// ModTime returns the modification time for the file. For backends that
// don't have a modification time, the creation time should be returned.
func (wi walkInfoContainer) ModTime() time.Time {
return wi.FileInfoFields.ModTime
}
// IsDir returns true if the path is a directory.
func (wi walkInfoContainer) IsDir() bool {
return wi.FileInfoFields.IsDir
}
func (d *driver) doWalk(parentCtx context.Context, objectCount *int64, path, prefix string, f storagedriver.WalkFn) error {
var retError error
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(d.Bucket),
Prefix: aws.String(path),
Delimiter: aws.String("/"),
MaxKeys: aws.Int64(listMax),
}
ctx, done := dcontext.WithTrace(parentCtx)
defer done("s3aws.ListObjectsV2Pages(%s)", path)
listObjectErr := d.S3.ListObjectsV2PagesWithContext(ctx, listObjectsInput, func(objects *s3.ListObjectsV2Output, lastPage bool) bool {
*objectCount += *objects.KeyCount
walkInfos := make([]walkInfoContainer, 0, *objects.KeyCount)
for _, dir := range objects.CommonPrefixes {
commonPrefix := *dir.Prefix
walkInfos = append(walkInfos, walkInfoContainer{
prefix: dir.Prefix,
FileInfoFields: storagedriver.FileInfoFields{
IsDir: true,
Path: strings.Replace(commonPrefix[:len(commonPrefix)-1], d.s3Path(""), prefix, 1),
},
})
}
for _, file := range objects.Contents {
walkInfos = append(walkInfos, walkInfoContainer{
FileInfoFields: storagedriver.FileInfoFields{
IsDir: false,
Size: *file.Size,
ModTime: *file.LastModified,
Path: strings.Replace(*file.Key, d.s3Path(""), prefix, 1),
},
})
}
sort.SliceStable(walkInfos, func(i, j int) bool { return walkInfos[i].FileInfoFields.Path < walkInfos[j].FileInfoFields.Path })
for _, walkInfo := range walkInfos {
err := f(walkInfo)
if err == storagedriver.ErrSkipDir {
if walkInfo.IsDir() {
continue
} else {
break
}
} else if err != nil {
retError = err
return false
}
if walkInfo.IsDir() {
if err := d.doWalk(ctx, objectCount, *walkInfo.prefix, prefix, f); err != nil {
retError = err
return false
}
}
}
return true
})
if retError != nil {
return retError
}
if listObjectErr != nil {
return listObjectErr
}
return nil
}
func (d *driver) s3Path(path string) string {
return strings.TrimLeft(strings.TrimRight(d.RootDirectory, "/")+path, "/")
}
@@ -1019,10 +1177,10 @@ func (w *writer) Write(p []byte) (int, error) {
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
})
defer resp.Body.Close()
if err != nil {
return 0, err
}
defer resp.Body.Close()
w.parts = nil
w.readyPart, err = ioutil.ReadAll(resp.Body)
if err != nil {

View File

@@ -31,6 +31,7 @@ func init() {
encrypt := os.Getenv("S3_ENCRYPT")
keyID := os.Getenv("S3_KEY_ID")
secure := os.Getenv("S3_SECURE")
skipVerify := os.Getenv("S3_SKIP_VERIFY")
v4Auth := os.Getenv("S3_V4_AUTH")
region := os.Getenv("AWS_REGION")
objectACL := os.Getenv("S3_OBJECT_ACL")
@@ -59,6 +60,14 @@ func init() {
}
}
skipVerifyBool := false
if skipVerify != "" {
skipVerifyBool, err = strconv.ParseBool(skipVerify)
if err != nil {
return nil, err
}
}
v4Bool := true
if v4Auth != "" {
v4Bool, err = strconv.ParseBool(v4Auth)
@@ -76,6 +85,7 @@ func init() {
encryptBool,
keyID,
secureBool,
skipVerifyBool,
v4Bool,
minChunkSize,
defaultMultipartCopyChunkSize,
@@ -139,14 +149,14 @@ func TestEmptyRootList(t *testing.T) {
}
defer rootedDriver.Delete(ctx, filename)
keys, err := emptyRootDriver.List(ctx, "/")
keys, _ := emptyRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
keys, err = slashRootDriver.List(ctx, "/")
keys, _ = slashRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)

View File

@@ -209,9 +209,9 @@ func (v2 *signer) Sign() error {
v2.signature = base64.StdEncoding.EncodeToString(hash.Sum(nil))
if expires {
params["Signature"] = []string{string(v2.signature)}
params["Signature"] = []string{v2.signature}
} else {
headers["Authorization"] = []string{"AWS " + accessKey + ":" + string(v2.signature)}
headers["Authorization"] = []string{"AWS " + accessKey + ":" + v2.signature}
}
log.WithFields(log.Fields{

View File

@@ -1,761 +0,0 @@
// Package s3 provides a storagedriver.StorageDriver implementation to
// store blobs in Amazon S3 cloud storage.
//
// This package leverages the docker/goamz client library for interfacing with
// S3. It is intended to be deprecated in favor of the s3-aws driver
// implementation.
//
// Because S3 is a key, value store the Stat call does not support last modification
// time for directories (directories are an abstraction for key, value stores)
//
// Keep in mind that S3 guarantees only read-after-write consistency for new
// objects, but no read-after-update or list-after-write consistency.
package s3
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/docker/goamz/aws"
"github.com/docker/goamz/s3"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/client/transport"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
)
const driverName = "s3goamz"
// minChunkSize defines the minimum multipart upload chunk size
// S3 API requires multipart upload chunks to be at least 5MB
const minChunkSize = 5 << 20
const defaultChunkSize = 2 * minChunkSize
// listMax is the largest amount of objects you can request from S3 in a list call
const listMax = 1000
//DriverParameters A struct that encapsulates all of the driver parameters after all values have been set
type DriverParameters struct {
AccessKey string
SecretKey string
Bucket string
Region aws.Region
Encrypt bool
Secure bool
V4Auth bool
ChunkSize int64
RootDirectory string
StorageClass s3.StorageClass
UserAgent string
}
func init() {
factory.Register(driverName, &s3DriverFactory{})
}
// s3DriverFactory implements the factory.StorageDriverFactory interface
type s3DriverFactory struct{}
func (factory *s3DriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
type driver struct {
S3 *s3.S3
Bucket *s3.Bucket
ChunkSize int64
Encrypt bool
RootDirectory string
StorageClass s3.StorageClass
}
type baseEmbed struct {
base.Base
}
// Driver is a storagedriver.StorageDriver implementation backed by Amazon S3
// Objects are stored at absolute keys in the provided bucket.
type Driver struct {
baseEmbed
}
// FromParameters constructs a new Driver with a given parameters map
// Required parameters:
// - accesskey
// - secretkey
// - region
// - bucket
// - encrypt
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
// Providing no values for these is valid in case the user is authenticating
// with an IAM on an ec2 instance (in which case the instance credentials will
// be summoned when GetAuth is called)
accessKey := parameters["accesskey"]
if accessKey == nil {
accessKey = ""
}
secretKey := parameters["secretkey"]
if secretKey == nil {
secretKey = ""
}
regionName := parameters["region"]
if regionName == nil || fmt.Sprint(regionName) == "" {
return nil, fmt.Errorf("No region parameter provided")
}
region := aws.GetRegion(fmt.Sprint(regionName))
if region.Name == "" {
return nil, fmt.Errorf("Invalid region provided: %v", region)
}
bucket := parameters["bucket"]
if bucket == nil || fmt.Sprint(bucket) == "" {
return nil, fmt.Errorf("No bucket parameter provided")
}
encryptBool := false
encrypt := parameters["encrypt"]
switch encrypt := encrypt.(type) {
case string:
b, err := strconv.ParseBool(encrypt)
if err != nil {
return nil, fmt.Errorf("The encrypt parameter should be a boolean")
}
encryptBool = b
case bool:
encryptBool = encrypt
case nil:
// do nothing
default:
return nil, fmt.Errorf("The encrypt parameter should be a boolean")
}
secureBool := true
secure := parameters["secure"]
switch secure := secure.(type) {
case string:
b, err := strconv.ParseBool(secure)
if err != nil {
return nil, fmt.Errorf("The secure parameter should be a boolean")
}
secureBool = b
case bool:
secureBool = secure
case nil:
// do nothing
default:
return nil, fmt.Errorf("The secure parameter should be a boolean")
}
v4AuthBool := false
v4Auth := parameters["v4auth"]
switch v4Auth := v4Auth.(type) {
case string:
b, err := strconv.ParseBool(v4Auth)
if err != nil {
return nil, fmt.Errorf("The v4auth parameter should be a boolean")
}
v4AuthBool = b
case bool:
v4AuthBool = v4Auth
case nil:
// do nothing
default:
return nil, fmt.Errorf("The v4auth parameter should be a boolean")
}
chunkSize := int64(defaultChunkSize)
chunkSizeParam := parameters["chunksize"]
switch v := chunkSizeParam.(type) {
case string:
vv, err := strconv.ParseInt(v, 0, 64)
if err != nil {
return nil, fmt.Errorf("chunksize parameter must be an integer, %v invalid", chunkSizeParam)
}
chunkSize = vv
case int64:
chunkSize = v
case int, uint, int32, uint32, uint64:
chunkSize = reflect.ValueOf(v).Convert(reflect.TypeOf(chunkSize)).Int()
case nil:
// do nothing
default:
return nil, fmt.Errorf("invalid value for chunksize: %#v", chunkSizeParam)
}
if chunkSize < minChunkSize {
return nil, fmt.Errorf("The chunksize %#v parameter should be a number that is larger than or equal to %d", chunkSize, minChunkSize)
}
rootDirectory := parameters["rootdirectory"]
if rootDirectory == nil {
rootDirectory = ""
}
storageClass := s3.StandardStorage
storageClassParam := parameters["storageclass"]
if storageClassParam != nil {
storageClassString, ok := storageClassParam.(string)
if !ok {
return nil, fmt.Errorf("The storageclass parameter must be one of %v, %v invalid", []s3.StorageClass{s3.StandardStorage, s3.ReducedRedundancy}, storageClassParam)
}
// All valid storage class parameters are UPPERCASE, so be a bit more flexible here
storageClassCasted := s3.StorageClass(strings.ToUpper(storageClassString))
if storageClassCasted != s3.StandardStorage && storageClassCasted != s3.ReducedRedundancy {
return nil, fmt.Errorf("The storageclass parameter must be one of %v, %v invalid", []s3.StorageClass{s3.StandardStorage, s3.ReducedRedundancy}, storageClassParam)
}
storageClass = storageClassCasted
}
userAgent := parameters["useragent"]
if userAgent == nil {
userAgent = ""
}
params := DriverParameters{
fmt.Sprint(accessKey),
fmt.Sprint(secretKey),
fmt.Sprint(bucket),
region,
encryptBool,
secureBool,
v4AuthBool,
chunkSize,
fmt.Sprint(rootDirectory),
storageClass,
fmt.Sprint(userAgent),
}
return New(params)
}
// New constructs a new Driver with the given AWS credentials, region, encryption flag, and
// bucketName
func New(params DriverParameters) (*Driver, error) {
auth, err := aws.GetAuth(params.AccessKey, params.SecretKey, "", time.Time{})
if err != nil {
return nil, fmt.Errorf("unable to resolve aws credentials, please ensure that 'accesskey' and 'secretkey' are properly set or the credentials are available in $HOME/.aws/credentials: %v", err)
}
if !params.Secure {
params.Region.S3Endpoint = strings.Replace(params.Region.S3Endpoint, "https", "http", 1)
}
s3obj := s3.New(auth, params.Region)
if params.UserAgent != "" {
s3obj.Client = &http.Client{
Transport: transport.NewTransport(http.DefaultTransport,
transport.NewHeaderRequestModifier(http.Header{
http.CanonicalHeaderKey("User-Agent"): []string{params.UserAgent},
}),
),
}
}
if params.V4Auth {
s3obj.Signature = aws.V4Signature
} else if mustV4Auth(params.Region.Name) {
return nil, fmt.Errorf("The %s region only works with v4 authentication", params.Region.Name)
}
bucket := s3obj.Bucket(params.Bucket)
// TODO Currently multipart uploads have no timestamps, so this would be unwise
// if you initiated a new s3driver while another one is running on the same bucket.
// multis, _, err := bucket.ListMulti("", "")
// if err != nil {
// return nil, err
// }
// for _, multi := range multis {
// err := multi.Abort()
// //TODO appropriate to do this error checking?
// if err != nil {
// return nil, err
// }
// }
d := &driver{
S3: s3obj,
Bucket: bucket,
ChunkSize: params.ChunkSize,
Encrypt: params.Encrypt,
RootDirectory: params.RootDirectory,
StorageClass: params.StorageClass,
}
return &Driver{
baseEmbed: baseEmbed{
Base: base.Base{
StorageDriver: d,
},
},
}, nil
}
// Implement the storagedriver.StorageDriver interface
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
content, err := d.Bucket.Get(d.s3Path(path))
if err != nil {
return nil, parseError(path, err)
}
return content, nil
}
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
return parseError(path, d.Bucket.Put(d.s3Path(path), contents, d.getContentType(), getPermissions(), d.getOptions()))
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(http.Header)
headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-")
resp, err := d.Bucket.GetResponseWithHeaders(d.s3Path(path), headers)
if err != nil {
if s3Err, ok := err.(*s3.Error); ok && s3Err.Code == "InvalidRange" {
return ioutil.NopCloser(bytes.NewReader(nil)), nil
}
return nil, parseError(path, err)
}
return resp.Body, nil
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
key := d.s3Path(path)
if !append {
// TODO (brianbland): cancel other uploads at this path
multi, err := d.Bucket.InitMulti(key, d.getContentType(), getPermissions(), d.getOptions())
if err != nil {
return nil, err
}
return d.newWriter(key, multi, nil), nil
}
multis, _, err := d.Bucket.ListMulti(key, "")
if err != nil {
return nil, parseError(path, err)
}
for _, multi := range multis {
if key != multi.Key {
continue
}
parts, err := multi.ListParts()
if err != nil {
return nil, parseError(path, err)
}
var multiSize int64
for _, part := range parts {
multiSize += part.Size
}
return d.newWriter(key, multi, parts), nil
}
return nil, storagedriver.PathNotFoundError{Path: path}
}
// Stat retrieves the FileInfo for the given path, including the current size
// in bytes and the creation time.
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
listResponse, err := d.Bucket.List(d.s3Path(path), "", "", 1)
if err != nil {
return nil, err
}
fi := storagedriver.FileInfoFields{
Path: path,
}
if len(listResponse.Contents) == 1 {
if listResponse.Contents[0].Key != d.s3Path(path) {
fi.IsDir = true
} else {
fi.IsDir = false
fi.Size = listResponse.Contents[0].Size
timestamp, err := time.Parse(time.RFC3339Nano, listResponse.Contents[0].LastModified)
if err != nil {
return nil, err
}
fi.ModTime = timestamp
}
} else if len(listResponse.CommonPrefixes) == 1 {
fi.IsDir = true
} else {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
}
// List returns a list of the objects that are direct descendants of the given path.
func (d *driver) List(ctx context.Context, opath string) ([]string, error) {
path := opath
if path != "/" && path[len(path)-1] != '/' {
path = path + "/"
}
// This is to cover for the cases when the rootDirectory of the driver is either "" or "/".
// In those cases, there is no root prefix to replace and we must actually add a "/" to all
// results in order to keep them as valid paths as recognized by storagedriver.PathRegexp
prefix := ""
if d.s3Path("") == "" {
prefix = "/"
}
listResponse, err := d.Bucket.List(d.s3Path(path), "/", "", listMax)
if err != nil {
return nil, parseError(opath, err)
}
files := []string{}
directories := []string{}
for {
for _, key := range listResponse.Contents {
files = append(files, strings.Replace(key.Key, d.s3Path(""), prefix, 1))
}
for _, commonPrefix := range listResponse.CommonPrefixes {
directories = append(directories, strings.Replace(commonPrefix[0:len(commonPrefix)-1], d.s3Path(""), prefix, 1))
}
if !listResponse.IsTruncated {
break
}
listResponse, err = d.Bucket.List(d.s3Path(path), "/", listResponse.NextMarker, listMax)
if err != nil {
return nil, err
}
}
if opath != "/" {
if len(files) == 0 && len(directories) == 0 {
// Treat empty response as missing directory, since we don't actually
// have directories in s3.
return nil, storagedriver.PathNotFoundError{Path: opath}
}
}
return append(files, directories...), nil
}
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
/* This is terrible, but aws doesn't have an actual move. */
_, err := d.Bucket.PutCopy(d.s3Path(destPath), getPermissions(),
s3.CopyOptions{Options: d.getOptions(), ContentType: d.getContentType()}, d.Bucket.Name+"/"+d.s3Path(sourcePath))
if err != nil {
return parseError(sourcePath, err)
}
return d.Delete(ctx, sourcePath)
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, path string) error {
s3Path := d.s3Path(path)
listResponse, err := d.Bucket.List(s3Path, "", "", listMax)
if err != nil || len(listResponse.Contents) == 0 {
return storagedriver.PathNotFoundError{Path: path}
}
s3Objects := make([]s3.Object, listMax)
for len(listResponse.Contents) > 0 {
numS3Objects := len(listResponse.Contents)
for index, key := range listResponse.Contents {
// Stop if we encounter a key that is not a subpath (so that deleting "/a" does not delete "/ab").
if len(key.Key) > len(s3Path) && (key.Key)[len(s3Path)] != '/' {
numS3Objects = index
break
}
s3Objects[index].Key = key.Key
}
err := d.Bucket.DelMulti(s3.Delete{Quiet: false, Objects: s3Objects[0:numS3Objects]})
if err != nil {
return nil
}
if numS3Objects < len(listResponse.Contents) {
return nil
}
listResponse, err = d.Bucket.List(d.s3Path(path), "", "", listMax)
if err != nil {
return err
}
}
return nil
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
methodString := "GET"
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != "GET" && methodString != "HEAD") {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
expiresTime := time.Now().Add(20 * time.Minute)
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresTime = et
}
}
return d.Bucket.SignedURLWithMethod(methodString, d.s3Path(path), expiresTime, nil, nil), nil
}
func (d *driver) s3Path(path string) string {
return strings.TrimLeft(strings.TrimRight(d.RootDirectory, "/")+path, "/")
}
// S3BucketKey returns the s3 bucket key for the given storage driver path.
func (d *Driver) S3BucketKey(path string) string {
return d.StorageDriver.(*driver).s3Path(path)
}
func parseError(path string, err error) error {
if s3Err, ok := err.(*s3.Error); ok && s3Err.Code == "NoSuchKey" {
return storagedriver.PathNotFoundError{Path: path}
}
return err
}
func (d *driver) getOptions() s3.Options {
return s3.Options{
SSE: d.Encrypt,
StorageClass: d.StorageClass,
}
}
func getPermissions() s3.ACL {
return s3.Private
}
// mustV4Auth checks whether must use v4 auth in specific region.
// Please see documentation at http://docs.aws.amazon.com/general/latest/gr/signature-version-2.html
func mustV4Auth(region string) bool {
switch region {
case "eu-central-1", "cn-north-1", "us-east-2",
"ca-central-1", "ap-south-1", "ap-northeast-2", "eu-west-2":
return true
}
return false
}
func (d *driver) getContentType() string {
return "application/octet-stream"
}
// writer attempts to upload parts to S3 in a buffered fashion where the last
// part is at least as large as the chunksize, so the multipart upload could be
// cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
driver *driver
key string
multi *s3.Multi
parts []s3.Part
size int64
readyPart []byte
pendingPart []byte
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(key string, multi *s3.Multi, parts []s3.Part) storagedriver.FileWriter {
var size int64
for _, part := range parts {
size += part.Size
}
return &writer{
driver: d,
key: key,
multi: multi,
parts: parts,
size: size,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
// If the last written part is smaller than minChunkSize, we need to make a
// new multipart upload :sadface:
if len(w.parts) > 0 && int(w.parts[len(w.parts)-1].Size) < minChunkSize {
err := w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return 0, err
}
multi, err := w.driver.Bucket.InitMulti(w.key, w.driver.getContentType(), getPermissions(), w.driver.getOptions())
if err != nil {
return 0, err
}
w.multi = multi
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
contents, err := w.driver.Bucket.Get(w.key)
if err != nil {
return 0, err
}
w.parts = nil
w.readyPart = contents
} else {
// Otherwise we can use the old file as the new first part
_, part, err := multi.PutPartCopy(1, s3.CopyOptions{}, w.driver.Bucket.Name+"/"+w.key)
if err != nil {
return 0, err
}
w.parts = []s3.Part{part}
}
}
var n int
for len(p) > 0 {
// If no parts are ready to write, fill up the first part
if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.readyPart = append(w.readyPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
} else {
w.readyPart = append(w.readyPart, p...)
n += len(p)
p = nil
}
}
if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
err := w.flushPart()
if err != nil {
w.size += int64(n)
return n, err
}
} else {
w.pendingPart = append(w.pendingPart, p...)
n += len(p)
p = nil
}
}
}
w.size += int64(n)
return n, nil
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.flushPart()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
err := w.multi.Abort()
return err
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
err := w.flushPart()
if err != nil {
return err
}
w.committed = true
err = w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return err
}
return nil
}
// flushPart flushes buffers to write a part to S3.
// Only called by Write (with both buffers full) and Close/Commit (always)
func (w *writer) flushPart() error {
if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
// nothing to write
return nil
}
if len(w.pendingPart) < int(w.driver.ChunkSize) {
// closing with a small pending part
// combine ready and pending to avoid writing a small part
w.readyPart = append(w.readyPart, w.pendingPart...)
w.pendingPart = nil
}
part, err := w.multi.PutPart(len(w.parts)+1, bytes.NewReader(w.readyPart))
if err != nil {
return err
}
w.parts = append(w.parts, part)
w.readyPart = w.pendingPart
w.pendingPart = nil
return nil
}

View File

@@ -1,201 +0,0 @@
package s3
import (
"io/ioutil"
"os"
"strconv"
"testing"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
"github.com/docker/goamz/aws"
"github.com/docker/goamz/s3"
"gopkg.in/check.v1"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
var s3DriverConstructor func(rootDirectory string, storageClass s3.StorageClass) (*Driver, error)
var skipS3 func() string
func init() {
accessKey := os.Getenv("AWS_ACCESS_KEY")
secretKey := os.Getenv("AWS_SECRET_KEY")
bucket := os.Getenv("S3_BUCKET")
encrypt := os.Getenv("S3_ENCRYPT")
secure := os.Getenv("S3_SECURE")
v4auth := os.Getenv("S3_USE_V4_AUTH")
region := os.Getenv("AWS_REGION")
root, err := ioutil.TempDir("", "driver-")
if err != nil {
panic(err)
}
defer os.Remove(root)
s3DriverConstructor = func(rootDirectory string, storageClass s3.StorageClass) (*Driver, error) {
encryptBool := false
if encrypt != "" {
encryptBool, err = strconv.ParseBool(encrypt)
if err != nil {
return nil, err
}
}
secureBool := true
if secure != "" {
secureBool, err = strconv.ParseBool(secure)
if err != nil {
return nil, err
}
}
v4AuthBool := false
if v4auth != "" {
v4AuthBool, err = strconv.ParseBool(v4auth)
if err != nil {
return nil, err
}
}
parameters := DriverParameters{
accessKey,
secretKey,
bucket,
aws.GetRegion(region),
encryptBool,
secureBool,
v4AuthBool,
minChunkSize,
rootDirectory,
storageClass,
driverName + "-test",
}
return New(parameters)
}
// Skip S3 storage driver tests if environment variable parameters are not provided
skipS3 = func() string {
if accessKey == "" || secretKey == "" || region == "" || bucket == "" || encrypt == "" {
return "Must set AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION, S3_BUCKET, and S3_ENCRYPT to run S3 tests"
}
return ""
}
testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) {
return s3DriverConstructor(root, s3.StandardStorage)
}, skipS3)
}
func TestEmptyRootList(t *testing.T) {
if skipS3() != "" {
t.Skip(skipS3())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
rootedDriver, err := s3DriverConstructor(validRoot, s3.StandardStorage)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
emptyRootDriver, err := s3DriverConstructor("", s3.StandardStorage)
if err != nil {
t.Fatalf("unexpected error creating empty root driver: %v", err)
}
slashRootDriver, err := s3DriverConstructor("/", s3.StandardStorage)
if err != nil {
t.Fatalf("unexpected error creating slash root driver: %v", err)
}
filename := "/test"
contents := []byte("contents")
ctx := context.Background()
err = rootedDriver.PutContent(ctx, filename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer rootedDriver.Delete(ctx, filename)
keys, err := emptyRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
keys, err = slashRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
}
func TestStorageClass(t *testing.T) {
if skipS3() != "" {
t.Skip(skipS3())
}
rootDir, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(rootDir)
standardDriver, err := s3DriverConstructor(rootDir, s3.StandardStorage)
if err != nil {
t.Fatalf("unexpected error creating driver with standard storage: %v", err)
}
rrDriver, err := s3DriverConstructor(rootDir, s3.ReducedRedundancy)
if err != nil {
t.Fatalf("unexpected error creating driver with reduced redundancy storage: %v", err)
}
standardFilename := "/test-standard"
rrFilename := "/test-rr"
contents := []byte("contents")
ctx := context.Background()
err = standardDriver.PutContent(ctx, standardFilename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer standardDriver.Delete(ctx, standardFilename)
err = rrDriver.PutContent(ctx, rrFilename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer rrDriver.Delete(ctx, rrFilename)
standardDriverUnwrapped := standardDriver.Base.StorageDriver.(*driver)
resp, err := standardDriverUnwrapped.Bucket.GetResponse(standardDriverUnwrapped.s3Path(standardFilename))
if err != nil {
t.Fatalf("unexpected error retrieving standard storage file: %v", err)
}
defer resp.Body.Close()
// Amazon only populates this header value for non-standard storage classes
if storageClass := resp.Header.Get("x-amz-storage-class"); storageClass != "" {
t.Fatalf("unexpected storage class for standard file: %v", storageClass)
}
rrDriverUnwrapped := rrDriver.Base.StorageDriver.(*driver)
resp, err = rrDriverUnwrapped.Bucket.GetResponse(rrDriverUnwrapped.s3Path(rrFilename))
if err != nil {
t.Fatalf("unexpected error retrieving reduced-redundancy storage file: %v", err)
}
defer resp.Body.Close()
if storageClass := resp.Header.Get("x-amz-storage-class"); storageClass != string(s3.ReducedRedundancy) {
t.Fatalf("unexpected storage class for reduced-redundancy file: %v", storageClass)
}
}

View File

@@ -1,13 +1,12 @@
package driver
import (
"context"
"fmt"
"io"
"regexp"
"strconv"
"strings"
"github.com/docker/distribution/context"
)
// Version is a string representing the storage driver version, of the form
@@ -84,6 +83,13 @@ type StorageDriver interface {
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error)
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file.
// If the returned error from the WalkFn is ErrSkipDir and fileInfo refers
// to a directory, the directory will not be entered and Walk
// will continue the traversal. If fileInfo refers to a normal file, processing stops
Walk(ctx context.Context, path string, f WalkFn) error
}
// FileWriter provides an abstraction for an opened writable file-like object in

View File

@@ -18,6 +18,7 @@ package swift
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/sha1"
"crypto/tls"
@@ -34,7 +35,6 @@ import (
"github.com/mitchellh/mapstructure"
"github.com/ncw/swift"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
@@ -142,6 +142,19 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
InsecureSkipVerify: false,
}
// Sanitize some entries before trying to decode parameters with mapstructure
// TenantID and Tenant when integers only and passed as ENV variables
// are considered as integer and not string. The parser fails in this
// case.
_, ok := parameters["tenant"]
if ok {
parameters["tenant"] = fmt.Sprint(parameters["tenant"])
}
_, ok = parameters["tenantid"]
if ok {
parameters["tenantid"] = fmt.Sprint(parameters["tenantid"])
}
if err := mapstructure.Decode(parameters, &params); err != nil {
return nil, err
}
@@ -644,6 +657,12 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
return tempURL, nil
}
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func (d *driver) Walk(ctx context.Context, path string, f storagedriver.WalkFn) error {
return storagedriver.WalkFallback(ctx, d, path, f)
}
func (d *driver) swiftPath(path string) string {
return strings.TrimLeft(strings.TrimRight(d.Prefix+"/files"+path, "/"), "/")
}

View File

@@ -148,14 +148,14 @@ func TestEmptyRootList(t *testing.T) {
t.Fatalf("unexpected error creating content: %v", err)
}
keys, err := emptyRootDriver.List(ctx, "/")
keys, _ := emptyRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
keys, err = slashRootDriver.List(ctx, "/")
keys, _ = slashRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
@@ -234,11 +234,11 @@ func TestFilenameChunking(t *testing.T) {
}
// Test 0 and < 0 sizes
actual, err = chunkFilenames(nil, 0)
_, err = chunkFilenames(nil, 0)
if err == nil {
t.Fatal("expected error for size = 0")
}
actual, err = chunkFilenames(nil, -1)
_, err = chunkFilenames(nil, -1)
if err == nil {
t.Fatal("expected error for size = -1")
}

View File

@@ -1,7 +1,8 @@
package testdriver
import (
"github.com/docker/distribution/context"
"context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/docker/distribution/registry/storage/driver/inmemory"

View File

@@ -2,6 +2,7 @@ package testsuites
import (
"bytes"
"context"
"crypto/sha1"
"io"
"io/ioutil"
@@ -15,10 +16,8 @@ import (
"testing"
"time"
"gopkg.in/check.v1"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"gopkg.in/check.v1"
)
// Test hooks up gocheck into the "go test" runner.
@@ -137,7 +136,7 @@ func (suite *DriverSuite) deletePath(c *check.C, path string) {
err = nil
}
c.Assert(err, check.IsNil)
paths, err := suite.StorageDriver.List(suite.ctx, path)
paths, _ := suite.StorageDriver.List(suite.ctx, path)
if len(paths) == 0 {
break
}
@@ -652,7 +651,7 @@ func (suite *DriverSuite) TestURLFor(c *check.C) {
}
c.Assert(err, check.IsNil)
response, err = http.Head(url)
response, _ = http.Head(url)
c.Assert(response.StatusCode, check.Equals, 200)
c.Assert(response.ContentLength, check.Equals, int64(32))
}
@@ -1117,7 +1116,7 @@ func (suite *DriverSuite) testFileStreams(c *check.C, size int64) {
c.Assert(err, check.IsNil)
tf.Sync()
tf.Seek(0, os.SEEK_SET)
tf.Seek(0, io.SeekStart)
writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)

View File

@@ -0,0 +1,61 @@
package driver
import (
"context"
"errors"
"sort"
"github.com/sirupsen/logrus"
)
// ErrSkipDir is used as a return value from onFileFunc to indicate that
// the directory named in the call is to be skipped. It is not returned
// as an error by any function.
var ErrSkipDir = errors.New("skip this directory")
// WalkFn is called once per file by Walk
type WalkFn func(fileInfo FileInfo) error
// WalkFallback traverses a filesystem defined within driver, starting
// from the given path, calling f on each file. It uses the List method and Stat to drive itself.
// If the returned error from the WalkFn is ErrSkipDir and fileInfo refers
// to a directory, the directory will not be entered and Walk
// will continue the traversal. If fileInfo refers to a normal file, processing stops
func WalkFallback(ctx context.Context, driver StorageDriver, from string, f WalkFn) error {
children, err := driver.List(ctx, from)
if err != nil {
return err
}
sort.Stable(sort.StringSlice(children))
for _, child := range children {
// TODO(stevvooe): Calling driver.Stat for every entry is quite
// expensive when running against backends with a slow Stat
// implementation, such as s3. This is very likely a serious
// performance bottleneck.
fileInfo, err := driver.Stat(ctx, child)
if err != nil {
switch err.(type) {
case PathNotFoundError:
// repository was removed in between listing and enumeration. Ignore it.
logrus.WithField("path", child).Infof("ignoring deleted path")
continue
default:
return err
}
}
err = f(fileInfo)
if err == nil && fileInfo.IsDir() {
if err := WalkFallback(ctx, driver, child, f); err != nil {
return err
}
} else if err == ErrSkipDir {
// Stop iteration if it's a file, otherwise noop if it's a directory
if !fileInfo.IsDir() {
return nil
}
} else if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,47 @@
package driver
import (
"context"
"fmt"
"testing"
)
type changingFileSystem struct {
StorageDriver
fileset []string
keptFiles map[string]bool
}
func (cfs *changingFileSystem) List(ctx context.Context, path string) ([]string, error) {
return cfs.fileset, nil
}
func (cfs *changingFileSystem) Stat(ctx context.Context, path string) (FileInfo, error) {
kept, ok := cfs.keptFiles[path]
if ok && kept {
return &FileInfoInternal{
FileInfoFields: FileInfoFields{
Path: path,
},
}, nil
}
return nil, PathNotFoundError{}
}
func TestWalkFileRemoved(t *testing.T) {
d := &changingFileSystem{
fileset: []string{"zoidberg", "bender"},
keptFiles: map[string]bool{
"zoidberg": true,
},
}
infos := []FileInfo{}
err := WalkFallback(context.Background(), d, "", func(fileInfo FileInfo) error {
infos = append(infos, fileInfo)
return nil
})
if len(infos) != 1 || infos[0].Path() != "zoidberg" {
t.Errorf(fmt.Sprintf("unexpected path set during walk: %s", infos))
}
if err != nil {
t.Fatalf(err.Error())
}
}

View File

@@ -0,0 +1,9 @@
package storage
import "fmt"
// pushError formats an error type given a path and an error
// and pushes it to a slice of errors
func pushError(errors []error, path string, err error) []error {
return append(errors, fmt.Errorf("%s: %s", path, err))
}

View File

@@ -3,12 +3,11 @@ package storage
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
@@ -81,12 +80,12 @@ func (fr *fileReader) Seek(offset int64, whence int) (int64, error) {
newOffset := fr.offset
switch whence {
case os.SEEK_CUR:
newOffset += int64(offset)
case os.SEEK_END:
newOffset = fr.size + int64(offset)
case os.SEEK_SET:
newOffset = int64(offset)
case io.SeekCurrent:
newOffset += offset
case io.SeekEnd:
newOffset = fr.size + offset
case io.SeekStart:
newOffset = offset
}
if newOffset < 0 {

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"io"
mrand "math/rand"
"os"
"testing"
"github.com/docker/distribution/context"
@@ -72,7 +71,7 @@ func TestFileReaderSeek(t *testing.T) {
for _, repitition := range mrand.Perm(repititions - 1) {
targetOffset := int64(len(pattern) * repitition)
// Seek to a multiple of pattern size and read pattern size bytes
offset, err := fr.Seek(targetOffset, os.SEEK_SET)
offset, err := fr.Seek(targetOffset, io.SeekStart)
if err != nil {
t.Fatalf("unexpected error seeking: %v", err)
}
@@ -97,7 +96,7 @@ func TestFileReaderSeek(t *testing.T) {
}
// Check offset
current, err := fr.Seek(0, os.SEEK_CUR)
current, err := fr.Seek(0, io.SeekCurrent)
if err != nil {
t.Fatalf("error checking current offset: %v", err)
}
@@ -107,7 +106,7 @@ func TestFileReaderSeek(t *testing.T) {
}
}
start, err := fr.Seek(0, os.SEEK_SET)
start, err := fr.Seek(0, io.SeekStart)
if err != nil {
t.Fatalf("error seeking to start: %v", err)
}
@@ -116,7 +115,7 @@ func TestFileReaderSeek(t *testing.T) {
t.Fatalf("expected to seek to start: %v != 0", start)
}
end, err := fr.Seek(0, os.SEEK_END)
end, err := fr.Seek(0, io.SeekEnd)
if err != nil {
t.Fatalf("error checking current offset: %v", err)
}
@@ -128,13 +127,13 @@ func TestFileReaderSeek(t *testing.T) {
// 4. Seek before start, ensure error.
// seek before start
before, err := fr.Seek(-1, os.SEEK_SET)
before, err := fr.Seek(-1, io.SeekStart)
if err == nil {
t.Fatalf("error expected, returned offset=%v", before)
}
// 5. Seek after end,
after, err := fr.Seek(1, os.SEEK_END)
after, err := fr.Seek(1, io.SeekEnd)
if err != nil {
t.Fatalf("unexpected error expected, returned offset=%v", after)
}

View File

@@ -1,10 +1,10 @@
package storage
import (
"context"
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/driver"
"github.com/opencontainers/go-digest"
@@ -14,8 +14,21 @@ func emit(format string, a ...interface{}) {
fmt.Printf(format+"\n", a...)
}
// GCOpts contains options for garbage collector
type GCOpts struct {
DryRun bool
RemoveUntagged bool
}
// ManifestDel contains manifest structure which will be deleted
type ManifestDel struct {
Name string
Digest digest.Digest
Tags []string
}
// MarkAndSweep performs a mark and sweep of registry data
func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, registry distribution.Namespace, dryRun bool) error {
func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, registry distribution.Namespace, opts GCOpts) error {
repositoryEnumerator, ok := registry.(distribution.RepositoryEnumerator)
if !ok {
return fmt.Errorf("unable to convert Namespace to RepositoryEnumerator")
@@ -23,6 +36,7 @@ func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, regis
// mark
markSet := make(map[digest.Digest]struct{})
manifestArr := make([]ManifestDel, 0)
err := repositoryEnumerator.Enumerate(ctx, func(repoName string) error {
emit(repoName)
@@ -47,6 +61,25 @@ func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, regis
}
err = manifestEnumerator.Enumerate(ctx, func(dgst digest.Digest) error {
if opts.RemoveUntagged {
// fetch all tags where this manifest is the latest one
tags, err := repository.Tags(ctx).Lookup(ctx, distribution.Descriptor{Digest: dgst})
if err != nil {
return fmt.Errorf("failed to retrieve tags for digest %v: %v", dgst, err)
}
if len(tags) == 0 {
emit("manifest eligible for deletion: %s", dgst)
// fetch all tags from repository
// all of these tags could contain manifest in history
// which means that we need check (and delete) those references when deleting manifest
allTags, err := repository.Tags(ctx).All(ctx)
if err != nil {
return fmt.Errorf("failed to retrieve tags %v", err)
}
manifestArr = append(manifestArr, ManifestDel{Name: repoName, Digest: dgst, Tags: allTags})
return nil
}
}
// Mark the manifest's blob
emit("%s: marking manifest %s ", repoName, dgst)
markSet[dgst] = struct{}{}
@@ -84,6 +117,15 @@ func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, regis
}
// sweep
vacuum := NewVacuum(ctx, storageDriver)
if !opts.DryRun {
for _, obj := range manifestArr {
err = vacuum.RemoveManifest(obj.Name, obj.Digest, obj.Tags)
if err != nil {
return fmt.Errorf("failed to delete manifest %s: %v", obj.Digest, err)
}
}
}
blobService := registry.Blobs()
deleteSet := make(map[digest.Digest]struct{})
err = blobService.Enumerate(ctx, func(dgst digest.Digest) error {
@@ -96,12 +138,10 @@ func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, regis
if err != nil {
return fmt.Errorf("error enumerating blobs: %v", err)
}
emit("\n%d blobs marked, %d blobs eligible for deletion", len(markSet), len(deleteSet))
// Construct vacuum
vacuum := NewVacuum(ctx, storageDriver)
emit("\n%d blobs marked, %d blobs and %d manifests eligible for deletion", len(markSet), len(deleteSet), len(manifestArr))
for dgst := range deleteSet {
emit("blob eligible for deletion: %s", dgst)
if dryRun {
if opts.DryRun {
continue
}
err = vacuum.RemoveBlob(string(dgst))

View File

@@ -27,7 +27,7 @@ func createRegistry(t *testing.T, driver driver.StorageDriver, options ...Regist
if err != nil {
t.Fatal(err)
}
options = append([]RegistryOption{EnableDelete, Schema1SigningKey(k)}, options...)
options = append([]RegistryOption{EnableDelete, Schema1SigningKey(k), EnableSchema1}, options...)
registry, err := NewRegistry(ctx, driver, options...)
if err != nil {
t.Fatalf("Failed to construct namespace")
@@ -61,6 +61,23 @@ func makeManifestService(t *testing.T, repository distribution.Repository) distr
return manifestService
}
func allManifests(t *testing.T, manifestService distribution.ManifestService) map[digest.Digest]struct{} {
ctx := context.Background()
allManMap := make(map[digest.Digest]struct{})
manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator)
if !ok {
t.Fatalf("unable to convert ManifestService into ManifestEnumerator")
}
err := manifestEnumerator.Enumerate(ctx, func(dgst digest.Digest) error {
allManMap[dgst] = struct{}{}
return nil
})
if err != nil {
t.Fatalf("Error getting all manifests: %v", err)
}
return allManMap
}
func allBlobs(t *testing.T, registry distribution.Namespace) map[digest.Digest]struct{} {
ctx := context.Background()
blobService := registry.Blobs()
@@ -147,7 +164,7 @@ func TestNoDeletionNoEffect(t *testing.T) {
registry := createRegistry(t, inmemoryDriver)
repo := makeRepository(t, registry, "palailogos")
manifestService, err := repo.Manifests(ctx)
manifestService, _ := repo.Manifests(ctx)
image1 := uploadRandomSchema1Image(t, repo)
image2 := uploadRandomSchema1Image(t, repo)
@@ -169,7 +186,10 @@ func TestNoDeletionNoEffect(t *testing.T) {
before := allBlobs(t, registry)
// Run GC
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}
@@ -180,6 +200,102 @@ func TestNoDeletionNoEffect(t *testing.T) {
}
}
func TestDeleteManifestIfTagNotFound(t *testing.T) {
ctx := context.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver)
repo := makeRepository(t, registry, "deletemanifests")
manifestService, _ := repo.Manifests(ctx)
// Create random layers
randomLayers1, err := testutil.CreateRandomLayers(3)
if err != nil {
t.Fatalf("failed to make layers: %v", err)
}
randomLayers2, err := testutil.CreateRandomLayers(3)
if err != nil {
t.Fatalf("failed to make layers: %v", err)
}
// Upload all layers
err = testutil.UploadBlobs(repo, randomLayers1)
if err != nil {
t.Fatalf("failed to upload layers: %v", err)
}
err = testutil.UploadBlobs(repo, randomLayers2)
if err != nil {
t.Fatalf("failed to upload layers: %v", err)
}
// Construct manifests
manifest1, err := testutil.MakeSchema1Manifest(getKeys(randomLayers1))
if err != nil {
t.Fatalf("failed to make manifest: %v", err)
}
manifest2, err := testutil.MakeSchema1Manifest(getKeys(randomLayers2))
if err != nil {
t.Fatalf("failed to make manifest: %v", err)
}
_, err = manifestService.Put(ctx, manifest1)
if err != nil {
t.Fatalf("manifest upload failed: %v", err)
}
_, err = manifestService.Put(ctx, manifest2)
if err != nil {
t.Fatalf("manifest upload failed: %v", err)
}
manifestEnumerator, _ := manifestService.(distribution.ManifestEnumerator)
manifestEnumerator.Enumerate(ctx, func(dgst digest.Digest) error {
repo.Tags(ctx).Tag(ctx, "test", distribution.Descriptor{Digest: dgst})
return nil
})
before1 := allBlobs(t, registry)
before2 := allManifests(t, manifestService)
// run GC with dry-run (should not remove anything)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
DryRun: true,
RemoveUntagged: true,
})
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}
afterDry1 := allBlobs(t, registry)
afterDry2 := allManifests(t, manifestService)
if len(before1) != len(afterDry1) {
t.Fatalf("Garbage collection affected blobs storage: %d != %d", len(before1), len(afterDry1))
}
if len(before2) != len(afterDry2) {
t.Fatalf("Garbage collection affected manifest storage: %d != %d", len(before2), len(afterDry2))
}
// Run GC (removes everything because no manifests with tags exist)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: true,
})
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}
after1 := allBlobs(t, registry)
after2 := allManifests(t, manifestService)
if len(before1) == len(after1) {
t.Fatalf("Garbage collection affected blobs storage: %d == %d", len(before1), len(after1))
}
if len(before2) == len(after2) {
t.Fatalf("Garbage collection affected manifest storage: %d == %d", len(before2), len(after2))
}
}
func TestGCWithMissingManifests(t *testing.T) {
ctx := context.Background()
d := inmemory.New()
@@ -200,7 +316,10 @@ func TestGCWithMissingManifests(t *testing.T) {
t.Fatal(err)
}
err = MarkAndSweep(context.Background(), d, registry, false)
err = MarkAndSweep(context.Background(), d, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}
@@ -217,7 +336,7 @@ func TestDeletionHasEffect(t *testing.T) {
registry := createRegistry(t, inmemoryDriver)
repo := makeRepository(t, registry, "komnenos")
manifests, err := repo.Manifests(ctx)
manifests, _ := repo.Manifests(ctx)
image1 := uploadRandomSchema1Image(t, repo)
image2 := uploadRandomSchema1Image(t, repo)
@@ -227,7 +346,10 @@ func TestDeletionHasEffect(t *testing.T) {
manifests.Delete(ctx, image3.manifestDigest)
// Run GC
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
err := MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}
@@ -361,7 +483,10 @@ func TestOrphanBlobDeleted(t *testing.T) {
uploadRandomSchema2Image(t, repo)
// Run GC
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}

View File

@@ -1,11 +1,11 @@
package storage
import (
"context"
"errors"
"io"
"io/ioutil"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/driver"
)

View File

@@ -1,13 +1,14 @@
package storage
import (
"context"
"fmt"
"net/http"
"path"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/uuid"
@@ -86,7 +87,7 @@ func (lbs *linkedBlobStore) Put(ctx context.Context, mediaType string, p []byte)
// Place the data in the blob store first.
desc, err := lbs.blobStore.Put(ctx, mediaType, p)
if err != nil {
context.GetLogger(ctx).Errorf("error putting into main store: %v", err)
dcontext.GetLogger(ctx).Errorf("error putting into main store: %v", err)
return distribution.Descriptor{}, err
}
@@ -125,7 +126,7 @@ func WithMountFrom(ref reference.Canonical) distribution.BlobCreateOption {
// Writer begins a blob write session, returning a handle.
func (lbs *linkedBlobStore) Create(ctx context.Context, options ...distribution.BlobCreateOption) (distribution.BlobWriter, error) {
context.GetLogger(ctx).Debug("(*linkedBlobStore).Writer")
dcontext.GetLogger(ctx).Debug("(*linkedBlobStore).Writer")
var opts distribution.CreateOptions
@@ -174,7 +175,7 @@ func (lbs *linkedBlobStore) Create(ctx context.Context, options ...distribution.
}
func (lbs *linkedBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
context.GetLogger(ctx).Debug("(*linkedBlobStore).Resume")
dcontext.GetLogger(ctx).Debug("(*linkedBlobStore).Resume")
startedAtPath, err := pathFor(uploadStartedAtPathSpec{
name: lbs.repository.Named().Name(),
@@ -236,7 +237,7 @@ func (lbs *linkedBlobStore) Enumerate(ctx context.Context, ingestor func(digest.
if err != nil {
return err
}
err = Walk(ctx, lbs.blobStore.driver, rootPath, func(fileInfo driver.FileInfo) error {
return lbs.driver.Walk(ctx, rootPath, func(fileInfo driver.FileInfo) error {
// exit early if directory...
if fileInfo.IsDir() {
return nil
@@ -272,12 +273,6 @@ func (lbs *linkedBlobStore) Enumerate(ctx context.Context, ingestor func(digest.
return nil
})
if err != nil {
return err
}
return nil
}
func (lbs *linkedBlobStore) mount(ctx context.Context, sourceRepo reference.Named, dgst digest.Digest, sourceStat *distribution.Descriptor) (distribution.Descriptor, error) {
@@ -317,14 +312,14 @@ func (lbs *linkedBlobStore) newBlobUpload(ctx context.Context, uuid, path string
}
bw := &blobWriter{
ctx: ctx,
blobStore: lbs,
id: uuid,
startedAt: startedAt,
digester: digest.Canonical.Digester(),
fileWriter: fw,
driver: lbs.driver,
path: path,
ctx: ctx,
blobStore: lbs,
id: uuid,
startedAt: startedAt,
digester: digest.Canonical.Digester(),
fileWriter: fw,
driver: lbs.driver,
path: path,
resumableDigestEnabled: lbs.resumableDigestEnabled,
}
@@ -411,7 +406,7 @@ func (lbs *linkedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (dis
if target != dgst {
// Track when we are doing cross-digest domain lookups. ie, sha512 to sha256.
context.GetLogger(ctx).Warnf("looking up blob with canonical target: %v -> %v", dgst, target)
dcontext.GetLogger(ctx).Warnf("looking up blob with canonical target: %v -> %v", dgst, target)
}
// TODO(stevvooe): Look up repository local mediatype and replace that on

View File

@@ -1,6 +1,7 @@
package storage
import (
"context"
"fmt"
"io"
"reflect"
@@ -8,11 +9,9 @@ import (
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/opencontainers/go-digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/testutil"
"github.com/opencontainers/go-digest"
)
func TestLinkedBlobStoreCreateWithMountFrom(t *testing.T) {
@@ -163,8 +162,8 @@ type mockBlobDescriptorServiceFactory struct {
func (f *mockBlobDescriptorServiceFactory) BlobAccessController(svc distribution.BlobDescriptorService) distribution.BlobDescriptorService {
return &mockBlobDescriptorService{
BlobDescriptorService: svc,
t: f.t,
stats: f.stats,
t: f.t,
stats: f.stats,
}
}

View File

@@ -1,11 +1,11 @@
package storage
import (
"context"
"fmt"
"encoding/json"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/manifest/manifestlist"
"github.com/opencontainers/go-digest"
)
@@ -20,18 +20,18 @@ type manifestListHandler struct {
var _ ManifestHandler = &manifestListHandler{}
func (ms *manifestListHandler) Unmarshal(ctx context.Context, dgst digest.Digest, content []byte) (distribution.Manifest, error) {
context.GetLogger(ms.ctx).Debug("(*manifestListHandler).Unmarshal")
dcontext.GetLogger(ms.ctx).Debug("(*manifestListHandler).Unmarshal")
var m manifestlist.DeserializedManifestList
if err := json.Unmarshal(content, &m); err != nil {
m := &manifestlist.DeserializedManifestList{}
if err := m.UnmarshalJSON(content); err != nil {
return nil, err
}
return &m, nil
return m, nil
}
func (ms *manifestListHandler) Put(ctx context.Context, manifestList distribution.Manifest, skipDependencyVerification bool) (digest.Digest, error) {
context.GetLogger(ms.ctx).Debug("(*manifestListHandler).Put")
dcontext.GetLogger(ms.ctx).Debug("(*manifestListHandler).Put")
m, ok := manifestList.(*manifestlist.DeserializedManifestList)
if !ok {
@@ -49,7 +49,7 @@ func (ms *manifestListHandler) Put(ctx context.Context, manifestList distributio
revision, err := ms.blobStore.Put(ctx, mt, payload)
if err != nil {
context.GetLogger(ctx).Errorf("error putting payload into blobstore: %v", err)
dcontext.GetLogger(ctx).Errorf("error putting payload into blobstore: %v", err)
return "", err
}
@@ -63,6 +63,10 @@ func (ms *manifestListHandler) Put(ctx context.Context, manifestList distributio
func (ms *manifestListHandler) verifyManifest(ctx context.Context, mnfst manifestlist.DeserializedManifestList, skipDependencyVerification bool) error {
var errs distribution.ErrManifestVerification
if mnfst.SchemaVersion != 2 {
return fmt.Errorf("unrecognized manifest list schema version %d", mnfst.SchemaVersion)
}
if !skipDependencyVerification {
// This manifest service is different from the blob service
// returned by Blob. It uses a linked blob store to ensure that

View File

@@ -1,16 +1,19 @@
package storage
import (
"context"
"encoding/json"
"fmt"
"encoding/json"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/manifestlist"
"github.com/docker/distribution/manifest/ocischema"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2"
"github.com/opencontainers/go-digest"
"github.com/opencontainers/image-spec/specs-go/v1"
)
// A ManifestHandler gets and puts manifests of a particular type.
@@ -47,13 +50,14 @@ type manifestStore struct {
schema1Handler ManifestHandler
schema2Handler ManifestHandler
ocischemaHandler ManifestHandler
manifestListHandler ManifestHandler
}
var _ distribution.ManifestService = &manifestStore{}
func (ms *manifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
context.GetLogger(ms.ctx).Debug("(*manifestStore).Exists")
dcontext.GetLogger(ms.ctx).Debug("(*manifestStore).Exists")
_, err := ms.blobStore.Stat(ms.ctx, dgst)
if err != nil {
@@ -68,7 +72,7 @@ func (ms *manifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool,
}
func (ms *manifestStore) Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error) {
context.GetLogger(ms.ctx).Debug("(*manifestStore).Get")
dcontext.GetLogger(ms.ctx).Debug("(*manifestStore).Get")
// TODO(stevvooe): Need to check descriptor from above to ensure that the
// mediatype is as we expect for the manifest store.
@@ -98,8 +102,22 @@ func (ms *manifestStore) Get(ctx context.Context, dgst digest.Digest, options ..
switch versioned.MediaType {
case schema2.MediaTypeManifest:
return ms.schema2Handler.Unmarshal(ctx, dgst, content)
case manifestlist.MediaTypeManifestList:
case v1.MediaTypeImageManifest:
return ms.ocischemaHandler.Unmarshal(ctx, dgst, content)
case manifestlist.MediaTypeManifestList, v1.MediaTypeImageIndex:
return ms.manifestListHandler.Unmarshal(ctx, dgst, content)
case "":
// OCI image or image index - no media type in the content
// First see if it looks like an image index
res, err := ms.manifestListHandler.Unmarshal(ctx, dgst, content)
resIndex := res.(*manifestlist.DeserializedManifestList)
if err == nil && resIndex.Manifests != nil {
return resIndex, nil
}
// Otherwise, assume it must be an image manifest
return ms.ocischemaHandler.Unmarshal(ctx, dgst, content)
default:
return nil, distribution.ErrManifestVerification{fmt.Errorf("unrecognized manifest content type %s", versioned.MediaType)}
}
@@ -109,13 +127,15 @@ func (ms *manifestStore) Get(ctx context.Context, dgst digest.Digest, options ..
}
func (ms *manifestStore) Put(ctx context.Context, manifest distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) {
context.GetLogger(ms.ctx).Debug("(*manifestStore).Put")
dcontext.GetLogger(ms.ctx).Debug("(*manifestStore).Put")
switch manifest.(type) {
case *schema1.SignedManifest:
return ms.schema1Handler.Put(ctx, manifest, ms.skipDependencyVerification)
case *schema2.DeserializedManifest:
return ms.schema2Handler.Put(ctx, manifest, ms.skipDependencyVerification)
case *ocischema.DeserializedManifest:
return ms.ocischemaHandler.Put(ctx, manifest, ms.skipDependencyVerification)
case *manifestlist.DeserializedManifestList:
return ms.manifestListHandler.Put(ctx, manifest, ms.skipDependencyVerification)
}
@@ -125,7 +145,7 @@ func (ms *manifestStore) Put(ctx context.Context, manifest distribution.Manifest
// Delete removes the revision of the specified manifest.
func (ms *manifestStore) Delete(ctx context.Context, dgst digest.Digest) error {
context.GetLogger(ms.ctx).Debug("(*manifestStore).Delete")
dcontext.GetLogger(ms.ctx).Debug("(*manifestStore).Delete")
return ms.blobStore.Delete(ctx, dgst)
}

View File

@@ -2,13 +2,15 @@ package storage
import (
"bytes"
"context"
"io"
"reflect"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/manifestlist"
"github.com/docker/distribution/manifest/ocischema"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache/memory"
@@ -17,6 +19,7 @@ import (
"github.com/docker/distribution/testutil"
"github.com/docker/libtrust"
"github.com/opencontainers/go-digest"
"github.com/opencontainers/image-spec/specs-go/v1"
)
type manifestStoreTestEnv struct {
@@ -56,10 +59,18 @@ func TestManifestStorage(t *testing.T) {
if err != nil {
t.Fatal(err)
}
testManifestStorage(t, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect, Schema1SigningKey(k))
testManifestStorage(t, true, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect, Schema1SigningKey(k), EnableSchema1)
}
func testManifestStorage(t *testing.T, options ...RegistryOption) {
func TestManifestStorageV1Unsupported(t *testing.T) {
k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
testManifestStorage(t, false, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect, Schema1SigningKey(k))
}
func testManifestStorage(t *testing.T, schema1Enabled bool, options ...RegistryOption) {
repoName, _ := reference.WithName("foo/bar")
env := newManifestStoreTestEnv(t, repoName, "thetag", options...)
ctx := context.Background()
@@ -111,6 +122,15 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("expected errors putting manifest with full verification")
}
// If schema1 is not enabled, do a short version of this test, just checking
// if we get the right error when we Put
if !schema1Enabled {
if err != distribution.ErrSchemaV1Unsupported {
t.Fatalf("got the wrong error when schema1 is disabled: %s", err)
}
return
}
switch err := err.(type) {
case distribution.ErrManifestVerification:
if len(err) != 2 {
@@ -356,6 +376,175 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
}
}
func TestOCIManifestStorage(t *testing.T) {
testOCIManifestStorage(t, "includeMediaTypes=true", true)
testOCIManifestStorage(t, "includeMediaTypes=false", false)
}
func testOCIManifestStorage(t *testing.T, testname string, includeMediaTypes bool) {
var imageMediaType string
var indexMediaType string
if includeMediaTypes {
imageMediaType = v1.MediaTypeImageManifest
indexMediaType = v1.MediaTypeImageIndex
} else {
imageMediaType = ""
indexMediaType = ""
}
repoName, _ := reference.WithName("foo/bar")
env := newManifestStoreTestEnv(t, repoName, "thetag",
BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()),
EnableDelete, EnableRedirect)
ctx := context.Background()
ms, err := env.repository.Manifests(ctx)
if err != nil {
t.Fatal(err)
}
// Build a manifest and store it and its layers in the registry
blobStore := env.repository.Blobs(ctx)
builder := ocischema.NewManifestBuilder(blobStore, []byte{}, map[string]string{})
err = builder.(*ocischema.Builder).SetMediaType(imageMediaType)
if err != nil {
t.Fatal(err)
}
// Add some layers
for i := 0; i < 2; i++ {
rs, ds, err := testutil.CreateRandomTarFile()
if err != nil {
t.Fatalf("%s: unexpected error generating test layer file", testname)
}
dgst := digest.Digest(ds)
wr, err := env.repository.Blobs(env.ctx).Create(env.ctx)
if err != nil {
t.Fatalf("%s: unexpected error creating test upload: %v", testname, err)
}
if _, err := io.Copy(wr, rs); err != nil {
t.Fatalf("%s: unexpected error copying to upload: %v", testname, err)
}
if _, err := wr.Commit(env.ctx, distribution.Descriptor{Digest: dgst}); err != nil {
t.Fatalf("%s: unexpected error finishing upload: %v", testname, err)
}
builder.AppendReference(distribution.Descriptor{Digest: dgst})
}
manifest, err := builder.Build(ctx)
if err != nil {
t.Fatalf("%s: unexpected error generating manifest: %v", testname, err)
}
// before putting the manifest test for proper handling of SchemaVersion
if manifest.(*ocischema.DeserializedManifest).Manifest.SchemaVersion != 2 {
t.Fatalf("%s: unexpected error generating default version for oci manifest", testname)
}
manifest.(*ocischema.DeserializedManifest).Manifest.SchemaVersion = 0
var manifestDigest digest.Digest
if manifestDigest, err = ms.Put(ctx, manifest); err != nil {
if err.Error() != "unrecognized manifest schema version 0" {
t.Fatalf("%s: unexpected error putting manifest: %v", testname, err)
}
manifest.(*ocischema.DeserializedManifest).Manifest.SchemaVersion = 2
if manifestDigest, err = ms.Put(ctx, manifest); err != nil {
t.Fatalf("%s: unexpected error putting manifest: %v", testname, err)
}
}
// Also create an image index that contains the manifest
descriptor, err := env.registry.BlobStatter().Stat(ctx, manifestDigest)
if err != nil {
t.Fatalf("%s: unexpected error getting manifest descriptor", testname)
}
descriptor.MediaType = v1.MediaTypeImageManifest
platformSpec := manifestlist.PlatformSpec{
Architecture: "atari2600",
OS: "CP/M",
}
manifestDescriptors := []manifestlist.ManifestDescriptor{
{
Descriptor: descriptor,
Platform: platformSpec,
},
}
imageIndex, err := manifestlist.FromDescriptorsWithMediaType(manifestDescriptors, indexMediaType)
if err != nil {
t.Fatalf("%s: unexpected error creating image index: %v", testname, err)
}
var indexDigest digest.Digest
if indexDigest, err = ms.Put(ctx, imageIndex); err != nil {
t.Fatalf("%s: unexpected error putting image index: %v", testname, err)
}
// Now check that we can retrieve the manifest
fromStore, err := ms.Get(ctx, manifestDigest)
if err != nil {
t.Fatalf("%s: unexpected error fetching manifest: %v", testname, err)
}
fetchedManifest, ok := fromStore.(*ocischema.DeserializedManifest)
if !ok {
t.Fatalf("%s: unexpected type for fetched manifest", testname)
}
if fetchedManifest.MediaType != imageMediaType {
t.Fatalf("%s: unexpected MediaType for result, %s", testname, fetchedManifest.MediaType)
}
if fetchedManifest.SchemaVersion != ocischema.SchemaVersion.SchemaVersion {
t.Fatalf("%s: unexpected schema version for result, %d", testname, fetchedManifest.SchemaVersion)
}
payloadMediaType, _, err := fromStore.Payload()
if err != nil {
t.Fatalf("%s: error getting payload %v", testname, err)
}
if payloadMediaType != v1.MediaTypeImageManifest {
t.Fatalf("%s: unexpected MediaType for manifest payload, %s", testname, payloadMediaType)
}
// and the image index
fromStore, err = ms.Get(ctx, indexDigest)
if err != nil {
t.Fatalf("%s: unexpected error fetching image index: %v", testname, err)
}
fetchedIndex, ok := fromStore.(*manifestlist.DeserializedManifestList)
if !ok {
t.Fatalf("%s: unexpected type for fetched manifest", testname)
}
if fetchedIndex.MediaType != indexMediaType {
t.Fatalf("%s: unexpected MediaType for result, %s", testname, fetchedManifest.MediaType)
}
payloadMediaType, _, err = fromStore.Payload()
if err != nil {
t.Fatalf("%s: error getting payload %v", testname, err)
}
if payloadMediaType != v1.MediaTypeImageIndex {
t.Fatalf("%s: unexpected MediaType for index payload, %s", testname, payloadMediaType)
}
}
// TestLinkPathFuncs ensures that the link path functions behavior are locked
// down and implemented as expected.
func TestLinkPathFuncs(t *testing.T) {
@@ -387,5 +576,4 @@ func TestLinkPathFuncs(t *testing.T) {
t.Fatalf("incorrect path returned: %q != %q", p, testcase.expected)
}
}
}

View File

@@ -0,0 +1,133 @@
package storage
import (
"context"
"fmt"
"net/url"
"github.com/docker/distribution"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/manifest/ocischema"
"github.com/opencontainers/go-digest"
"github.com/opencontainers/image-spec/specs-go/v1"
)
//ocischemaManifestHandler is a ManifestHandler that covers ocischema manifests.
type ocischemaManifestHandler struct {
repository distribution.Repository
blobStore distribution.BlobStore
ctx context.Context
manifestURLs manifestURLs
}
var _ ManifestHandler = &ocischemaManifestHandler{}
func (ms *ocischemaManifestHandler) Unmarshal(ctx context.Context, dgst digest.Digest, content []byte) (distribution.Manifest, error) {
dcontext.GetLogger(ms.ctx).Debug("(*ocischemaManifestHandler).Unmarshal")
m := &ocischema.DeserializedManifest{}
if err := m.UnmarshalJSON(content); err != nil {
return nil, err
}
return m, nil
}
func (ms *ocischemaManifestHandler) Put(ctx context.Context, manifest distribution.Manifest, skipDependencyVerification bool) (digest.Digest, error) {
dcontext.GetLogger(ms.ctx).Debug("(*ocischemaManifestHandler).Put")
m, ok := manifest.(*ocischema.DeserializedManifest)
if !ok {
return "", fmt.Errorf("non-ocischema manifest put to ocischemaManifestHandler: %T", manifest)
}
if err := ms.verifyManifest(ms.ctx, *m, skipDependencyVerification); err != nil {
return "", err
}
mt, payload, err := m.Payload()
if err != nil {
return "", err
}
revision, err := ms.blobStore.Put(ctx, mt, payload)
if err != nil {
dcontext.GetLogger(ctx).Errorf("error putting payload into blobstore: %v", err)
return "", err
}
return revision.Digest, nil
}
// verifyManifest ensures that the manifest content is valid from the
// perspective of the registry. As a policy, the registry only tries to store
// valid content, leaving trust policies of that content up to consumers.
func (ms *ocischemaManifestHandler) verifyManifest(ctx context.Context, mnfst ocischema.DeserializedManifest, skipDependencyVerification bool) error {
var errs distribution.ErrManifestVerification
if mnfst.Manifest.SchemaVersion != 2 {
return fmt.Errorf("unrecognized manifest schema version %d", mnfst.Manifest.SchemaVersion)
}
if skipDependencyVerification {
return nil
}
manifestService, err := ms.repository.Manifests(ctx)
if err != nil {
return err
}
blobsService := ms.repository.Blobs(ctx)
for _, descriptor := range mnfst.References() {
var err error
switch descriptor.MediaType {
case v1.MediaTypeImageLayer, v1.MediaTypeImageLayerGzip, v1.MediaTypeImageLayerNonDistributable, v1.MediaTypeImageLayerNonDistributableGzip:
allow := ms.manifestURLs.allow
deny := ms.manifestURLs.deny
for _, u := range descriptor.URLs {
var pu *url.URL
pu, err = url.Parse(u)
if err != nil || (pu.Scheme != "http" && pu.Scheme != "https") || pu.Fragment != "" || (allow != nil && !allow.MatchString(u)) || (deny != nil && deny.MatchString(u)) {
err = errInvalidURL
break
}
}
if err == nil && len(descriptor.URLs) == 0 {
// If no URLs, require that the blob exists
_, err = blobsService.Stat(ctx, descriptor.Digest)
}
case v1.MediaTypeImageManifest:
var exists bool
exists, err = manifestService.Exists(ctx, descriptor.Digest)
if err != nil || !exists {
err = distribution.ErrBlobUnknown // just coerce to unknown.
}
fallthrough // double check the blob store.
default:
// forward all else to blob storage
if len(descriptor.URLs) == 0 {
_, err = blobsService.Stat(ctx, descriptor.Digest)
}
}
if err != nil {
if err != distribution.ErrBlobUnknown {
errs = append(errs, err)
}
// On error here, we always append unknown blob errors.
errs = append(errs, distribution.ErrManifestBlobUnknown{Digest: descriptor.Digest})
}
}
if len(errs) != 0 {
return errs
}
return nil
}

View File

@@ -0,0 +1,138 @@
package storage
import (
"context"
"regexp"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/ocischema"
"github.com/docker/distribution/registry/storage/driver/inmemory"
"github.com/opencontainers/image-spec/specs-go/v1"
)
func TestVerifyOCIManifestNonDistributableLayer(t *testing.T) {
ctx := context.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver,
ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")),
ManifestURLsDenyRegexp(regexp.MustCompile("^https?://foo/nope")))
repo := makeRepository(t, registry, "test")
manifestService := makeManifestService(t, repo)
config, err := repo.Blobs(ctx).Put(ctx, v1.MediaTypeImageConfig, nil)
if err != nil {
t.Fatal(err)
}
layer, err := repo.Blobs(ctx).Put(ctx, v1.MediaTypeImageLayerGzip, nil)
if err != nil {
t.Fatal(err)
}
nonDistributableLayer := distribution.Descriptor{
Digest: "sha256:463435349086340864309863409683460843608348608934092322395278926a",
Size: 6323,
MediaType: v1.MediaTypeImageLayerNonDistributableGzip,
}
template := ocischema.Manifest{
Versioned: manifest.Versioned{
SchemaVersion: 2,
MediaType: v1.MediaTypeImageManifest,
},
Config: config,
}
type testcase struct {
BaseLayer distribution.Descriptor
URLs []string
Err error
}
cases := []testcase{
{
nonDistributableLayer,
nil,
distribution.ErrManifestBlobUnknown{Digest: nonDistributableLayer.Digest},
},
{
layer,
[]string{"http://foo/bar"},
nil,
},
{
nonDistributableLayer,
[]string{"file:///local/file"},
errInvalidURL,
},
{
nonDistributableLayer,
[]string{"http://foo/bar#baz"},
errInvalidURL,
},
{
nonDistributableLayer,
[]string{""},
errInvalidURL,
},
{
nonDistributableLayer,
[]string{"https://foo/bar", ""},
errInvalidURL,
},
{
nonDistributableLayer,
[]string{"", "https://foo/bar"},
errInvalidURL,
},
{
nonDistributableLayer,
[]string{"http://nope/bar"},
errInvalidURL,
},
{
nonDistributableLayer,
[]string{"http://foo/nope"},
errInvalidURL,
},
{
nonDistributableLayer,
[]string{"http://foo/bar"},
nil,
},
{
nonDistributableLayer,
[]string{"https://foo/bar"},
nil,
},
}
for _, c := range cases {
m := template
l := c.BaseLayer
l.URLs = c.URLs
m.Layers = []distribution.Descriptor{l}
dm, err := ocischema.FromStruct(m)
if err != nil {
t.Error(err)
continue
}
_, err = manifestService.Put(ctx, dm)
if verr, ok := err.(distribution.ErrManifestVerification); ok {
// Extract the first error
if len(verr) == 2 {
if _, ok = verr[1].(distribution.ErrManifestBlobUnknown); ok {
err = verr[0]
}
} else if len(verr) == 1 {
err = verr[0]
}
}
if err != c.Err {
t.Errorf("%#v: expected %v, got %v", l, c.Err, err)
}
}
}

View File

@@ -1,14 +1,14 @@
package storage
import (
"context"
"path"
"strings"
"time"
"github.com/docker/distribution/context"
storageDriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/uuid"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)
// uploadData stored the location of temporary files created during a layer upload
@@ -22,7 +22,7 @@ func newUploadData() uploadData {
return uploadData{
containingDir: "",
// default to far in future to protect against missing startedat
startedAt: time.Now().Add(time.Duration(10000 * time.Hour)),
startedAt: time.Now().Add(10000 * time.Hour),
}
}
@@ -30,13 +30,13 @@ func newUploadData() uploadData {
// created before olderThan. The list of files deleted and errors
// encountered are returned
func PurgeUploads(ctx context.Context, driver storageDriver.StorageDriver, olderThan time.Time, actuallyDelete bool) ([]string, []error) {
log.Infof("PurgeUploads starting: olderThan=%s, actuallyDelete=%t", olderThan, actuallyDelete)
logrus.Infof("PurgeUploads starting: olderThan=%s, actuallyDelete=%t", olderThan, actuallyDelete)
uploadData, errors := getOutstandingUploads(ctx, driver)
var deleted []string
for _, uploadData := range uploadData {
if uploadData.startedAt.Before(olderThan) {
var err error
log.Infof("Upload files in %s have older date (%s) than purge date (%s). Removing upload directory.",
logrus.Infof("Upload files in %s have older date (%s) than purge date (%s). Removing upload directory.",
uploadData.containingDir, uploadData.startedAt, olderThan)
if actuallyDelete {
err = driver.Delete(ctx, uploadData.containingDir)
@@ -49,7 +49,7 @@ func PurgeUploads(ctx context.Context, driver storageDriver.StorageDriver, older
}
}
log.Infof("Purge uploads finished. Num deleted=%d, num errors=%d", len(deleted), len(errors))
logrus.Infof("Purge uploads finished. Num deleted=%d, num errors=%d", len(deleted), len(errors))
return deleted, errors
}
@@ -67,7 +67,7 @@ func getOutstandingUploads(ctx context.Context, driver storageDriver.StorageDriv
return uploads, append(errors, err)
}
err = Walk(ctx, driver, root, func(fileInfo storageDriver.FileInfo) error {
err = driver.Walk(ctx, root, func(fileInfo storageDriver.FileInfo) error {
filePath := fileInfo.Path()
_, file := path.Split(filePath)
if file[0] == '_' {
@@ -75,7 +75,7 @@ func getOutstandingUploads(ctx context.Context, driver storageDriver.StorageDriv
inUploadDir = (file == "_uploads")
if fileInfo.IsDir() && !inUploadDir {
return ErrSkipDir
return storageDriver.ErrSkipDir
}
}

View File

@@ -1,12 +1,12 @@
package storage
import (
"context"
"path"
"strings"
"testing"
"time"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/inmemory"
"github.com/docker/distribution/uuid"
@@ -142,7 +142,7 @@ func TestPurgeMissingStartedAt(t *testing.T) {
oneHourAgo := time.Now().Add(-1 * time.Hour)
fs, ctx := testUploadFS(t, 1, "test-repo", oneHourAgo)
err := Walk(ctx, fs, "/", func(fileInfo driver.FileInfo) error {
err := fs.Walk(ctx, "/", func(fileInfo driver.FileInfo) error {
filePath := fileInfo.Path()
_, file := path.Split(filePath)

View File

@@ -1,10 +1,10 @@
package storage
import (
"context"
"regexp"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache"
storagedriver "github.com/docker/distribution/registry/storage/driver"
@@ -19,10 +19,12 @@ type registry struct {
statter *blobStatter // global statter service.
blobDescriptorCacheProvider cache.BlobDescriptorCacheProvider
deleteEnabled bool
schema1Enabled bool
resumableDigestEnabled bool
schema1SigningKey libtrust.PrivateKey
blobDescriptorServiceFactory distribution.BlobDescriptorServiceFactory
manifestURLs manifestURLs
driver storagedriver.StorageDriver
}
// manifestURLs holds regular expressions for controlling manifest URL whitelisting
@@ -48,6 +50,13 @@ func EnableDelete(registry *registry) error {
return nil
}
// EnableSchema1 is a functional option for NewRegistry. It enables pushing of
// schema1 manifests.
func EnableSchema1(registry *registry) error {
registry.schema1Enabled = true
return nil
}
// DisableDigestResumption is a functional option for NewRegistry. It should be
// used if the registry is acting as a caching proxy.
func DisableDigestResumption(registry *registry) error {
@@ -133,6 +142,7 @@ func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, option
},
statter: statter,
resumableDigestEnabled: true,
driver: driver,
}
for _, option := range options {
@@ -237,16 +247,30 @@ func (repo *repository) Manifests(ctx context.Context, options ...distribution.M
linkDirectoryPathSpec: manifestDirectoryPathSpec,
}
ms := &manifestStore{
ctx: ctx,
repository: repo,
blobStore: blobStore,
schema1Handler: &signedManifestHandler{
var v1Handler ManifestHandler
if repo.schema1Enabled {
v1Handler = &signedManifestHandler{
ctx: ctx,
schema1SigningKey: repo.schema1SigningKey,
repository: repo,
blobStore: blobStore,
},
}
} else {
v1Handler = &v1UnsupportedHandler{
innerHandler: &signedManifestHandler{
ctx: ctx,
schema1SigningKey: repo.schema1SigningKey,
repository: repo,
blobStore: blobStore,
},
}
}
ms := &manifestStore{
ctx: ctx,
repository: repo,
blobStore: blobStore,
schema1Handler: v1Handler,
schema2Handler: &schema2ManifestHandler{
ctx: ctx,
repository: repo,
@@ -258,6 +282,12 @@ func (repo *repository) Manifests(ctx context.Context, options ...distribution.M
repository: repo,
blobStore: blobStore,
},
ocischemaHandler: &ocischemaManifestHandler{
ctx: ctx,
repository: repo,
blobStore: blobStore,
manifestURLs: repo.registry.manifestURLs,
},
}
// Apply options

View File

@@ -1,13 +1,13 @@
package storage
import (
"encoding/json"
"context"
"errors"
"fmt"
"net/url"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2"
"github.com/opencontainers/go-digest"
@@ -30,18 +30,18 @@ type schema2ManifestHandler struct {
var _ ManifestHandler = &schema2ManifestHandler{}
func (ms *schema2ManifestHandler) Unmarshal(ctx context.Context, dgst digest.Digest, content []byte) (distribution.Manifest, error) {
context.GetLogger(ms.ctx).Debug("(*schema2ManifestHandler).Unmarshal")
dcontext.GetLogger(ms.ctx).Debug("(*schema2ManifestHandler).Unmarshal")
var m schema2.DeserializedManifest
if err := json.Unmarshal(content, &m); err != nil {
m := &schema2.DeserializedManifest{}
if err := m.UnmarshalJSON(content); err != nil {
return nil, err
}
return &m, nil
return m, nil
}
func (ms *schema2ManifestHandler) Put(ctx context.Context, manifest distribution.Manifest, skipDependencyVerification bool) (digest.Digest, error) {
context.GetLogger(ms.ctx).Debug("(*schema2ManifestHandler).Put")
dcontext.GetLogger(ms.ctx).Debug("(*schema2ManifestHandler).Put")
m, ok := manifest.(*schema2.DeserializedManifest)
if !ok {
@@ -59,7 +59,7 @@ func (ms *schema2ManifestHandler) Put(ctx context.Context, manifest distribution
revision, err := ms.blobStore.Put(ctx, mt, payload)
if err != nil {
context.GetLogger(ctx).Errorf("error putting payload into blobstore: %v", err)
dcontext.GetLogger(ctx).Errorf("error putting payload into blobstore: %v", err)
return "", err
}
@@ -72,6 +72,10 @@ func (ms *schema2ManifestHandler) Put(ctx context.Context, manifest distribution
func (ms *schema2ManifestHandler) verifyManifest(ctx context.Context, mnfst schema2.DeserializedManifest, skipDependencyVerification bool) error {
var errs distribution.ErrManifestVerification
if mnfst.Manifest.SchemaVersion != 2 {
return fmt.Errorf("unrecognized manifest schema version %d", mnfst.Manifest.SchemaVersion)
}
if skipDependencyVerification {
return nil
}

View File

@@ -1,11 +1,12 @@
package storage
import (
"context"
"encoding/json"
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/reference"
"github.com/docker/libtrust"
@@ -24,7 +25,7 @@ type signedManifestHandler struct {
var _ ManifestHandler = &signedManifestHandler{}
func (ms *signedManifestHandler) Unmarshal(ctx context.Context, dgst digest.Digest, content []byte) (distribution.Manifest, error) {
context.GetLogger(ms.ctx).Debug("(*signedManifestHandler).Unmarshal")
dcontext.GetLogger(ms.ctx).Debug("(*signedManifestHandler).Unmarshal")
var (
signatures [][]byte
@@ -56,7 +57,7 @@ func (ms *signedManifestHandler) Unmarshal(ctx context.Context, dgst digest.Dige
}
func (ms *signedManifestHandler) Put(ctx context.Context, manifest distribution.Manifest, skipDependencyVerification bool) (digest.Digest, error) {
context.GetLogger(ms.ctx).Debug("(*signedManifestHandler).Put")
dcontext.GetLogger(ms.ctx).Debug("(*signedManifestHandler).Put")
sm, ok := manifest.(*schema1.SignedManifest)
if !ok {
@@ -72,7 +73,7 @@ func (ms *signedManifestHandler) Put(ctx context.Context, manifest distribution.
revision, err := ms.blobStore.Put(ctx, mt, payload)
if err != nil {
context.GetLogger(ctx).Errorf("error putting payload into blobstore: %v", err)
dcontext.GetLogger(ctx).Errorf("error putting payload into blobstore: %v", err)
return "", err
}

View File

@@ -1,10 +1,10 @@
package storage
import (
"context"
"path"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/opencontainers/go-digest"
)
@@ -122,17 +122,20 @@ func (ts *tagStore) Untag(ctx context.Context, tag string) error {
name: ts.repository.Named().Name(),
tag: tag,
})
switch err.(type) {
case storagedriver.PathNotFoundError:
return distribution.ErrTagUnknown{Tag: tag}
case nil:
break
default:
if err != nil {
return err
}
return ts.blobStore.driver.Delete(ctx, tagPath)
if err := ts.blobStore.driver.Delete(ctx, tagPath); err != nil {
switch err.(type) {
case storagedriver.PathNotFoundError:
return nil // Untag is idempotent, we don't care if it didn't exist
default:
return err
}
}
return nil
}
// linkedBlobStore returns the linkedBlobStore for the named tag, allowing one
@@ -176,9 +179,13 @@ func (ts *tagStore) Lookup(ctx context.Context, desc distribution.Descriptor) ([
tag: tag,
}
tagLinkPath, err := pathFor(tagLinkPathSpec)
tagLinkPath, _ := pathFor(tagLinkPathSpec)
tagDigest, err := ts.blobStore.readlink(ctx, tagLinkPath)
if err != nil {
switch err.(type) {
case storagedriver.PathNotFoundError:
continue
}
return nil, err
}

View File

@@ -1,10 +1,10 @@
package storage
import (
"context"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
@@ -84,8 +84,8 @@ func TestTagStoreUnTag(t *testing.T) {
desc := distribution.Descriptor{Digest: "sha256:bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}
err := tags.Untag(ctx, "latest")
if err == nil {
t.Errorf("Expected error untagging non-existant tag")
if err != nil {
t.Error(err)
}
err = tags.Tag(ctx, "latest", desc)

View File

@@ -1,7 +1,8 @@
package storage
import (
"github.com/docker/distribution/context"
"context"
"github.com/docker/distribution/registry/storage/driver"
)

View File

@@ -0,0 +1,23 @@
package storage
import (
"context"
"github.com/docker/distribution"
digest "github.com/opencontainers/go-digest"
)
// signedManifestHandler is a ManifestHandler that unmarshals v1 manifests but
// refuses to Put v1 manifests
type v1UnsupportedHandler struct {
innerHandler ManifestHandler
}
var _ ManifestHandler = &v1UnsupportedHandler{}
func (v *v1UnsupportedHandler) Unmarshal(ctx context.Context, dgst digest.Digest, content []byte) (distribution.Manifest, error) {
return v.innerHandler.Unmarshal(ctx, dgst, content)
}
func (v *v1UnsupportedHandler) Put(ctx context.Context, manifest distribution.Manifest, skipDependencyVerification bool) (digest.Digest, error) {
return digest.Digest(""), distribution.ErrSchemaV1Unsupported
}

View File

@@ -1,9 +1,10 @@
package storage
import (
"context"
"path"
"github.com/docker/distribution/context"
dcontext "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/driver"
"github.com/opencontainers/go-digest"
)
@@ -39,7 +40,7 @@ func (v Vacuum) RemoveBlob(dgst string) error {
return err
}
context.GetLogger(v.ctx).Infof("Deleting blob: %s", blobPath)
dcontext.GetLogger(v.ctx).Infof("Deleting blob: %s", blobPath)
err = v.driver.Delete(v.ctx, blobPath)
if err != nil {
@@ -49,6 +50,40 @@ func (v Vacuum) RemoveBlob(dgst string) error {
return nil
}
// RemoveManifest removes a manifest from the filesystem
func (v Vacuum) RemoveManifest(name string, dgst digest.Digest, tags []string) error {
// remove a tag manifest reference, in case of not found continue to next one
for _, tag := range tags {
tagsPath, err := pathFor(manifestTagIndexEntryPathSpec{name: name, revision: dgst, tag: tag})
if err != nil {
return err
}
_, err = v.driver.Stat(v.ctx, tagsPath)
if err != nil {
switch err := err.(type) {
case driver.PathNotFoundError:
continue
default:
return err
}
}
dcontext.GetLogger(v.ctx).Infof("deleting manifest tag reference: %s", tagsPath)
err = v.driver.Delete(v.ctx, tagsPath)
if err != nil {
return err
}
}
manifestPath, err := pathFor(manifestRevisionPathSpec{name: name, revision: dgst})
if err != nil {
return err
}
dcontext.GetLogger(v.ctx).Infof("deleting manifest: %s", manifestPath)
return v.driver.Delete(v.ctx, manifestPath)
}
// RemoveRepository removes a repository directory from the
// filesystem
func (v Vacuum) RemoveRepository(repoName string) error {
@@ -57,7 +92,7 @@ func (v Vacuum) RemoveRepository(repoName string) error {
return err
}
repoDir := path.Join(rootForRepository, repoName)
context.GetLogger(v.ctx).Infof("Deleting repo: %s", repoDir)
dcontext.GetLogger(v.ctx).Infof("Deleting repo: %s", repoDir)
err = v.driver.Delete(v.ctx, repoDir)
if err != nil {
return err

View File

@@ -1,59 +0,0 @@
package storage
import (
"errors"
"fmt"
"sort"
"github.com/docker/distribution/context"
storageDriver "github.com/docker/distribution/registry/storage/driver"
)
// ErrSkipDir is used as a return value from onFileFunc to indicate that
// the directory named in the call is to be skipped. It is not returned
// as an error by any function.
var ErrSkipDir = errors.New("skip this directory")
// WalkFn is called once per file by Walk
// If the returned error is ErrSkipDir and fileInfo refers
// to a directory, the directory will not be entered and Walk
// will continue the traversal. Otherwise Walk will return
type WalkFn func(fileInfo storageDriver.FileInfo) error
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file
func Walk(ctx context.Context, driver storageDriver.StorageDriver, from string, f WalkFn) error {
children, err := driver.List(ctx, from)
if err != nil {
return err
}
sort.Stable(sort.StringSlice(children))
for _, child := range children {
// TODO(stevvooe): Calling driver.Stat for every entry is quite
// expensive when running against backends with a slow Stat
// implementation, such as s3. This is very likely a serious
// performance bottleneck.
fileInfo, err := driver.Stat(ctx, child)
if err != nil {
return err
}
err = f(fileInfo)
skipDir := (err == ErrSkipDir)
if err != nil && !skipDir {
return err
}
if fileInfo.IsDir() && !skipDir {
if err := Walk(ctx, driver, child, f); err != nil {
return err
}
}
}
return nil
}
// pushError formats an error type given a path and an error
// and pushes it to a slice of errors
func pushError(errors []error, path string, err error) []error {
return append(errors, fmt.Errorf("%s: %s", path, err))
}

View File

@@ -1,152 +0,0 @@
package storage
import (
"fmt"
"sort"
"testing"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
func testFS(t *testing.T) (driver.StorageDriver, map[string]string, context.Context) {
d := inmemory.New()
ctx := context.Background()
expected := map[string]string{
"/a": "dir",
"/a/b": "dir",
"/a/b/c": "dir",
"/a/b/c/d": "file",
"/a/b/c/e": "file",
"/a/b/f": "dir",
"/a/b/f/g": "file",
"/a/b/f/h": "file",
"/a/b/f/i": "file",
"/z": "dir",
"/z/y": "file",
}
for p, typ := range expected {
if typ != "file" {
continue
}
if err := d.PutContent(ctx, p, []byte(p)); err != nil {
t.Fatalf("unable to put content into fixture: %v", err)
}
}
return d, expected, ctx
}
func TestWalkErrors(t *testing.T) {
d, expected, ctx := testFS(t)
fileCount := len(expected)
err := Walk(ctx, d, "", func(fileInfo driver.FileInfo) error {
return nil
})
if err == nil {
t.Error("Expected invalid root err")
}
errEarlyExpected := fmt.Errorf("Early termination")
err = Walk(ctx, d, "/", func(fileInfo driver.FileInfo) error {
// error on the 2nd file
if fileInfo.Path() == "/a/b" {
return errEarlyExpected
}
delete(expected, fileInfo.Path())
return nil
})
if len(expected) != fileCount-1 {
t.Error("Walk failed to terminate with error")
}
if err != errEarlyExpected {
if err == nil {
t.Fatalf("expected an error due to early termination")
} else {
t.Error(err.Error())
}
}
err = Walk(ctx, d, "/nonexistent", func(fileInfo driver.FileInfo) error {
return nil
})
if err == nil {
t.Errorf("Expected missing file err")
}
}
func TestWalk(t *testing.T) {
d, expected, ctx := testFS(t)
var traversed []string
err := Walk(ctx, d, "/", func(fileInfo driver.FileInfo) error {
filePath := fileInfo.Path()
filetype, ok := expected[filePath]
if !ok {
t.Fatalf("Unexpected file in walk: %q", filePath)
}
if fileInfo.IsDir() {
if filetype != "dir" {
t.Errorf("Unexpected file type: %q", filePath)
}
} else {
if filetype != "file" {
t.Errorf("Unexpected file type: %q", filePath)
}
// each file has its own path as the contents. If the length
// doesn't match the path length, fail.
if fileInfo.Size() != int64(len(fileInfo.Path())) {
t.Fatalf("unexpected size for %q: %v != %v",
fileInfo.Path(), fileInfo.Size(), len(fileInfo.Path()))
}
}
delete(expected, filePath)
traversed = append(traversed, filePath)
return nil
})
if len(expected) > 0 {
t.Errorf("Missed files in walk: %q", expected)
}
if !sort.StringsAreSorted(traversed) {
t.Errorf("result should be sorted: %v", traversed)
}
if err != nil {
t.Fatalf(err.Error())
}
}
func TestWalkSkipDir(t *testing.T) {
d, expected, ctx := testFS(t)
err := Walk(ctx, d, "/", func(fileInfo driver.FileInfo) error {
filePath := fileInfo.Path()
if filePath == "/a/b" {
// skip processing /a/b/c and /a/b/c/d
return ErrSkipDir
}
delete(expected, filePath)
return nil
})
if err != nil {
t.Fatalf(err.Error())
}
if _, ok := expected["/a/b/c"]; !ok {
t.Errorf("/a/b/c not skipped")
}
if _, ok := expected["/a/b/c/d"]; !ok {
t.Errorf("/a/b/c/d not skipped")
}
if _, ok := expected["/a/b/c/e"]; !ok {
t.Errorf("/a/b/c/e not skipped")
}
}