Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an endpoint to download a file #818

Merged
merged 6 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions backend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"chalk": "^5.2.0",
"config": "^3.3.9",
"connect-mongo": "^5.1.0",
"content-disposition": "^0.5.4",
"cross-fetch": "^3.1.8",
"dedent-js": "^1.0.1",
"dev-null": "^0.1.1",
Expand Down
20 changes: 19 additions & 1 deletion backend/src/clients/s3.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { S3Client } from '@aws-sdk/client-s3'
import { GetObjectCommand, GetObjectRequest, S3Client } from '@aws-sdk/client-s3'
import { Upload } from '@aws-sdk/lib-storage'

import config from '../utils/v2/config.js'
Expand Down Expand Up @@ -31,3 +31,21 @@ export async function putObjectStream(bucket: string, key: string, body: Readabl
fileSize,
}
}

export async function getObjectStream(bucket: string, key: string, range?: { start: number; end: number }) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for this method?

const client = await getS3Client()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to create a new S3 Client for every S3 request?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our S3 credentials expire on some platforms relatively frequently. Tracking whether the current S3 client has expired and replacing it seems too complex for a function that's likely to only get called a few hundred - thousand times a day.


const input: GetObjectRequest = {
Bucket: bucket,
Key: key,
}

if (range) {
input.Range = `bytes=${range.start}-${range.end}`
}

const command = new GetObjectCommand(input)
const response = await client.send(command)

return response
}
21 changes: 20 additions & 1 deletion backend/src/connectors/v2/authorisation/Base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { AccessRequestDoc } from '../../../models/v2/AccessRequest.js'
import { FileInterfaceDoc } from '../../../models/v2/File.js'
import { ModelDoc, ModelVisibility } from '../../../models/v2/Model.js'
import { ReleaseDoc } from '../../../models/v2/Release.js'
import { UserDoc } from '../../../models/v2/User.js'
import { Access } from '../../../routes/v1/registryAuth.js'
import authentication from '../authentication/index.js'

export const ModelAction = {
Expand Down Expand Up @@ -29,6 +31,17 @@ export const AccessRequestAction = {
}
export type AccessRequestActionKeys = (typeof ReleaseAction)[keyof typeof ReleaseAction]

export const FileAction = {
Download: 'download',
}
export type FileActionKeys = (typeof FileAction)[keyof typeof FileAction]

export const ImageAction = {
Pull: 'pull',
Push: 'push',
}
export type ImageActionKeys = (typeof ImageAction)[keyof typeof ImageAction]

export abstract class BaseAuthorisationConnector {
abstract userModelAction(user: UserDoc, model: ModelDoc, action: ModelActionKeys): Promise<boolean>
abstract userReleaseAction(
Expand All @@ -43,7 +56,13 @@ export abstract class BaseAuthorisationConnector {
accessRequest: AccessRequestDoc,
action: AccessRequestActionKeys,
): Promise<boolean>

abstract userFileAction(
user: UserDoc,
model: ModelDoc,
file: FileInterfaceDoc,
action: FileActionKeys,
): Promise<boolean>
abstract userImageAction(user: UserDoc, model: ModelDoc, access: Access, action: ImageActionKeys): Promise<boolean>
async hasModelVisibilityAccess(user: UserDoc, model: ModelDoc) {
if (model.visibility === ModelVisibility.Public) {
return true
Expand Down
83 changes: 82 additions & 1 deletion backend/src/connectors/v2/authorisation/silly.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import { AccessRequestDoc } from '../../../models/v2/AccessRequest.js'
import { FileInterfaceDoc } from '../../../models/v2/File.js'
import { ModelDoc } from '../../../models/v2/Model.js'
import { ReleaseDoc } from '../../../models/v2/Release.js'
import { UserDoc } from '../../../models/v2/User.js'
import { AccessRequestActionKeys, BaseAuthorisationConnector, ModelActionKeys, ReleaseActionKeys } from './Base.js'
import { Access } from '../../../routes/v1/registryAuth.js'
import { getAccessRequestsByModel } from '../../../services/v2/accessRequest.js'
import log from '../../../services/v2/log.js'
import authentication from '../authentication/index.js'
import {
AccessRequestActionKeys,
BaseAuthorisationConnector,
FileAction,
FileActionKeys,
ImageAction,
ImageActionKeys,
ModelActionKeys,
ReleaseActionKeys,
} from './Base.js'

export class SillyAuthorisationConnector extends BaseAuthorisationConnector {
constructor() {
Expand Down Expand Up @@ -48,4 +62,71 @@ export class SillyAuthorisationConnector extends BaseAuthorisationConnector {
// Allow any other action to be completed
return true
}

async userFileAction(
user: UserDoc,
model: ModelDoc,
file: FileInterfaceDoc,
action: FileActionKeys,
): Promise<boolean> {
// Prohibit non-collaborators from seeing private models
if (!(await this.hasModelVisibilityAccess(user, model))) {
return false
}

const entities = await authentication.getEntities(user)
if (model.collaborators.some((collaborator) => entities.includes(collaborator.entity))) {
// Collaborators can upload or download files
return true
}

if (action !== FileAction.Download) {
log.warn({ userDn: user.dn, file: file._id }, 'Non-collaborator can only download artefacts')
return false
}

const accessRequests = await getAccessRequestsByModel(user, model.id)
const accessRequest = accessRequests.find((accessRequest) =>
accessRequest.metadata.overview.entities.some((entity) => entities.includes(entity)),
)

if (!accessRequest) {
// User does not have a valid access request
log.warn({ userDn: user.dn, file: file._id }, 'No valid access request found')
return false
}

return true
}

async userImageAction(user: UserDoc, model: ModelDoc, access: Access, action: ImageActionKeys): Promise<boolean> {
// Prohibit non-collaborators from seeing private models
if (!(await this.hasModelVisibilityAccess(user, model))) {
return false
}

const entities = await authentication.getEntities(user)
if (model.collaborators.some((collaborator) => entities.includes(collaborator.entity))) {
// Collaborators can upload or download files
return true
}

if (action !== ImageAction.Pull) {
log.warn({ userDn: user.dn, access }, 'Non-collaborator can only pull models')
return false
}

const accessRequests = await getAccessRequestsByModel(user, model.id)
const accessRequest = accessRequests.find((accessRequest) =>
accessRequest.metadata.overview.entities.some((entity) => entities.includes(entity)),
)

if (!accessRequest) {
// User does not have a valid access request
log.warn({ userDn: user.dn, access }, 'No valid access request found')
return false
}

return true
}
}
2 changes: 2 additions & 0 deletions backend/src/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import { getModelAccessRequests } from './routes/v2/model/accessRequest/getModel
import { patchAccessRequest } from './routes/v2/model/accessRequest/patchAccessRequest.js'
import { postAccessRequest } from './routes/v2/model/accessRequest/postAccessRequest.js'
import { deleteFile } from './routes/v2/model/file/deleteFile.js'
import { getDownloadFile } from './routes/v2/model/file/getDownloadFile.js'
import { getFiles } from './routes/v2/model/file/getFiles.js'
import { postFinishMultipartUpload } from './routes/v2/model/file/postFinishMultipartUpload.js'
import { postSimpleUpload } from './routes/v2/model/file/postSimpleUpload.js'
Expand Down Expand Up @@ -226,6 +227,7 @@ if (config.experimental.v2) {
server.post('/api/v2/model/:modelId/access-request/:accessRequestId/review', ...postAccessRequestReviewResponse)

server.get('/api/v2/model/:modelId/files', ...getFiles)
server.get('/api/v2/model/:modelId/file/:fileId/download', ...getDownloadFile)
server.post('/api/v2/model/:modelId/files/upload/simple', ...postSimpleUpload)
server.post('/api/v2/model/:modelId/files/upload/multipart/start', ...postStartMultipartUpload)
server.post('/api/v2/model/:modelId/files/upload/multipart/finish', ...postFinishMultipartUpload)
Expand Down
32 changes: 4 additions & 28 deletions backend/src/routes/v1/registryAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import jwt from 'jsonwebtoken'
import { isEqual } from 'lodash-es'
import { stringify as uuidStringify, v4 as uuidv4 } from 'uuid'

import authentication from '../../connectors/v2/authentication/index.js'
import { ImageAction } from '../../connectors/v2/authorisation/Base.js'
import authorisation from '../../connectors/v2/authorisation/index.js'
import { ModelDoc } from '../../models/v2/Model.js'
import { UserDoc as UserDocV2 } from '../../models/v2/User.js'
import { findDeploymentByUuid } from '../../services/deployment.js'
import { getAccessRequestsByModel } from '../../services/v2/accessRequest.js'
import log from '../../services/v2/log.js'
import { getModelById } from '../../services/v2/model.js'
import { ModelId, UserDoc } from '../../types/types.js'
Expand Down Expand Up @@ -175,32 +175,8 @@ async function checkAccessV2(access: Access, user: UserDocV2) {
return false
}

const entities = await authentication.getEntities(user)
if (model.collaborators.some((collaborator) => entities.includes(collaborator.entity))) {
// They are a collaborator to the model, let them push or pull.
return true
}

if (!isEqual(access.actions, ['pull'])) {
// If users are not collaborators, they should only be able to pull
log.warn({ userDn: user.dn, access }, 'Non-collaborator can only pull models')
return false
}

// TODO: If the model is 'public access' automatically approve pulls.

const accessRequests = await getAccessRequestsByModel(user, modelId)
const accessRequest = accessRequests.find((accessRequest) =>
accessRequest.metadata.overview.entities.some((entity) => entities.includes(entity)),
)

if (!accessRequest) {
// User does not have a valid access request
log.warn({ userDn: user.dn, access }, 'No valid access request found')
return false
}

return true
const action = isEqual(access.actions, ['pull']) ? ImageAction.Pull : ImageAction.Push
return authorisation.userImageAction(user, model, access, action)
}

async function checkAccess(access: Access, user: UserDoc) {
Expand Down
55 changes: 55 additions & 0 deletions backend/src/routes/v2/model/file/getDownloadFile.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import bodyParser from 'body-parser'
import contentDisposition from 'content-disposition'
import { Request, Response } from 'express'
import stream from 'stream'
import { z } from 'zod'

import { FileInterface } from '../../../../models/v2/File.js'
import { downloadFile, getFileById } from '../../../../services/v2/file.js'
import { BadReq, InternalError } from '../../../../utils/v2/error.js'
import { parse } from '../../../../utils/validate.js'

export const getDownloadFileSchema = z.object({
params: z.object({
modelId: z.string(),
fileId: z.string(),
}),
})

interface GetDownloadFileResponse {
files: Array<FileInterface>
}

export const getDownloadFile = [
bodyParser.json(),
async (req: Request, res: Response<GetDownloadFileResponse>) => {
const {
params: { fileId },
} = parse(req, getDownloadFileSchema)

const file = await getFileById(req.user, fileId)

// required to support utf-8 file names
res.set('Content-Disposition', contentDisposition(file.name, { type: 'inline' }))
res.set('Content-Type', file.mime)
res.set('Cache-Control', 'public, max-age=604800, immutable')

if (req.headers.range) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this first before logic for response?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Range requests will include the same headers (Content-Type, Cache-Control and Content-Disposition) seen above. The idea behind placing it here is that when we do support it you can just write the implementation in here and not need to change anything else.

I could move the error up (rule: do error checking first), but it feels like a waste for now.

// TODO: support ranges
throw BadReq('Ranges are not supported', { fileId })
}

res.set('Content-Length', String(file.size))
// TODO: support ranges
// res.set('Accept-Ranges', 'bytes')
res.writeHead(200)

const stream = await downloadFile(req.user, fileId)
if (!stream.Body) {
throw InternalError('We were not able to retrieve the body of this file', { fileId })
}

// The AWS library doesn't seem to properly type 'Body' as being pipeable?
;(stream.Body as stream.Readable).pipe(res)
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something, where is the method's return statement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to return anything. The pipe will close the response stream when the Body is done being streamed to the user.

]
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
"description": "A description of what the model does.",
"type": "string",
"minLength": 1,
"maxLength": 5000,
"widget": "customTextInput"
"maxLength": 5000
},
"tags": {
"title": "Descriptive tags for the model.",
Expand Down
16 changes: 14 additions & 2 deletions backend/src/services/v2/file.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { putObjectStream } from '../../clients/s3.js'
import { ModelAction } from '../../connectors/v2/authorisation/Base.js'
import { getObjectStream, putObjectStream } from '../../clients/s3.js'
import { FileAction, ModelAction } from '../../connectors/v2/authorisation/Base.js'
import authorisation from '../../connectors/v2/authorisation/index.js'
import FileModel from '../../models/v2/File.js'
import { UserDoc } from '../../models/v2/User.js'
Expand Down Expand Up @@ -37,6 +37,18 @@ export async function uploadFile(user: UserDoc, modelId: string, name: string, m
return file
}

export async function downloadFile(user: UserDoc, fileId: string, range?: { start: number; end: number }) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests?

const file = await getFileById(user, fileId)
const model = await getModelById(user, file.modelId)

const access = await authorisation.userFileAction(user, model, file, FileAction.Download)
if (!access) {
throw Forbidden(`You do not have permission to download this model.`, { user: user.dn, fileId })
}

return getObjectStream(file.bucket, file.path, range)
}

export async function getFileById(user: UserDoc, fileId: string) {
const file = await FileModel.findOne({
_id: fileId,
Expand Down
4 changes: 4 additions & 0 deletions backend/src/utils/v2/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ export function Forbidden(message: string, context?: BailoError['context'], logg
export function NotFound(message: string, context?: BailoError['context'], logger?: Logger) {
return GenericError(404, message, context, logger)
}

export function InternalError(message: string, context?: BailoError['context'], logger?: Logger) {
return GenericError(500, message, context, logger)
}
Loading