Adds organization model type (#162)

* Adds organization model type

Fixes test cases

* Switches to using org username vs name

* Adds authorization based on your organization

* Adds org to token

Minor bug fixes

* Fixes test case
This commit is contained in:
Chris Anderson 2023-05-27 08:55:24 -05:00 committed by GitHub
parent 9554d17a1d
commit 8e042ee9b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 223 additions and 111 deletions

View file

@ -73,6 +73,7 @@ jobs:
DB_USERNAME: root
DB_PORT: 3306
BASE_URL: https://parcelvoy.com
AUTH_DRIVER: basic
QUEUE_DRIVER: memory
STORAGE_DRIVER: s3

View file

@ -0,0 +1,42 @@
exports.up = async function(knex) {
await knex.schema.createTable('organizations', function(table) {
table.increments()
table.string('username').unique()
table.string('domain').index()
table.json('auth')
table.timestamp('created_at').defaultTo(knex.fn.now())
table.timestamp('updated_at').defaultTo(knex.fn.now())
})
await knex.schema.table('projects', function(table) {
table.integer('organization_id')
.references('id')
.inTable('organizations')
.onDelete('CASCADE')
.unsigned()
.after('id')
})
await knex.schema.table('admins', function(table) {
table.integer('organization_id')
.references('id')
.inTable('organizations')
.onDelete('CASCADE')
.unsigned()
.after('id')
})
const orgId = await knex('organizations').insert({ id: 1, username: 'main' })
await knex.raw('UPDATE projects SET organization_id = ? WHERE organization_id IS NULL', [orgId])
await knex.raw('UPDATE admins SET organization_id = ? WHERE organization_id IS NULL', [orgId])
}
exports.down = async function(knex) {
await knex.schema.dropTable('organizations')
await knex.schema.table('projects', function(table) {
table.dropColumn('organization_id')
})
await knex.schema.table('admins', function(table) {
table.dropColumn('organization_id')
})
}

View file

@ -1,6 +1,7 @@
import Model, { ModelParams } from '../core/Model'
export class Admin extends Model {
organization_id!: number
email!: string
first_name?: string
last_name?: string

View file

@ -16,11 +16,15 @@ export const getAdminByEmail = async (email: string): Promise<Admin | undefined>
return await Admin.first(qb => qb.where('email', email))
}
export const createOrUpdateAdmin = async (params: AdminParams): Promise<Admin> => {
export const createOrUpdateAdmin = async ({ organization_id, ...params }: AdminParams): Promise<Admin> => {
const admin = await getAdminByEmail(params.email)
if (admin?.id) {
return Admin.updateAndFetch(admin.id, params)
} else {
return Admin.insertAndFetch(params)
return Admin.insertAndFetch({
...params,
organization_id,
})
}
}

View file

@ -3,9 +3,8 @@ import AuthProvider from './AuthProvider'
import OpenIDProvider, { OpenIDConfig } from './OpenIDAuthProvider'
import SAMLProvider, { SAMLConfig } from './SAMLAuthProvider'
import { DriverConfig } from '../config/env'
import LoggerAuthProvider from './LoggerAuthProvider'
import { logger } from '../config/logger'
import BasicAuthProvider, { BasicAuthConfig } from './BasicAuthProvider'
import { getOrganizationByUsername } from '../organizations/OrganizationService'
export type AuthProviderName = 'basic' | 'saml' | 'openid' | 'logger'
@ -20,23 +19,38 @@ export default class Auth {
provider: AuthProvider
constructor(config?: AuthConfig) {
this.provider = Auth.provider(config)
}
static provider(config?: AuthConfig): AuthProvider {
if (config?.driver === 'basic') {
this.provider = new BasicAuthProvider(config)
return new BasicAuthProvider(config)
} else if (config?.driver === 'saml') {
this.provider = new SAMLProvider(config)
return new SAMLProvider(config)
} else if (config?.driver === 'openid') {
this.provider = new OpenIDProvider(config)
return new OpenIDProvider(config)
} else {
logger.info({}, 'No valid auth provider has been set, using logger as fallback')
this.provider = new LoggerAuthProvider()
throw new Error('A valid auth driver must be set!')
}
}
async start(ctx: Context): Promise<void> {
return await this.provider.start(ctx)
const provider = await this.loadProvider(ctx)
return await provider.start(ctx)
}
async validate(ctx: Context): Promise<void> {
return await this.provider.validate(ctx)
const provider = await this.loadProvider(ctx)
return await provider.validate(ctx)
}
private async loadProvider(ctx: Context): Promise<AuthProvider> {
if (ctx.subdomains && ctx.subdomains[0]) {
const subdomain = ctx.subdomains[0]
const org = await getOrganizationByUsername(subdomain)
ctx.state.organization = org
if (org) return Auth.provider(org.auth)
}
return this.provider
}
}

View file

@ -10,6 +10,7 @@ import { getTokenCookies, isAccessTokenRevoked } from './TokenRepository'
export interface JwtAdmin {
id: number
organization_id: number
}
export interface State {
@ -35,13 +36,13 @@ const parseAuth = async (ctx: Context) => {
}
if (token.startsWith('pk_')) {
// public key
// Public key
return {
scope: 'public',
key: await getProjectApiKey(token),
}
} else if (token.startsWith('sk_')) {
// secret key
// Secret key
return {
scope: 'secret',
key: await getProjectApiKey(token),
@ -51,7 +52,7 @@ const parseAuth = async (ctx: Context) => {
if (await isAccessTokenRevoked(token)) {
throw new RequestError(AuthError.AccessDenied)
}
// user jwt
// User JWT
return {
scope: 'admin',
admin,

View file

@ -2,27 +2,41 @@ import { Context } from 'koa'
import App from '../app'
import { RequestError } from '../core/errors'
import AuthError from './AuthError'
import { AdminParams } from './Admin'
import { createOrUpdateAdmin, getAdmin } from './AdminRepository'
import { Admin, AdminParams } from './Admin'
import { createOrUpdateAdmin } from './AdminRepository'
import { generateAccessToken, OAuthResponse, setTokenCookies } from './TokenRepository'
import Organization from '../organizations/Organization'
import { State } from './AuthMiddleware'
import { createOrganization, getOrganizationByDomain } from '../organizations/OrganizationService'
type OrgState = State & { organization?: Organization }
export type AuthContext = Context & { state: OrgState }
export default abstract class AuthProvider {
abstract start(ctx: Context): Promise<void>
abstract validate(ctx: Context): Promise<void>
abstract start(ctx: AuthContext): Promise<void>
abstract validate(ctx: AuthContext): Promise<void>
async login(id: number, ctx?: Context, redirect?: string): Promise<OAuthResponse>
async login(params: AdminParams, ctx?: Context, redirect?: string): Promise<OAuthResponse>
async login(params: AdminParams | number, ctx?: Context, redirect?: string): Promise<OAuthResponse> {
async loadAuthOrganization(ctx: AuthContext, domain: string) {
const organization = ctx.state.organization ?? await getOrganizationByDomain(domain)
if (!organization) {
return await createOrganization(domain)
}
return organization
}
async login(params: AdminParams, ctx?: AuthContext, redirect?: string): Promise<OAuthResponse> {
// If existing, update otherwise create new admin based on params
const admin = typeof params === 'number'
? await getAdmin(params)
: await createOrUpdateAdmin(params)
const admin = await createOrUpdateAdmin(params)
if (!admin) throw new RequestError(AuthError.AdminNotFound)
const oauth = generateAccessToken(admin, ctx)
return await this.generateOauth(admin, ctx, redirect)
}
private async generateOauth(admin: Admin, ctx?: AuthContext, redirect?: string) {
const oauth = await generateAccessToken(admin, ctx)
if (ctx) {
setTokenCookies(ctx, oauth)
@ -31,7 +45,7 @@ export default abstract class AuthProvider {
return oauth
}
async logout(params: AdminParams, ctx: Context) {
async logout(params: Pick<AdminParams, 'email'>, ctx: AuthContext) {
console.log(params, ctx)
// not sure how we find the refresh token for a given session atm
// revokeRefreshToken()

View file

@ -46,7 +46,10 @@ export default class BasicAuthProvider extends AuthProvider {
admin = await Admin.insertAndFetch({ email, first_name: 'Admin' })
}
// Get the only org that can exist for this method of login
const { id } = await this.loadAuthOrganization(ctx, 'local')
// Process the login
await this.login(admin.id, ctx)
await this.login({ email, organization_id: id }, ctx)
}
}

View file

@ -1,50 +0,0 @@
import { Context } from 'koa'
import { verify } from 'jsonwebtoken'
import { AuthTypeConfig } from './Auth'
import { getAdminByEmail } from './AdminRepository'
import AuthProvider from './AuthProvider'
import { generateAccessToken } from './TokenRepository'
import App from '../app'
import { logger } from '../config/logger'
import { combineURLs } from '../utilities'
export interface LoggerAuthConfig extends AuthTypeConfig {
driver: 'logger'
}
export default class LoggerAuthProvider extends AuthProvider {
async start(ctx: Context) {
const { email } = ctx.request.body
if (!email) throw new Error()
// Find admin, otherwise silently break
const admin = await getAdminByEmail(email)
if (!admin) return
const jwt = generateAccessToken(admin.id, ctx)
const url = this.callbackUrl(jwt.access_token)
logger.info({ url }, 'login link')
ctx.redirect(url)
}
async validate(ctx: Context) {
const jwt = ctx.query.token as string
// Verify that the token is authentic and get the ID
const { id } = verify(jwt, App.main.env.secret) as { id: number }
// With the ID, process a new login
const oauth = await this.login(id, ctx)
logger.info(oauth, 'login credentials')
}
callbackUrl(token: string): string {
const baseUrl = combineURLs([App.main.env.baseUrl, 'api/auth/login'])
const url = new URL(baseUrl)
url.searchParams.set('token', token)
return url.href
}
}

View file

@ -1,10 +1,9 @@
import { addSeconds } from 'date-fns'
import { Context } from 'koa'
import { Issuer, generators, BaseClient, IdTokenClaims } from 'openid-client'
import { RequestError } from '../core/errors'
import AuthError from './AuthError'
import { AuthTypeConfig } from './Auth'
import AuthProvider from './AuthProvider'
import AuthProvider, { AuthContext } from './AuthProvider'
import { firstQueryParam } from '../utilities'
import { logger } from '../config/logger'
@ -14,7 +13,7 @@ export interface OpenIDConfig extends AuthTypeConfig {
clientId: string
cliendSecret: string
redirectUri: string
domainWhitelist: string[]
domain: string
}
export default class OpenIDAuthProvider extends AuthProvider {
@ -28,7 +27,7 @@ export default class OpenIDAuthProvider extends AuthProvider {
this.getClient()
}
async start(ctx: Context): Promise<void> {
async start(ctx: AuthContext): Promise<void> {
const client = await this.getClient()
@ -61,7 +60,7 @@ export default class OpenIDAuthProvider extends AuthProvider {
ctx.redirect(url)
}
async validate(ctx: Context): Promise<void> {
async validate(ctx: AuthContext): Promise<void> {
const client = await this.getClient()
// Unsafe cast, but Koa and library don't play nicely
@ -77,8 +76,8 @@ export default class OpenIDAuthProvider extends AuthProvider {
}
const claims = tokenSet.claims()
if (!this.isDomainWhitelisted(claims)) {
const domain = this.getDomain(claims)
if (!domain || !this.domainMatch(domain)) {
throw new RequestError(AuthError.InvalidDomain)
}
@ -86,11 +85,13 @@ export default class OpenIDAuthProvider extends AuthProvider {
throw new RequestError(AuthError.InvalidEmail)
}
const organization = await this.loadAuthOrganization(ctx, domain)
const admin = {
email: claims.email,
first_name: claims.given_name ?? claims.name,
last_name: claims.family_name,
image_url: claims.picture,
organization_id: organization.id,
}
await this.login(admin, ctx, state)
@ -115,12 +116,15 @@ export default class OpenIDAuthProvider extends AuthProvider {
return this.client
}
private isDomainWhitelisted(claims: IdTokenClaims): boolean {
private domainMatch(domain?: string): boolean {
if (!this.config.domain) return true
return this.config.domain === domain
}
private getDomain(claims: IdTokenClaims): string | undefined {
if (claims.hd && typeof claims.hd === 'string') {
return this.config.domainWhitelist.includes(claims.hd)
return claims.hd
}
return this.config.domainWhitelist.find(
domain => claims.email?.endsWith(domain),
) !== undefined
return claims.email?.split('@')[1]
}
}

View file

@ -75,7 +75,7 @@ export default class SAMLAuthProvider extends AuthProvider {
const [response, state] = result
// If there is no profile we take no action
if (!response.profile) return
if (!response.profile) throw new RequestError(AuthError.SAMLValidationError)
if (response.loggedOut) {
await this.logout({ email: response.profile.nameID }, ctx)
return
@ -83,7 +83,15 @@ export default class SAMLAuthProvider extends AuthProvider {
// If we are logging in, grab profile and create tokens
const { first_name, last_name, nameID: email } = response.profile
await this.login({ first_name, last_name, email }, ctx, state)
const domain = this.getDomain(email)
if (!email || !domain) throw new RequestError(AuthError.SAMLValidationError)
const { id } = await this.loadAuthOrganization(ctx, domain)
await this.login({ first_name, last_name, email, organization_id: id }, ctx, state)
}
private getDomain(email: string): string | undefined {
return email?.split('@')[1]
}
private async parseValidation(ctx: Context): Promise<[ValidatedSAMLResponse, string?] | undefined> {

View file

@ -25,15 +25,15 @@ export async function cleanupExpiredRevokedTokens(until: Date) {
await AccessToken.delete(qb => qb.where('expires_at', '<=', until))
}
export const generateAccessToken = (input: Admin | number, ctx?: Context) => {
const id = typeof input === 'number' ? input : input.id
export const generateAccessToken = async ({ id, organization_id }: Admin, ctx?: Context) => {
const expires_at = addSeconds(Date.now(), App.main.env.auth.tokenLife)
const token = sign({
id,
organization_id,
exp: Math.floor(expires_at.getTime() / 1000),
}, App.main.env.secret)
AccessToken.insert({
await AccessToken.insert({
admin_id: id,
expires_at,
token,

View file

@ -24,12 +24,12 @@ describe('CampaignService', () => {
}
const createCampaignDependencies = async (): Promise<CampaignRefs> => {
const adminId = await Admin.insert({
const admin = await Admin.insertAndFetch({
first_name: uuid(),
last_name: uuid(),
email: `${uuid()}@test.com`,
})
const project = await createProject(adminId, {
const project = await createProject(admin, {
name: uuid(),
timezone: 'utc',
})

View file

@ -111,7 +111,6 @@ export default (type?: EnvType): Env => {
clientId: process.env.AUTH_OPENID_CLIENT_ID,
clientSecret: process.env.AUTH_OPENID_CLIENT_SECRET,
redirectUri: process.env.AUTH_OPENID_REDIRECT_URI,
domainWhitelist: (process.env.AUTH_OPENID_DOMAIN_WHITELIST || '').split(','),
}),
logger: () => ({
tokenLife: defaultTokenLife,

View file

@ -0,0 +1,10 @@
import { AuthConfig } from '../auth/Auth'
import Model from '../core/Model'
export default class Organization extends Model {
username!: string
domain!: string
auth!: AuthConfig
static jsonAttributes = ['auth']
}

View file

@ -0,0 +1,28 @@
import { encodeHashid } from '../utilities'
import Organization from './Organization'
export const getOrganizationByUsername = async (username: string) => {
return await Organization.first(qb => qb.where('username', username))
}
export const getOrganizationByDomain = async (domain?: string) => {
if (!domain) return undefined
return await Organization.first(qb => qb.where('domain', domain))
}
export const createOrganization = async (domain: string): Promise<Organization> => {
const username = domain.split('.').shift()
const org = await Organization.insertAndFetch({
username,
domain,
})
// If for some reason the domain format is odd, generate
// a random username from the org id
if (!username) {
await Organization.updateAndFetch(org.id, {
username: encodeHashid(org.id),
})
}
return org
}

View file

@ -1,7 +1,7 @@
import Model, { ModelParams } from '../core/Model'
export default class Project extends Model {
organization_id!: number
name!: string
description?: string
deleted_at?: Date
@ -9,7 +9,7 @@ export default class Project extends Model {
timezone!: string
}
export type ProjectParams = Omit<Project, ModelParams | 'deleted_at'>
export type ProjectParams = Omit<Project, ModelParams | 'deleted_at' | 'organization_id'>
export const projectRoles = [
'support',

View file

@ -1,14 +1,15 @@
import Router from '@koa/router'
import Project, { ProjectParams } from './Project'
import { ProjectParams } from './Project'
import { JSONSchemaType, validate } from '../core/validate'
import { extractQueryParams } from '../utilities'
import { searchParamsSchema } from '../core/searchParams'
import { ParameterizedContext } from 'koa'
import { createProject, getProject, requireProjectRole, updateProject } from './ProjectService'
import { allProjects, createProject, getProject, pagedProjects, requireProjectRole, updateProject } from './ProjectService'
import { AuthState, ProjectState } from '../auth/AuthMiddleware'
import { getProjectAdmin } from './ProjectAdminRepository'
import { RequestError } from '../core/errors'
import { ProjectError } from './ProjectError'
import { getAdmin } from '../auth/AdminRepository'
export async function projectMiddleware(ctx: ParameterizedContext<ProjectState>, next: () => void) {
@ -40,11 +41,12 @@ export async function projectMiddleware(ctx: ParameterizedContext<ProjectState>,
const router = new Router<AuthState>({ prefix: '/projects' })
router.get('/', async ctx => {
ctx.body = await Project.searchParams(extractQueryParams(ctx.request.query, searchParamsSchema), ['name'])
const params = extractQueryParams(ctx.query, searchParamsSchema)
ctx.body = await pagedProjects(params, ctx.state.admin!.id)
})
router.get('/all', async ctx => {
ctx.body = await Project.all()
ctx.body = await allProjects(ctx.state.admin!.id)
})
const projectCreateParams: JSONSchemaType<ProjectParams> = {
@ -70,7 +72,8 @@ const projectCreateParams: JSONSchemaType<ProjectParams> = {
router.post('/', async ctx => {
const payload = validate(projectCreateParams, ctx.request.body)
ctx.body = await createProject(ctx.state.admin!.id, payload)
const admin = await getAdmin(ctx.state.admin!.id)
ctx.body = await createProject(admin!, payload)
})
export default router

View file

@ -7,12 +7,36 @@ import { uuid } from '../utilities'
import Project, { ProjectParams, ProjectRole, projectRoles } from './Project'
import { ProjectAdmin } from './ProjectAdmins'
import { ProjectApiKey, ProjectApiKeyParams } from './ProjectApiKey'
import { Admin } from '../auth/Admin'
import { getAdmin } from '../auth/AdminRepository'
export const adminProjectIds = async (adminId: number) => {
const records = await ProjectAdmin.all(qb => qb.where('admin_id', adminId))
return records.map(item => item.project_id)
}
export const pagedProjects = async (params: SearchParams, adminId: number) => {
const admin = await getAdmin(adminId)
const projectIds = await adminProjectIds(adminId)
return await Project.searchParams(params, ['name'], qb =>
qb.where(qb =>
qb.where('organization_id', admin!.organization_id)
.orWhereIn('projects.id', projectIds),
),
)
}
export const allProjects = async (adminId: number) => {
const admin = await getAdmin(adminId)
const projectIds = await adminProjectIds(adminId)
return await Project.all(qb =>
qb.where(qb =>
qb.where('organization_id', admin!.organization_id)
.orWhereIn('projects.id', projectIds),
),
)
}
export const getProject = async (id: number, adminId?: number) => {
return Project.first(
qb => {
@ -27,13 +51,16 @@ export const getProject = async (id: number, adminId?: number) => {
})
}
export const createProject = async (adminId: number, params: ProjectParams): Promise<Project> => {
const projectId = await Project.insert(params)
export const createProject = async (admin: Admin, params: ProjectParams): Promise<Project> => {
const projectId = await Project.insert({
...params,
organization_id: admin.organization_id,
})
// Add the user creating the project to it
await ProjectAdmin.insert({
project_id: projectId,
admin_id: adminId,
admin_id: admin.id,
role: 'admin',
})
@ -43,7 +70,7 @@ export const createProject = async (adminId: number, params: ProjectParams): Pro
await createSubscription(projectId, { name: 'Default Push', channel: 'push' })
await createSubscription(projectId, { name: 'Default Webhook', channel: 'webhook' })
const project = await getProject(projectId, adminId)
const project = await getProject(projectId, admin.id)
return project!
}

View file

@ -26,7 +26,7 @@ export default class Queue {
} else if (config?.driver === 'memory') {
this.provider = new MemoryQueueProvider(this)
} else {
throw new Error('A valid queue must be defined!')
throw new Error('A valid queue driver must be set!')
}
}

View file

@ -11,12 +11,15 @@ afterEach(() => {
describe('LinkService', () => {
describe('encodedLinkToParts', () => {
test('a properly encoded link decodes to parts', async () => {
const adminId = await Admin.insert({
const admin = await Admin.insertAndFetch({
first_name: uuid(),
last_name: uuid(),
email: `${uuid()}@test.com`,
})
const project = await createProject(adminId, { name: uuid() })
const project = await createProject(admin, {
name: uuid(),
timezone: 'utc',
})
const user = await createUser(project.id, {
anonymous_id: uuid(),
external_id: uuid(),