winlin

add s1 validation for client/server

@@ -759,6 +759,29 @@ int c1s1::c1_validate_digest(bool& is_valid) @@ -759,6 +759,29 @@ int c1s1::c1_validate_digest(bool& is_valid)
759 return ret; 759 return ret;
760 } 760 }
761 761
  762 +int c1s1::s1_validate_digest(bool& is_valid)
  763 +{
  764 + int ret = ERROR_SUCCESS;
  765 +
  766 + char* s1_digest = NULL;
  767 +
  768 + if ((ret = calc_s1_digest(s1_digest)) != ERROR_SUCCESS) {
  769 + srs_error("validate s1 error, failed to calc digest. ret=%d", ret);
  770 + return ret;
  771 + }
  772 +
  773 + srs_assert(s1_digest != NULL);
  774 + SrsAutoFree(char, s1_digest, true);
  775 +
  776 + if (schema == srs_schema0) {
  777 + is_valid = srs_bytes_equals(block1.digest.digest, s1_digest, 32);
  778 + } else {
  779 + is_valid = srs_bytes_equals(block0.digest.digest, s1_digest, 32);
  780 + }
  781 +
  782 + return ret;
  783 +}
  784 +
762 int c1s1::s1_create(c1s1* c1) 785 int c1s1::s1_create(c1s1* c1)
763 { 786 {
764 int ret = ERROR_SUCCESS; 787 int ret = ERROR_SUCCESS;
@@ -1076,6 +1099,13 @@ int SrsComplexHandshake::handshake_with_client(ISrsProtocolReaderWriter* skt, ch @@ -1076,6 +1099,13 @@ int SrsComplexHandshake::handshake_with_client(ISrsProtocolReaderWriter* skt, ch
1076 return ret; 1099 return ret;
1077 } 1100 }
1078 srs_verbose("create s1 from c1 success."); 1101 srs_verbose("create s1 from c1 success.");
  1102 + // verify s1
  1103 + if ((ret = s1.s1_validate_digest(is_valid)) != ERROR_SUCCESS || !is_valid) {
  1104 + ret = ERROR_RTMP_TRY_SIMPLE_HS;
  1105 + srs_info("valid s1 failed, try simple handshake. ret=%d", ret);
  1106 + return ret;
  1107 + }
  1108 + srs_verbose("verify s1 from c1 success.");
1079 1109
1080 c2s2 s2; 1110 c2s2 s2;
1081 if ((ret = s2.s2_create(&c1)) != ERROR_SUCCESS) { 1111 if ((ret = s2.s2_create(&c1)) != ERROR_SUCCESS) {
@@ -206,13 +206,17 @@ namespace srs @@ -206,13 +206,17 @@ namespace srs
206 */ 206 */
207 virtual int c1_parse(char* _c1s1, srs_schema_type _schema); 207 virtual int c1_parse(char* _c1s1, srs_schema_type _schema);
208 /** 208 /**
209 - * server: validate the parsed schema and c1s1 209 + * server: validate the parsed c1 schema
210 */ 210 */
211 virtual int c1_validate_digest(bool& is_valid); 211 virtual int c1_validate_digest(bool& is_valid);
212 /** 212 /**
213 * server: create and sign the s1 from c1. 213 * server: create and sign the s1 from c1.
214 */ 214 */
215 virtual int s1_create(c1s1* c1); 215 virtual int s1_create(c1s1* c1);
  216 + /**
  217 + * server: validate the parsed s1 schema
  218 + */
  219 + virtual int s1_validate_digest(bool& is_valid);
216 private: 220 private:
217 virtual int calc_s1_digest(char*& digest); 221 virtual int calc_s1_digest(char*& digest);
218 virtual int calc_c1_digest(char*& digest); 222 virtual int calc_c1_digest(char*& digest);