os_firewall_manage_iptables.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from subprocess import call, check_output
  4. DOCUMENTATION = '''
  5. ---
  6. module: os_firewall_manage_iptables
  7. short_description: This module manages iptables rules for a given chain
  8. author: Jason DeTiberus
  9. requirements: [ ]
  10. '''
  11. EXAMPLES = '''
  12. '''
  13. class IpTablesError(Exception):
  14. def __init__(self, msg, cmd, exit_code, output):
  15. self.msg = msg
  16. self.cmd = cmd
  17. self.exit_code = exit_code
  18. self.output = output
  19. class IpTablesAddRuleError(IpTablesError):
  20. pass
  21. class IpTablesRemoveRuleError(IpTablesError):
  22. pass
  23. class IpTablesSaveError(IpTablesError):
  24. pass
  25. class IpTablesCreateChainError(IpTablesError):
  26. def __init__(self, chain, msg, cmd, exit_code, output):
  27. super(IpTablesCreateChainError, self).__init__(msg, cmd, exit_code, output)
  28. self.chain = chain
  29. class IpTablesCreateJumpRuleError(IpTablesError):
  30. def __init__(self, chain, msg, cmd, exit_code, output):
  31. super(IpTablesCreateJumpRuleError, self).__init__(msg, cmd, exit_code,
  32. output)
  33. self.chain = chain
  34. # TODO: impliment rollbacks for any events that where successful and an
  35. # exception was thrown later. for example, when the chain is created
  36. # successfully, but the add/remove rule fails.
  37. class IpTablesManager:
  38. def __init__(self, module):
  39. self.module = module
  40. self.ip_version = module.params['ip_version']
  41. self.check_mode = module.check_mode
  42. self.chain = module.params['chain']
  43. self.create_jump_rule = module.params['create_jump_rule']
  44. self.jump_rule_chain = module.params['jump_rule_chain']
  45. self.cmd = self.gen_cmd()
  46. self.save_cmd = self.gen_save_cmd()
  47. self.output = []
  48. self.changed = False
  49. def save(self):
  50. try:
  51. self.output.append(check_output(self.save_cmd,
  52. stderr=subprocess.STDOUT))
  53. except subprocess.CalledProcessError as e:
  54. raise IpTablesSaveError(
  55. msg="Failed to save iptables rules",
  56. cmd=e.cmd, exit_code=e.returncode, output=e.output)
  57. def verify_chain(self):
  58. if not self.chain_exists():
  59. self.create_chain()
  60. if self.create_jump_rule and not self.jump_rule_exists():
  61. self.create_jump()
  62. def add_rule(self, port, proto):
  63. rule = self.gen_rule(port, proto)
  64. if not self.rule_exists(rule):
  65. self.verify_chain()
  66. if self.check_mode:
  67. self.changed = True
  68. self.output.append("Create rule for %s %s" % (proto, port))
  69. else:
  70. cmd = self.cmd + ['-A'] + rule
  71. try:
  72. self.output.append(check_output(cmd))
  73. self.changed = True
  74. self.save()
  75. except subprocess.CalledProcessError as e:
  76. raise IpTablesCreateChainError(
  77. chain=self.chain,
  78. msg="Failed to create rule for "
  79. "%s %s" % (self.proto, self.port),
  80. cmd=e.cmd, exit_code=e.returncode,
  81. output=e.output)
  82. def remove_rule(self, port, proto):
  83. rule = self.gen_rule(port, proto)
  84. if self.rule_exists(rule):
  85. if self.check_mode:
  86. self.changed = True
  87. self.output.append("Remove rule for %s %s" % (proto, port))
  88. else:
  89. cmd = self.cmd + ['-D'] + rule
  90. try:
  91. self.output.append(check_output(cmd))
  92. self.changed = True
  93. self.save()
  94. except subprocess.CalledProcessError as e:
  95. raise IpTablesRemoveChainError(
  96. chain=self.chain,
  97. msg="Failed to remove rule for %s %s" % (proto, port),
  98. cmd=e.cmd, exit_code=e.returncode, output=e.output)
  99. def rule_exists(self, rule):
  100. check_cmd = self.cmd + ['-C'] + rule
  101. return True if subprocess.call(check_cmd) == 0 else False
  102. def gen_rule(self, port, proto):
  103. return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW',
  104. '-m', proto, '--dport', str(port), '-j', 'ACCEPT']
  105. def create_jump(self):
  106. if self.check_mode:
  107. self.changed = True
  108. self.output.append("Create jump rule for chain %s" % self.chain)
  109. else:
  110. try:
  111. cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers']
  112. output = check_output(cmd, stderr=subprocess.STDOUT)
  113. # break the input rules into rows and columns
  114. input_rules = map(lambda s: s.split(), output.split('\n'))
  115. # Find the last numbered rule
  116. last_rule_num = None
  117. last_rule_target = None
  118. for rule in input_rules[:-1]:
  119. if rule:
  120. try:
  121. last_rule_num = int(rule[0])
  122. except ValueError:
  123. continue
  124. last_rule_target = rule[1]
  125. # Raise an exception if we do not find a valid rule
  126. if not last_rule_num or not last_rule_target:
  127. raise IpTablesCreateJumpRuleError(
  128. chain=self.chain,
  129. msg="Failed to find existing %s rules" % self.jump_rule_chain,
  130. cmd=None, exit_code=None, output=None)
  131. # Naively assume that if the last row is a REJECT rule, then
  132. # we can add insert our rule right before it, otherwise we
  133. # assume that we can just append the rule.
  134. if last_rule_target == 'REJECT':
  135. # insert rule
  136. cmd = self.cmd + ['-I', self.jump_rule_chain, str(last_rule_num)]
  137. else:
  138. # append rule
  139. cmd = self.cmd + ['-A', self.jump_rule_chain]
  140. cmd += ['-j', self.chain]
  141. output = check_output(cmd, stderr=subprocess.STDOUT)
  142. changed = True
  143. self.output.append(output)
  144. self.save()
  145. except subprocess.CalledProcessError as e:
  146. if '--line-numbers' in e.cmd:
  147. raise IpTablesCreateJumpRuleError(
  148. chain=self.chain,
  149. msg="Failed to query existing %s rules to " % self.jump_rule_chain +
  150. "determine jump rule location",
  151. cmd=e.cmd, exit_code=e.returncode,
  152. output=e.output)
  153. else:
  154. raise IpTablesCreateJumpRuleError(
  155. chain=self.chain,
  156. msg="Failed to create jump rule for chain %s" %
  157. self.chain,
  158. cmd=e.cmd, exit_code=e.returncode,
  159. output=e.output)
  160. def create_chain(self):
  161. if self.check_mode:
  162. self.changed = True
  163. self.output.append("Create chain %s" % self.chain)
  164. else:
  165. try:
  166. cmd = self.cmd + ['-N', self.chain]
  167. self.output.append(check_output(cmd,
  168. stderr=subprocess.STDOUT))
  169. self.changed = True
  170. self.output.append("Successfully created chain %s" %
  171. self.chain)
  172. self.save()
  173. except subprocess.CalledProcessError as e:
  174. raise IpTablesCreateChainError(
  175. chain=self.chain,
  176. msg="Failed to create chain: %s" % self.chain,
  177. cmd=e.cmd, exit_code=e.returncode, output=e.output
  178. )
  179. def jump_rule_exists(self):
  180. cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain]
  181. return True if subprocess.call(cmd) == 0 else False
  182. def chain_exists(self):
  183. cmd = self.cmd + ['-L', self.chain]
  184. return True if subprocess.call(cmd) == 0 else False
  185. def gen_cmd(self):
  186. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  187. return ["/usr/sbin/%s" % cmd]
  188. def gen_save_cmd(self):
  189. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  190. return ['/usr/libexec/iptables/iptables.init', 'save']
  191. def main():
  192. module = AnsibleModule(
  193. argument_spec=dict(
  194. name=dict(required=True),
  195. action=dict(required=True, choices=['add', 'remove', 'verify_chain']),
  196. chain=dict(required=False, default='OS_FIREWALL_ALLOW'),
  197. create_jump_rule=dict(required=False, type='bool', default=True),
  198. jump_rule_chain=dict(required=False, default='INPUT'),
  199. protocol=dict(required=False, choices=['tcp', 'udp']),
  200. port=dict(required=False, type='int'),
  201. ip_version=dict(required=False, default='ipv4',
  202. choices=['ipv4', 'ipv6']),
  203. ),
  204. supports_check_mode=True
  205. )
  206. action = module.params['action']
  207. protocol = module.params['protocol']
  208. port = module.params['port']
  209. if action in ['add', 'remove']:
  210. if not protocol:
  211. error = "protocol is required when action is %s" % action
  212. module.fail_json(msg=error)
  213. if not port:
  214. error = "port is required when action is %s" % action
  215. module.fail_json(msg=error)
  216. iptables_manager = IpTablesManager(module)
  217. try:
  218. if action == 'add':
  219. iptables_manager.add_rule(port, protocol)
  220. elif action == 'remove':
  221. iptables_manager.remove_rule(port, protocol)
  222. elif action == 'verify_chain':
  223. iptables_manager.verify_chain()
  224. except IpTablesError as e:
  225. module.fail_json(msg=e.msg)
  226. return module.exit_json(changed=iptables_manager.changed,
  227. output=iptables_manager.output)
  228. # import module snippets
  229. from ansible.module_utils.basic import *
  230. main()