mirror of
https://github.com/discourse/discourse.git
synced 2025-09-06 10:50:21 +08:00
FIX: handle CORS in hijacked requests
This commit is contained in:
parent
c64774f4f8
commit
90a55d6f7c
3 changed files with 85 additions and 34 deletions
|
@ -1,41 +1,52 @@
|
||||||
if GlobalSetting.enable_cors
|
# frozen_string_literal: true
|
||||||
class Discourse::Cors
|
|
||||||
def initialize(app, options = nil)
|
|
||||||
@app = app
|
|
||||||
if GlobalSetting.enable_cors && GlobalSetting.cors_origin.present?
|
|
||||||
@global_origins = GlobalSetting.cors_origin.split(',').map(&:strip)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def call(env)
|
class Discourse::Cors
|
||||||
if env['REQUEST_METHOD'] == ('OPTIONS') && env['HTTP_ACCESS_CONTROL_REQUEST_METHOD']
|
ORIGINS_ENV = "Discourse_Cors_Origins"
|
||||||
return [200, apply_headers(env), []]
|
|
||||||
end
|
|
||||||
|
|
||||||
status, headers, body = @app.call(env)
|
def initialize(app, options = nil)
|
||||||
[status, apply_headers(env, headers), body]
|
@app = app
|
||||||
end
|
if GlobalSetting.enable_cors && GlobalSetting.cors_origin.present?
|
||||||
|
@global_origins = GlobalSetting.cors_origin.split(',').map(&:strip)
|
||||||
def apply_headers(env, headers = nil)
|
|
||||||
headers ||= {}
|
|
||||||
|
|
||||||
origin = nil
|
|
||||||
cors_origins = @global_origins || []
|
|
||||||
cors_origins += SiteSetting.cors_origins.split('|') if SiteSetting.cors_origins
|
|
||||||
|
|
||||||
if cors_origins
|
|
||||||
if origin = env['HTTP_ORIGIN']
|
|
||||||
origin = nil unless cors_origins.include?(origin)
|
|
||||||
end
|
|
||||||
|
|
||||||
headers['Access-Control-Allow-Origin'] = origin || cors_origins[0]
|
|
||||||
headers['Access-Control-Allow-Headers'] = 'X-Requested-With, X-CSRF-Token, Discourse-Visible'
|
|
||||||
headers['Access-Control-Allow-Credentials'] = 'true'
|
|
||||||
end
|
|
||||||
|
|
||||||
headers
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def call(env)
|
||||||
|
|
||||||
|
cors_origins = @global_origins || []
|
||||||
|
cors_origins += SiteSetting.cors_origins.split('|') if SiteSetting.cors_origins.present?
|
||||||
|
cors_origins = cors_origins.presence
|
||||||
|
|
||||||
|
if env['REQUEST_METHOD'] == ('OPTIONS') && env['HTTP_ACCESS_CONTROL_REQUEST_METHOD']
|
||||||
|
return [200, Discourse::Cors.apply_headers(cors_origins, env, {}), []]
|
||||||
|
end
|
||||||
|
|
||||||
|
env[Discourse::Cors::ORIGINS_ENV] = cors_origins if cors_origins
|
||||||
|
|
||||||
|
status, headers, body = @app.call(env)
|
||||||
|
headers ||= {}
|
||||||
|
|
||||||
|
Discourse::Cors.apply_headers(cors_origins, env, headers) if cors_origins
|
||||||
|
|
||||||
|
[status, headers, body]
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.apply_headers(cors_origins, env, headers)
|
||||||
|
origin = nil
|
||||||
|
|
||||||
|
if cors_origins
|
||||||
|
if origin = env['HTTP_ORIGIN']
|
||||||
|
origin = nil unless cors_origins.include?(origin)
|
||||||
|
end
|
||||||
|
|
||||||
|
headers['Access-Control-Allow-Origin'] = origin || cors_origins[0]
|
||||||
|
headers['Access-Control-Allow-Headers'] = 'X-Requested-With, X-CSRF-Token, Discourse-Visible'
|
||||||
|
headers['Access-Control-Allow-Credentials'] = 'true'
|
||||||
|
end
|
||||||
|
|
||||||
|
headers
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if GlobalSetting.enable_cors
|
||||||
Rails.configuration.middleware.insert_before ActionDispatch::Flash, Discourse::Cors
|
Rails.configuration.middleware.insert_before ActionDispatch::Flash, Discourse::Cors
|
||||||
end
|
end
|
||||||
|
|
|
@ -55,6 +55,11 @@ module Hijack
|
||||||
body = response.body
|
body = response.body
|
||||||
|
|
||||||
headers = response.headers
|
headers = response.headers
|
||||||
|
# add cors if needed
|
||||||
|
if cors_origins = env_copy[Discourse::Cors::ORIGINS_ENV]
|
||||||
|
Discourse::Cors.apply_headers(cors_origins, env_copy, headers)
|
||||||
|
end
|
||||||
|
|
||||||
headers['Content-Length'] = body.bytesize
|
headers['Content-Length'] = body.bytesize
|
||||||
headers['Content-Type'] = response.content_type || "text/plain"
|
headers['Content-Type'] = response.content_type || "text/plain"
|
||||||
headers['Connection'] = "close"
|
headers['Connection'] = "close"
|
||||||
|
|
|
@ -79,6 +79,41 @@ describe Hijack do
|
||||||
expect(copy_req.object_id).not_to eq(orig_req.object_id)
|
expect(copy_req.object_id).not_to eq(orig_req.object_id)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "handles cors" do
|
||||||
|
SiteSetting.cors_origins = "www.rainbows.com"
|
||||||
|
|
||||||
|
app = lambda do |env|
|
||||||
|
tester = Hijack::Tester.new(env)
|
||||||
|
tester.hijack_test do
|
||||||
|
render body: "hello", status: 201
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(tester.io.string).to include("Access-Control-Allow-Origin: www.rainbows.com")
|
||||||
|
end
|
||||||
|
|
||||||
|
env = {}
|
||||||
|
middleware = Discourse::Cors.new(app)
|
||||||
|
middleware.call(env)
|
||||||
|
|
||||||
|
# it can do pre-flight
|
||||||
|
env = {
|
||||||
|
'REQUEST_METHOD' => 'OPTIONS',
|
||||||
|
'HTTP_ACCESS_CONTROL_REQUEST_METHOD' => 'GET'
|
||||||
|
}
|
||||||
|
|
||||||
|
status, headers, _body = middleware.call(env)
|
||||||
|
|
||||||
|
expect(status).to eq(200)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"Access-Control-Allow-Origin" => "www.rainbows.com",
|
||||||
|
"Access-Control-Allow-Headers" => "X-Requested-With, X-CSRF-Token, Discourse-Visible",
|
||||||
|
"Access-Control-Allow-Credentials" => "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(headers).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
it "handles expires_in" do
|
it "handles expires_in" do
|
||||||
tester.hijack_test do
|
tester.hijack_test do
|
||||||
expires_in 1.year
|
expires_in 1.year
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue