os_firewall_manage_iptables.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # vim: expandtab:tabstop=4:shiftwidth=4
  4. from subprocess import call, check_output
  5. DOCUMENTATION = '''
  6. ---
  7. module: os_firewall_manage_iptables
  8. short_description: This module manages iptables rules for a given chain
  9. author: Jason DeTiberus
  10. requirements: [ ]
  11. '''
  12. EXAMPLES = '''
  13. '''
  14. class IpTablesError(Exception):
  15. def __init__(self, msg, cmd, exit_code, output):
  16. self.msg = msg
  17. self.cmd = cmd
  18. self.exit_code = exit_code
  19. self.output = output
  20. class IpTablesAddRuleError(IpTablesError):
  21. pass
  22. class IpTablesRemoveRuleError(IpTablesError):
  23. pass
  24. class IpTablesSaveError(IpTablesError):
  25. pass
  26. class IpTablesCreateChainError(IpTablesError):
  27. def __init__(self, chain, msg, cmd, exit_code, output):
  28. super(IpTablesCreateChainError, self).__init__(msg, cmd, exit_code, output)
  29. self.chain = chain
  30. class IpTablesCreateJumpRuleError(IpTablesError):
  31. def __init__(self, chain, msg, cmd, exit_code, output):
  32. super(IpTablesCreateJumpRuleError, self).__init__(msg, cmd, exit_code,
  33. output)
  34. self.chain = chain
  35. # TODO: impliment rollbacks for any events that where successful and an
  36. # exception was thrown later. for example, when the chain is created
  37. # successfully, but the add/remove rule fails.
  38. class IpTablesManager:
  39. def __init__(self, module):
  40. self.module = module
  41. self.ip_version = module.params['ip_version']
  42. self.check_mode = module.check_mode
  43. self.chain = module.params['chain']
  44. self.create_jump_rule = module.params['create_jump_rule']
  45. self.jump_rule_chain = module.params['jump_rule_chain']
  46. self.cmd = self.gen_cmd()
  47. self.save_cmd = self.gen_save_cmd()
  48. self.output = []
  49. self.changed = False
  50. def save(self):
  51. try:
  52. self.output.append(check_output(self.save_cmd,
  53. stderr=subprocess.STDOUT))
  54. except subprocess.CalledProcessError as e:
  55. raise IpTablesSaveError(
  56. msg="Failed to save iptables rules",
  57. cmd=e.cmd, exit_code=e.returncode, output=e.output)
  58. def verify_chain(self):
  59. if not self.chain_exists():
  60. self.create_chain()
  61. if self.create_jump_rule and not self.jump_rule_exists():
  62. self.create_jump()
  63. def add_rule(self, port, proto):
  64. rule = self.gen_rule(port, proto)
  65. if not self.rule_exists(rule):
  66. self.verify_chain()
  67. if self.check_mode:
  68. self.changed = True
  69. self.output.append("Create rule for %s %s" % (proto, port))
  70. else:
  71. cmd = self.cmd + ['-A'] + rule
  72. try:
  73. self.output.append(check_output(cmd))
  74. self.changed = True
  75. self.save()
  76. except subprocess.CalledProcessError as e:
  77. raise IpTablesCreateChainError(
  78. chain=self.chain,
  79. msg="Failed to create rule for "
  80. "%s %s" % (self.proto, self.port),
  81. cmd=e.cmd, exit_code=e.returncode,
  82. output=e.output)
  83. def remove_rule(self, port, proto):
  84. rule = self.gen_rule(port, proto)
  85. if self.rule_exists(rule):
  86. if self.check_mode:
  87. self.changed = True
  88. self.output.append("Remove rule for %s %s" % (proto, port))
  89. else:
  90. cmd = self.cmd + ['-D'] + rule
  91. try:
  92. self.output.append(check_output(cmd))
  93. self.changed = True
  94. self.save()
  95. except subprocess.CalledProcessError as e:
  96. raise IpTablesRemoveChainError(
  97. chain=self.chain,
  98. msg="Failed to remove rule for %s %s" % (proto, port),
  99. cmd=e.cmd, exit_code=e.returncode, output=e.output)
  100. def rule_exists(self, rule):
  101. check_cmd = self.cmd + ['-C'] + rule
  102. return True if subprocess.call(check_cmd) == 0 else False
  103. def gen_rule(self, port, proto):
  104. return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW',
  105. '-m', proto, '--dport', str(port), '-j', 'ACCEPT']
  106. def create_jump(self):
  107. if self.check_mode:
  108. self.changed = True
  109. self.output.append("Create jump rule for chain %s" % self.chain)
  110. else:
  111. try:
  112. cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers']
  113. output = check_output(cmd, stderr=subprocess.STDOUT)
  114. # break the input rules into rows and columns
  115. input_rules = map(lambda s: s.split(), output.split('\n'))
  116. # Find the last numbered rule
  117. last_rule_num = None
  118. last_rule_target = None
  119. for rule in input_rules[:-1]:
  120. if rule:
  121. try:
  122. last_rule_num = int(rule[0])
  123. except ValueError:
  124. continue
  125. last_rule_target = rule[1]
  126. # Naively assume that if the last row is a REJECT rule, then
  127. # we can add insert our rule right before it, otherwise we
  128. # assume that we can just append the rule.
  129. if last_rule_num and last_rule_target and last_rule_target == 'REJECT':
  130. # insert rule
  131. cmd = self.cmd + ['-I', self.jump_rule_chain, str(last_rule_num)]
  132. else:
  133. # append rule
  134. cmd = self.cmd + ['-A', self.jump_rule_chain]
  135. cmd += ['-j', self.chain]
  136. output = check_output(cmd, stderr=subprocess.STDOUT)
  137. changed = True
  138. self.output.append(output)
  139. self.save()
  140. except subprocess.CalledProcessError as e:
  141. if '--line-numbers' in e.cmd:
  142. raise IpTablesCreateJumpRuleError(
  143. chain=self.chain,
  144. msg="Failed to query existing %s rules to " % self.jump_rule_chain +
  145. "determine jump rule location",
  146. cmd=e.cmd, exit_code=e.returncode,
  147. output=e.output)
  148. else:
  149. raise IpTablesCreateJumpRuleError(
  150. chain=self.chain,
  151. msg="Failed to create jump rule for chain %s" %
  152. self.chain,
  153. cmd=e.cmd, exit_code=e.returncode,
  154. output=e.output)
  155. def create_chain(self):
  156. if self.check_mode:
  157. self.changed = True
  158. self.output.append("Create chain %s" % self.chain)
  159. else:
  160. try:
  161. cmd = self.cmd + ['-N', self.chain]
  162. self.output.append(check_output(cmd,
  163. stderr=subprocess.STDOUT))
  164. self.changed = True
  165. self.output.append("Successfully created chain %s" %
  166. self.chain)
  167. self.save()
  168. except subprocess.CalledProcessError as e:
  169. raise IpTablesCreateChainError(
  170. chain=self.chain,
  171. msg="Failed to create chain: %s" % self.chain,
  172. cmd=e.cmd, exit_code=e.returncode, output=e.output
  173. )
  174. def jump_rule_exists(self):
  175. cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain]
  176. return True if subprocess.call(cmd) == 0 else False
  177. def chain_exists(self):
  178. cmd = self.cmd + ['-L', self.chain]
  179. return True if subprocess.call(cmd) == 0 else False
  180. def gen_cmd(self):
  181. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  182. return ["/usr/sbin/%s" % cmd]
  183. def gen_save_cmd(self):
  184. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  185. return ['/usr/libexec/iptables/iptables.init', 'save']
  186. def main():
  187. module = AnsibleModule(
  188. argument_spec=dict(
  189. name=dict(required=True),
  190. action=dict(required=True, choices=['add', 'remove', 'verify_chain']),
  191. chain=dict(required=False, default='OS_FIREWALL_ALLOW'),
  192. create_jump_rule=dict(required=False, type='bool', default=True),
  193. jump_rule_chain=dict(required=False, default='INPUT'),
  194. protocol=dict(required=False, choices=['tcp', 'udp']),
  195. port=dict(required=False, type='int'),
  196. ip_version=dict(required=False, default='ipv4',
  197. choices=['ipv4', 'ipv6']),
  198. ),
  199. supports_check_mode=True
  200. )
  201. action = module.params['action']
  202. protocol = module.params['protocol']
  203. port = module.params['port']
  204. if action in ['add', 'remove']:
  205. if not protocol:
  206. error = "protocol is required when action is %s" % action
  207. module.fail_json(msg=error)
  208. if not port:
  209. error = "port is required when action is %s" % action
  210. module.fail_json(msg=error)
  211. iptables_manager = IpTablesManager(module)
  212. try:
  213. if action == 'add':
  214. iptables_manager.add_rule(port, protocol)
  215. elif action == 'remove':
  216. iptables_manager.remove_rule(port, protocol)
  217. elif action == 'verify_chain':
  218. iptables_manager.verify_chain()
  219. except IpTablesError as e:
  220. module.fail_json(msg=e.msg)
  221. return module.exit_json(changed=iptables_manager.changed,
  222. output=iptables_manager.output)
  223. # import module snippets
  224. from ansible.module_utils.basic import *
  225. if __name__ == '__main__':
  226. main()