diff --git a/lib/rexml/parsers/baseparser.rb b/lib/rexml/parsers/baseparser.rb index 342f9482..b471feff 100644 --- a/lib/rexml/parsers/baseparser.rb +++ b/lib/rexml/parsers/baseparser.rb @@ -8,6 +8,22 @@ module REXML module Parsers + unless [].respond_to?(:tally) + module EnumerableTally + refine Enumerable do + def tally + counts = {} + each do |item| + counts[item] ||= 0 + counts[item] += 1 + end + counts + end + end + end + using EnumerableTally + end + if StringScanner::Version < "3.0.8" module StringScannerCaptures refine StringScanner do @@ -547,20 +563,29 @@ def unnormalize( string, entities=nil, filter=nil ) [Integer(m)].pack('U*') } matches.collect!{|x|x[0]}.compact! + if filter + matches.reject! do |entity_reference| + filter.include?(entity_reference) + end + end if matches.size > 0 - matches.each do |entity_reference| - unless filter and filter.include?(entity_reference) - entity_value = entity( entity_reference, entities ) - if entity_value - re = Private::DEFAULT_ENTITIES_PATTERNS[entity_reference] || /&#{entity_reference};/ - rv.gsub!( re, entity_value ) - if rv.bytesize > Security.entity_expansion_text_limit - raise "entity expansion has grown too large" - end - else - er = DEFAULT_ENTITIES[entity_reference] - rv.gsub!( er[0], er[2] ) if er + matches.tally.each do |entity_reference, n| + entity_expansion_count_before = @entity_expansion_count + entity_value = entity( entity_reference, entities ) + if entity_value + if n > 1 + entity_expansion_count_delta = + @entity_expansion_count - entity_expansion_count_before + record_entity_expansion(entity_expansion_count_delta * (n - 1)) + end + re = Private::DEFAULT_ENTITIES_PATTERNS[entity_reference] || /&#{entity_reference};/ + rv.gsub!( re, entity_value ) + if rv.bytesize > Security.entity_expansion_text_limit + raise "entity expansion has grown too large" end + else + er = DEFAULT_ENTITIES[entity_reference] + rv.gsub!( er[0], er[2] ) if er end end rv.gsub!( Private::DEFAULT_ENTITIES_PATTERNS['amp'], '&' ) @@ -570,8 +595,8 @@ def unnormalize( string, entities=nil, filter=nil ) private - def record_entity_expansion - @entity_expansion_count += 1 + def record_entity_expansion(delta=1) + @entity_expansion_count += delta if @entity_expansion_count > Security.entity_expansion_limit raise "number of entity expansions exceeded, processing aborted." end diff --git a/test/test_pullparser.rb b/test/test_pullparser.rb index 827fad1d..dbde8779 100644 --- a/test/test_pullparser.rb +++ b/test/test_pullparser.rb @@ -206,21 +206,23 @@ def test_empty_value XML + REXML::Security.entity_expansion_limit = 100000 parser = REXML::Parsers::PullParser.new(source) - assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do - while parser.has_next? - parser.pull - end + while parser.has_next? + parser.pull end + assert_equal(11111, parser.entity_expansion_count) - REXML::Security.entity_expansion_limit = 100 + REXML::Security.entity_expansion_limit = @default_entity_expansion_limit parser = REXML::Parsers::PullParser.new(source) assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do while parser.has_next? parser.pull end end - assert_equal(101, parser.entity_expansion_count) + assert do + parser.entity_expansion_count > @default_entity_expansion_limit + end end def test_with_default_entity diff --git a/test/test_sax.rb b/test/test_sax.rb index f452de50..d31de183 100644 --- a/test/test_sax.rb +++ b/test/test_sax.rb @@ -147,17 +147,19 @@ def test_empty_value XML + REXML::Security.entity_expansion_limit = 100000 sax = REXML::Parsers::SAX2Parser.new(source) - assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do - sax.parse - end + sax.parse + assert_equal(11111, sax.entity_expansion_count) - REXML::Security.entity_expansion_limit = 100 + REXML::Security.entity_expansion_limit = @default_entity_expansion_limit sax = REXML::Parsers::SAX2Parser.new(source) assert_raise(RuntimeError.new("number of entity expansions exceeded, processing aborted.")) do sax.parse end - assert_equal(101, sax.entity_expansion_count) + assert do + sax.entity_expansion_count > @default_entity_expansion_limit + end end def test_with_default_entity